Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dc3fa73

Browse files
mengluy0125facebook-github-bot
authored andcommitted
[Optimus][Auto-AC] Support activation quantization (#148380)
Summary: We enable the activation quantization in the forward pass, and users can customize the dtype they want to quantize. Test Plan: # unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization -- test_activation_quantization_aten ``` Buck UI: https://www.internalfb.com/buck2/fc8469a3-54f7-425d-9b1f-e54840c0793a Test UI: https://www.internalfb.com/intern/testinfra/testrun/3377699989853946 Network: Up: 10KiB Down: 0B (reSessionID-ab248457-6ac0-4b72-96da-0d3c427e260a) Executing actions. Remaining 0/3 0.7s exec time total Command: test. Finished 2 local Time elapsed: 1:03.1s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E ### how to enable (you can overrite the dtype and clamp range, if nothing given, the default is fp8) ``` post_grad_fusion_options={ "activation_quantization_aten_pass": {"quant_type": torch.float8_e5m2, "clamp_min": -57344.0, "clamp_max": 57344.0} }, ``` Differential Revision: D70522237
1 parent 3d62e81 commit dc3fa73

File tree

7 files changed

+387
-6
lines changed

7 files changed

+387
-6
lines changed

test/inductor/test_quantization.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import logging
4+
5+
import numpy as np
6+
import torch
7+
import torch._inductor
8+
import torch._inductor.fx_passes.group_batch_fusion
9+
from torch._dynamo.utils import counters
10+
from torch._inductor.test_case import run_tests, TestCase
11+
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
12+
13+
log = logging.getLogger(__name__)
14+
15+
16+
class TargetCPModule(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(self, x1, x2):
21+
relued = torch.relu(x1)
22+
tanhed = torch.tanh(relued)
23+
tensor = torch.matmul(
24+
tanhed,
25+
x2,
26+
)
27+
return tensor
28+
29+
30+
class FeedforwardNN(torch.nn.Module):
31+
def __init__(self):
32+
super(FeedforwardNN, self).__init__()
33+
self.fc1 = torch.nn.Linear(1, 64)
34+
self.fc2 = torch.nn.Linear(64, 64)
35+
self.fc3 = torch.nn.Linear(64, 64)
36+
self.fc4 = torch.nn.Linear(64, 1)
37+
38+
def forward(self, x):
39+
x = torch.relu(self.fc1(x))
40+
tanh_x = torch.tanh(x)
41+
x = torch.relu(self.fc2(x))
42+
x = torch.relu(self.fc3(tanh_x))
43+
x = self.fc4(x)
44+
return x
45+
46+
47+
class TestQuantization(TestCase):
48+
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
49+
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
50+
return False
51+
for key1 in ref_dict.keys():
52+
key2 = "_orig_mod." + key1
53+
assert key2 in res_dict, f"{key1} does not exist in traced module"
54+
# if both of them are None, continue
55+
if (
56+
not isinstance(ref_dict[key1], torch.Tensor)
57+
and not isinstance(res_dict[key2], torch.Tensor)
58+
and ref_dict[key1] is None
59+
and res_dict[key2] is None
60+
):
61+
log.info(f"None found with key1 and value 1: {key1, ref_dict[key1]}, key2 and value2 {key2, res_dict[key2]}")
62+
continue
63+
elif not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol, equal_nan=True):
64+
log.info(f"gradient mismatch for eager and compiled modules, with eager: {ref_dict[key1]} and compiled: {res_dict[key2]}")
65+
return False
66+
return True
67+
68+
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
69+
ref = module(*input)
70+
res = traced(*input)
71+
self.assertEqual(ref, res, rtol=rtol, atol=atol)
72+
73+
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
74+
ref_params = dict(module.named_parameters())
75+
res_params = dict(traced.named_parameters())
76+
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
77+
78+
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
79+
ref_grad = {key: param.grad for key, param in module.named_parameters()}
80+
res_grad = {key: param.grad for key, param in traced.named_parameters()}
81+
self.assertTrue(
82+
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
83+
)
84+
85+
@requires_gpu()
86+
@torch._inductor.config.patch(
87+
pre_grad_fusion_options={},
88+
post_grad_fusion_options={
89+
"activation_quantization_aten_pass": {"quant_type": torch.float8_e5m2}
90+
},
91+
)
92+
def test_activation_quantization_aten(self):
93+
counters.clear()
94+
module = TargetCPModule().to(GPU_TYPE)
95+
input = [
96+
torch.rand(
97+
(16, 10), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
98+
),
99+
torch.rand(
100+
(10, 16), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
101+
),
102+
]
103+
traced = torch.compile(module)
104+
ref = module(*input)
105+
res = traced(*input)
106+
self.compare_pred(module, traced, input)
107+
ref.sum().backward()
108+
res.sum().backward()
109+
self.compare_parameters(module, traced)
110+
self.compare_gradients(module, traced)
111+
self.assertEqual(counters["inductor"]["activation_quantization_aten_pass"], 3)
112+
self.assertTrue(torch.allclose(ref, res))
113+
counters.clear()
114+
115+
module = FeedforwardNN().to(GPU_TYPE)
116+
X = np.linspace(-10, 10, 100).reshape(-1, 1).astype(np.float32)
117+
input = [
118+
torch.from_numpy(X).to(GPU_TYPE),
119+
]
120+
traced = torch.compile(module)
121+
ref = module(*input)
122+
res = traced(*input)
123+
self.compare_pred(module, traced, input)
124+
ref.sum().backward()
125+
res.sum().backward()
126+
self.compare_parameters(module, traced)
127+
self.compare_gradients(module, traced)
128+
self.assertEqual(counters["inductor"]["activation_quantization_aten_pass"], 4)
129+
self.assertTrue(torch.allclose(ref, res))
130+
counters.clear()
131+
132+
133+
if __name__ == "__main__":
134+
run_tests()

torch/_functorch/partitioners.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,6 +1994,14 @@ def classify_nodes(joint_module):
19941994
joint_module, fw_module, bw_module, len(saved_sym_nodes)
19951995
)
19961996
bw_module = reordering_to_mimic_autograd_engine(bw_module)
1997+
# tag all activation nodes as quantized nodes, we can customized this later
1998+
for output in fw_module.graph.find_nodes(op="output"):
1999+
for node in output.args[0]:
2000+
if node.target in [torch.ops.aten.relu.default, torch.ops.aten.tanh.default]:
2001+
node.meta["saved_for_quantization"] = True
2002+
for placeholder in bw_module.graph.find_nodes(op="placeholder"):
2003+
if any(name in str(placeholder.target) for name in ["relu", "tanh", "sigmoid", "gelu"]):
2004+
placeholder.meta["saved_for_quantization"] = True
19972005

19982006
if AOT_PARTITIONER_DEBUG:
19992007
# Calculate sorted sizes of saved values

torch/_inductor/compile_fx.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,16 +370,22 @@ def _recursive_joint_graph_passes(gm: GraphModule) -> None:
370370
joint_graph_passes(gm)
371371

372372

373-
def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None:
373+
def _recursive_post_grad_passes(
374+
gm: GraphModule,
375+
is_inference: bool = False,
376+
is_backward: bool = False,
377+
) -> None:
374378
with dynamo_timed(
375379
"_recursive_post_grad_passes",
376380
log_pt2_compile_event=True,
377381
dynamo_compile_column_us="post_grad_pass_time_us",
378382
):
379383
for subgraph_name in _get_subgraph_names(gm):
380384
subgraph = getattr(gm, subgraph_name)
381-
_recursive_post_grad_passes(subgraph, is_inference)
382-
post_grad_passes(gm, is_inference)
385+
_recursive_post_grad_passes(
386+
subgraph, is_inference, is_backward)
387+
388+
post_grad_passes(gm, is_inference, is_backward)
383389

384390

385391
def split_const_gm(
@@ -982,7 +988,7 @@ def log_graph_runnable() -> str:
982988
# has some issues with memory in training
983989
cuda_context = get_cuda_device_context(gm)
984990
with cuda_context:
985-
_recursive_post_grad_passes(gm, is_inference=is_inference)
991+
_recursive_post_grad_passes(gm, is_inference=is_inference, is_backward=is_backward)
986992
V.debug.fx_graph_transformed(gm, example_inputs)
987993
post_grad_graphs_log.debug(
988994
"%s",

torch/_inductor/fx_passes/post_grad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
]
7373

7474

75-
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
75+
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool, is_backward: bool):
7676
"""
7777
Passes that run on after grad. This is called once on the forwards
7878
graph and once on the backwards graph.
@@ -132,6 +132,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
132132
if pass_name in POST_GRAD_FUSIONS:
133133
continue
134134
pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
135+
pattern_matcher_pass.is_backward = is_backward
135136
inductor_before_change = save_inductor_dict(
136137
[pattern_matcher_pass.pass_name]
137138
)

0 commit comments

Comments
 (0)