|
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 |
|
7 |
| -from .common import dtype_dict, use_torch_compile |
| 7 | +from .common import dtype_dict, compile_func |
8 | 8 | from .packed_int import pack_int_symetric, unpack_int_symetric, pack_int_asymetric, unpack_int_asymetric
|
9 | 9 |
|
10 | 10 |
|
@@ -226,21 +226,11 @@ def forward(self, weight, skip_quantized_matmul=False, **kwargs): # pylint: disa
|
226 | 226 | }
|
227 | 227 |
|
228 | 228 |
|
229 |
| -if use_torch_compile: |
230 |
| - dequantize_asymmetric_compiled = torch.compile(dequantize_asymmetric, fullgraph=True, dynamic=False) |
231 |
| - dequantize_symmetric_compiled = torch.compile(dequantize_symmetric, fullgraph=True, dynamic=False) |
232 |
| - dequantize_packed_int_asymmetric_compiled = torch.compile(dequantize_packed_int_asymmetric, fullgraph=True, dynamic=False) |
233 |
| - dequantize_packed_int_symmetric_compiled = torch.compile(dequantize_packed_int_symmetric, fullgraph=True, dynamic=False) |
234 |
| - re_quantize_matmul_asymmetric_compiled = torch.compile(re_quantize_matmul_asymmetric, fullgraph=True, dynamic=False) |
235 |
| - re_quantize_matmul_symmetric_compiled = torch.compile(re_quantize_matmul_symmetric, fullgraph=True, dynamic=False) |
236 |
| - re_quantize_matmul_packed_int_asymmetric_compiled = torch.compile(re_quantize_matmul_packed_int_asymmetric, fullgraph=True, dynamic=False) |
237 |
| - re_quantize_matmul_packed_int_symmetric_compiled = torch.compile(re_quantize_matmul_packed_int_symmetric, fullgraph=True, dynamic=False) |
238 |
| -else: |
239 |
| - dequantize_asymmetric_compiled = dequantize_asymmetric |
240 |
| - dequantize_symmetric_compiled = dequantize_symmetric |
241 |
| - dequantize_packed_int_asymmetric_compiled = dequantize_packed_int_asymmetric |
242 |
| - dequantize_packed_int_symmetric_compiled = dequantize_packed_int_symmetric |
243 |
| - re_quantize_matmul_asymmetric_compiled = re_quantize_matmul_asymmetric |
244 |
| - re_quantize_matmul_symmetric_compiled = re_quantize_matmul_symmetric |
245 |
| - re_quantize_matmul_packed_int_asymmetric_compiled = re_quantize_matmul_packed_int_asymmetric |
246 |
| - re_quantize_matmul_packed_int_symmetric_compiled = re_quantize_matmul_packed_int_symmetric |
| 229 | +dequantize_asymmetric_compiled = compile_func(dequantize_asymmetric) |
| 230 | +dequantize_symmetric_compiled = compile_func(dequantize_symmetric) |
| 231 | +dequantize_packed_int_asymmetric_compiled = compile_func(dequantize_packed_int_asymmetric) |
| 232 | +dequantize_packed_int_symmetric_compiled = compile_func(dequantize_packed_int_symmetric) |
| 233 | +re_quantize_matmul_asymmetric_compiled = compile_func(re_quantize_matmul_asymmetric) |
| 234 | +re_quantize_matmul_symmetric_compiled = compile_func(re_quantize_matmul_symmetric) |
| 235 | +re_quantize_matmul_packed_int_asymmetric_compiled = compile_func(re_quantize_matmul_packed_int_asymmetric) |
| 236 | +re_quantize_matmul_packed_int_symmetric_compiled = compile_func(re_quantize_matmul_packed_int_symmetric) |
0 commit comments