Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 059a135

Browse files
mengluy0125facebook-github-bot
authored andcommitted
[Optimus][Auto-AC] Support activation quantization
Summary: We enable the activation quantization in the forward pass, and users can customize the dtype they want to quantize. Test Plan: # unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization -- test_activation_quantization_aten ``` Buck UI: https://www.internalfb.com/buck2/776d3911-bb86-4ac8-a527-540cf1510b9d Test UI: https://www.internalfb.com/intern/testinfra/testrun/4785074873051017 Network: Up: 4.3MiB Down: 42MiB (reSessionID-fef7e727-68b1-4645-a519-5652854df38d) Executing actions. Remaining 0/4 6.7s exec time total Command: test. Finished 2 local Time elapsed: 3:11.5s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E ### how to enable (you can overrite the dtype, if nothing given, the default is fp8) ``` post_grad_fusion_options={ "activation_quantization_aten_pass": {"quant_type": torch.float8_e5m2} }, ``` Differential Revision: D70522237
1 parent 5887a2d commit 059a135

File tree

6 files changed

+221
-6
lines changed

6 files changed

+221
-6
lines changed

test/inductor/test_quantization.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import torch
4+
import torch._inductor
5+
import torch._inductor.fx_passes.group_batch_fusion
6+
from torch._dynamo.utils import counters
7+
from torch._inductor.test_case import run_tests, TestCase
8+
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
9+
10+
11+
class TargetCPModule(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
15+
def forward(self, x1, x2):
16+
relued = torch.relu(x1)
17+
tanhed = torch.tanh(relued)
18+
tensor = torch.matmul(
19+
tanhed,
20+
x2,
21+
)
22+
return tensor
23+
24+
25+
class TestQuantization(TestCase):
26+
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
27+
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
28+
return False
29+
for key1 in ref_dict.keys():
30+
key2 = "_orig_mod." + key1
31+
assert key2 in res_dict, f"{key1} does not exist in traced module"
32+
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
33+
return False
34+
return True
35+
36+
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
37+
ref = module(*input)
38+
res = traced(*input)
39+
self.assertEqual(ref, res, rtol=rtol, atol=atol)
40+
41+
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
42+
ref_params = dict(module.named_parameters())
43+
res_params = dict(traced.named_parameters())
44+
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
45+
46+
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
47+
ref_grad = {key: param.grad for key, param in module.named_parameters()}
48+
res_grad = {key: param.grad for key, param in traced.named_parameters()}
49+
self.assertTrue(
50+
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
51+
)
52+
53+
@requires_gpu()
54+
@torch._inductor.config.patch(
55+
pre_grad_fusion_options={},
56+
post_grad_fusion_options={
57+
"activation_quantization_aten_pass": {"quant_type": torch.float8_e5m2}
58+
},
59+
)
60+
def test_activation_quantization_aten(self):
61+
counters.clear()
62+
module = TargetCPModule().to(GPU_TYPE)
63+
input = [
64+
torch.rand((16, 10), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16),
65+
torch.rand((10, 16), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16),
66+
]
67+
traced = torch.compile(module)
68+
ref = module(*input)
69+
res = traced(*input)
70+
self.compare_pred(module, traced, input)
71+
ref.sum().backward()
72+
res.sum().backward()
73+
self.compare_parameters(module, traced)
74+
self.compare_gradients(module, traced)
75+
self.assertEqual(counters["inductor"]["activation_quantization_aten_pass"], 2)
76+
self.assertTrue(torch.allclose(ref, res))
77+
counters.clear()
78+
79+
80+
if __name__ == "__main__":
81+
run_tests()

torch/_inductor/compile_fx.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,16 +362,20 @@ def _recursive_joint_graph_passes(gm: GraphModule) -> None:
362362
joint_graph_passes(gm)
363363

364364

365-
def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None:
365+
def _recursive_post_grad_passes(
366+
gm: GraphModule,
367+
is_inference: bool = False,
368+
is_backward: bool = False,
369+
) -> None:
366370
with dynamo_timed(
367371
"_recursive_post_grad_passes",
368372
log_pt2_compile_event=True,
369373
dynamo_compile_column_us="post_grad_pass_time_us",
370374
):
371375
for subgraph_name in _get_subgraph_names(gm):
372376
subgraph = getattr(gm, subgraph_name)
373-
_recursive_post_grad_passes(subgraph, is_inference)
374-
post_grad_passes(gm, is_inference)
377+
_recursive_post_grad_passes(subgraph, is_inference, is_backward)
378+
post_grad_passes(gm, is_inference, is_backward)
375379

376380

377381
def split_const_gm(
@@ -990,7 +994,7 @@ def log_graph_runnable() -> str:
990994
# has some issues with memory in training
991995
cuda_context = get_cuda_device_context(gm)
992996
with cuda_context:
993-
_recursive_post_grad_passes(gm, is_inference=is_inference)
997+
_recursive_post_grad_passes(gm, is_inference=is_inference, is_backward=is_backward)
994998
V.debug.fx_graph_transformed(gm, example_inputs)
995999
post_grad_graphs_log.debug(
9961000
"%s",

torch/_inductor/fx_passes/post_grad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
]
7171

7272

73-
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
73+
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool, is_backward: bool):
7474
"""
7575
Passes that run on after grad. This is called once on the forwards
7676
graph and once on the backwards graph.
@@ -130,6 +130,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
130130
if pass_name in POST_GRAD_FUSIONS:
131131
continue
132132
pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
133+
pattern_matcher_pass.is_backward = is_backward
133134
inductor_before_change = save_inductor_dict(
134135
[pattern_matcher_pass.pass_name]
135136
)

torch/_inductor/fx_passes/quantization.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,27 @@
1111
from torch._dynamo.utils import counters
1212
from torch.fx.experimental.symbolic_shapes import has_free_symbols
1313
from torch.fx.node import map_arg
14+
from torch.fx.passes.shape_prop import TensorMetadata
1415

1516
from ..lowering import lowerings as L, require_channels_last
16-
from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
17+
from ..pattern_matcher import (
18+
Arg,
19+
CallFunction,
20+
CallFunctionVarArgs,
21+
filter_nodes,
22+
is_backward_pattern,
23+
KeywordArg,
24+
ListOf,
25+
Match,
26+
MULTIPLE,
27+
Placeholder,
28+
register_graph_pattern,
29+
)
1730
from ..utils import pad_listlike
1831
from .freezing_patterns import register_freezing_graph_pattern
32+
from .group_batch_fusion import is_node_meta_valid
1933
from .post_grad import register_lowering_pattern
34+
from .split_cat import construct_pattern_matcher_pass
2035

2136

2237
aten = torch.ops.aten
@@ -3590,3 +3605,100 @@ def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
35903605

35913606
graph_module.graph.lint()
35923607
graph_module.recompile()
3608+
3609+
3610+
activation_quantization_aten_pass = construct_pattern_matcher_pass(
3611+
"activation_quantization_aten_pass"
3612+
)
3613+
@register_graph_pattern(
3614+
CallFunctionVarArgs(
3615+
[
3616+
torch.ops.aten.relu.default,
3617+
torch.ops.aten.tanh.default,
3618+
torch.ops.aten.sigmoid.default,
3619+
torch.ops.aten.gelu.default,
3620+
],
3621+
users=MULTIPLE
3622+
),
3623+
pass_dict=activation_quantization_aten_pass,
3624+
extra_check=is_backward_pattern(activation_quantization_aten_pass, False),
3625+
)
3626+
def quantize_activation_fw(match: Match, *args, **kwargs):
3627+
graph = match.graph
3628+
activation_nodes = match.nodes
3629+
quant_type = torch._inductor.config.post_grad_fusion_options[
3630+
"activation_quantization_aten_pass"
3631+
].get("quant_type", torch.float8_e5m2)
3632+
for activation_node in activation_nodes:
3633+
# check if the activation node is the return node
3634+
users = list(activation_node.users.keys())
3635+
for user in users:
3636+
# check if the user is the return node
3637+
if user.op == "output":
3638+
if not is_node_meta_valid(activation_node):
3639+
continue
3640+
# we need to insert a quantization node after it
3641+
with graph.inserting_after(activation_node):
3642+
quant_activation_node = graph.call_function(
3643+
torch.ops.prims.convert_element_type.default,
3644+
args=(activation_node, quant_type)
3645+
)
3646+
quant_activation_node.meta.update(activation_node.meta)
3647+
quant_activation_node.meta["val"] = quant_activation_node.meta["val"].to(quant_type)
3648+
quant_activation_node.meta["tensor_meta"] = TensorMetadata(
3649+
shape=quant_activation_node.meta["tensor_meta"].shape,
3650+
dtype=quant_type,
3651+
requires_grad=quant_activation_node.meta["tensor_meta"].requires_grad,
3652+
stride=quant_activation_node.meta["tensor_meta"].stride,
3653+
memory_format=quant_activation_node.meta["tensor_meta"].memory_format,
3654+
is_quantized=quant_activation_node.meta["tensor_meta"].is_quantized,
3655+
qparams=quant_activation_node.meta["tensor_meta"].qparams,
3656+
)
3657+
# only update the return node args, and remain all other users unchanged
3658+
user_updated_args = tuple(
3659+
quant_activation_node if node == activation_node else node for node in user.args[0]
3660+
)
3661+
user.update_arg(0, user_updated_args)
3662+
if len(activation_node.users) == 0:
3663+
graph.erase_node(activation_node)
3664+
counters["inductor"]["activation_quantization_aten_pass"] += 1
3665+
break
3666+
3667+
3668+
@register_graph_pattern(
3669+
Placeholder(["tanh", "relu", "sigmoid", "gelu"], users=MULTIPLE),
3670+
pass_dict=activation_quantization_aten_pass,
3671+
extra_check=is_backward_pattern(activation_quantization_aten_pass, True),
3672+
)
3673+
def quantize_activation_bw(match: Match, *args, **kwargs):
3674+
graph = match.graph
3675+
inputs = match.nodes
3676+
quant_type = torch._inductor.config.post_grad_fusion_options[
3677+
"activation_quantization_aten_pass"
3678+
].get("quant_type", torch.float8_e5m2)
3679+
for input in inputs:
3680+
if not is_node_meta_valid(input):
3681+
continue
3682+
# we need to insert a dequantization node after it
3683+
with graph.inserting_after(input):
3684+
dequant_activation_node = graph.call_function(
3685+
torch.ops.prims.convert_element_type.default,
3686+
args=(input, input.meta["val"].dtype)
3687+
)
3688+
input.replace_all_uses_with(dequant_activation_node)
3689+
# restore dequant_activation_node input
3690+
dequant_activation_node.replace_input_with(dequant_activation_node, input)
3691+
dequant_activation_node.meta.update(input.meta)
3692+
3693+
# replace the input with quant type to keep sync with forward pass
3694+
input.meta["val"] = input.meta["val"].to(quant_type)
3695+
input.meta["tensor_meta"] = TensorMetadata(
3696+
shape=input.meta["tensor_meta"].shape,
3697+
dtype=quant_type,
3698+
requires_grad=input.meta["tensor_meta"].requires_grad,
3699+
stride=input.meta["tensor_meta"].stride,
3700+
memory_format=input.meta["tensor_meta"].memory_format,
3701+
is_quantized=input.meta["tensor_meta"].is_quantized,
3702+
qparams=input.meta["tensor_meta"].qparams,
3703+
)
3704+
counters["inductor"]["activation_quantization_aten_pass"] += 1

torch/_inductor/fx_passes/split_cat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
"pad_aten_mm_pass",
7272
"split_cat_aten_pass",
7373
"select_cat_aten_pass",
74+
"activation_quantization_aten_pass",
7475
]
7576

7677
for pass_name in pre_grad_pass_names:

torch/_inductor/pattern_matcher.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,10 @@ class CallModuleVarArgs(_TargetExprVarArgs):
778778
op = "call_module"
779779

780780

781+
class Placeholder(_TargetExprVarArgs):
782+
op = "placeholder"
783+
784+
781785
class ListOf(PatternExpr):
782786
"""
783787
Matches a repeated pattern
@@ -1794,12 +1798,14 @@ class PatternMatcherPass:
17941798
def __init__(
17951799
self,
17961800
pass_name: Optional[str] = None,
1801+
is_backward: Optional[bool] = None,
17971802
) -> None:
17981803
super().__init__()
17991804
self.patterns: defaultdict[
18001805
tuple[str, torch.fx.node.Target], list[PatternEntry]
18011806
] = defaultdict(list)
18021807
self.pass_name = pass_name
1808+
self.is_backward = is_backward
18031809

18041810
# For a particular generated pattern repr, store all of the str representations
18051811
# of the graph used to generate them. Because we ignore certain patterns
@@ -2129,6 +2135,16 @@ def flag_check(match: Match) -> Any:
21292135
return flag_check
21302136

21312137

2138+
def is_backward_pattern(pattern: PatternMatcherPass, backward: bool = True) -> Callable[[Match], Any]:
2139+
"""Function for extra_check to check if it is a pattern for backward only"""
2140+
def backward_check(match: Match) -> Any:
2141+
if backward:
2142+
return pattern.is_backward
2143+
else:
2144+
return not pattern.is_backward
2145+
return backward_check
2146+
2147+
21322148
def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
21332149
class CopyGraph(Transformer):
21342150
def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node:

0 commit comments

Comments
 (0)