Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e6715ba

Browse files
committed
Cleanup SDNQ compile
1 parent cd79f92 commit e6715ba

File tree

9 files changed

+28
-37
lines changed

9 files changed

+28
-37
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit a33753321b914c6122df96d1dc0b5117d38af680

modules/sdnq/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import os
44
import torch
5+
from functools import partial
6+
57
from modules import shared
68

79
torch_version = float(torch.__version__[:3])
@@ -42,3 +44,7 @@
4244
if use_torch_compile:
4345
torch._dynamo.config.cache_size_limit = max(8192, torch._dynamo.config.cache_size_limit)
4446
torch._dynamo.config.accumulated_recompile_limit = max(8192, torch._dynamo.config.accumulated_recompile_limit)
47+
compile_func = partial(torch.compile, fullgraph=True, dynamic=False)
48+
else:
49+
def compile_func(fn, **kwargs):
50+
return fn

modules/sdnq/dequantizer.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from .common import dtype_dict, use_torch_compile
7+
from .common import dtype_dict, compile_func
88
from .packed_int import pack_int_symetric, unpack_int_symetric, pack_int_asymetric, unpack_int_asymetric
99

1010

@@ -226,21 +226,11 @@ def forward(self, weight, skip_quantized_matmul=False, **kwargs): # pylint: disa
226226
}
227227

228228

229-
if use_torch_compile:
230-
dequantize_asymmetric_compiled = torch.compile(dequantize_asymmetric, fullgraph=True, dynamic=False)
231-
dequantize_symmetric_compiled = torch.compile(dequantize_symmetric, fullgraph=True, dynamic=False)
232-
dequantize_packed_int_asymmetric_compiled = torch.compile(dequantize_packed_int_asymmetric, fullgraph=True, dynamic=False)
233-
dequantize_packed_int_symmetric_compiled = torch.compile(dequantize_packed_int_symmetric, fullgraph=True, dynamic=False)
234-
re_quantize_matmul_asymmetric_compiled = torch.compile(re_quantize_matmul_asymmetric, fullgraph=True, dynamic=False)
235-
re_quantize_matmul_symmetric_compiled = torch.compile(re_quantize_matmul_symmetric, fullgraph=True, dynamic=False)
236-
re_quantize_matmul_packed_int_asymmetric_compiled = torch.compile(re_quantize_matmul_packed_int_asymmetric, fullgraph=True, dynamic=False)
237-
re_quantize_matmul_packed_int_symmetric_compiled = torch.compile(re_quantize_matmul_packed_int_symmetric, fullgraph=True, dynamic=False)
238-
else:
239-
dequantize_asymmetric_compiled = dequantize_asymmetric
240-
dequantize_symmetric_compiled = dequantize_symmetric
241-
dequantize_packed_int_asymmetric_compiled = dequantize_packed_int_asymmetric
242-
dequantize_packed_int_symmetric_compiled = dequantize_packed_int_symmetric
243-
re_quantize_matmul_asymmetric_compiled = re_quantize_matmul_asymmetric
244-
re_quantize_matmul_symmetric_compiled = re_quantize_matmul_symmetric
245-
re_quantize_matmul_packed_int_asymmetric_compiled = re_quantize_matmul_packed_int_asymmetric
246-
re_quantize_matmul_packed_int_symmetric_compiled = re_quantize_matmul_packed_int_symmetric
229+
dequantize_asymmetric_compiled = compile_func(dequantize_asymmetric)
230+
dequantize_symmetric_compiled = compile_func(dequantize_symmetric)
231+
dequantize_packed_int_asymmetric_compiled = compile_func(dequantize_packed_int_asymmetric)
232+
dequantize_packed_int_symmetric_compiled = compile_func(dequantize_packed_int_symmetric)
233+
re_quantize_matmul_asymmetric_compiled = compile_func(re_quantize_matmul_asymmetric)
234+
re_quantize_matmul_symmetric_compiled = compile_func(re_quantize_matmul_symmetric)
235+
re_quantize_matmul_packed_int_asymmetric_compiled = compile_func(re_quantize_matmul_packed_int_asymmetric)
236+
re_quantize_matmul_packed_int_symmetric_compiled = compile_func(re_quantize_matmul_packed_int_symmetric)

modules/sdnq/layers/conv/conv_fp8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from ...common import use_torch_compile # noqa: TID252
7+
from ...common import compile_func # noqa: TID252
88
from ..linear.linear_fp8 import quantize_fp8_matmul_input # noqa: TID252
99
from .forward import get_conv_args, process_conv_input
1010

@@ -68,5 +68,4 @@ def quantized_conv_forward_fp8_matmul(self, input) -> torch.FloatTensor:
6868
)
6969

7070

71-
if use_torch_compile:
72-
conv_fp8_matmul = torch.compile(conv_fp8_matmul, fullgraph=True, dynamic=False)
71+
conv_fp8_matmul = compile_func(conv_fp8_matmul)

modules/sdnq/layers/conv/conv_fp8_tensorwise.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from ...common import use_torch_compile # noqa: TID252
7+
from ...common import compile_func # noqa: TID252
88
from ...dequantizer import dequantize_symmetric, dequantize_symmetric_with_bias # noqa: TID252
99
from ..linear.linear_fp8_tensorwise import quantize_fp8_matmul_input_tensorwise # noqa: TID252
1010
from .forward import get_conv_args, process_conv_input
@@ -63,5 +63,4 @@ def quantized_conv_forward_fp8_matmul_tensorwise(self, input) -> torch.FloatTens
6363
)
6464

6565

66-
if use_torch_compile:
67-
conv_fp8_matmul_tensorwise = torch.compile(conv_fp8_matmul_tensorwise, fullgraph=True, dynamic=False)
66+
conv_fp8_matmul_tensorwise = compile_func(conv_fp8_matmul_tensorwise)

modules/sdnq/layers/conv/conv_int8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from ...common import use_torch_compile # noqa: TID252
7+
from ...common import compile_func # noqa: TID252
88
from ...packed_int import unpack_int_symetric # noqa: TID252
99
from ...dequantizer import dequantize_symmetric, dequantize_symmetric_with_bias # noqa: TID252
1010
from ..linear.linear_int8 import quantize_int8_matmul_input # noqa: TID252
@@ -75,5 +75,4 @@ def quantized_conv_forward_int8_matmul(self, input) -> torch.FloatTensor:
7575
)
7676

7777

78-
if use_torch_compile:
79-
conv_int8_matmul = torch.compile(conv_int8_matmul, fullgraph=True, dynamic=False)
78+
conv_int8_matmul = compile_func(conv_int8_matmul)

modules/sdnq/layers/linear/linear_fp8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from ...common import use_torch_compile # noqa: TID252
7+
from ...common import compile_func # noqa: TID252
88
from ...dequantizer import quantize_fp8 # noqa: TID252
99

1010

@@ -34,5 +34,4 @@ def quantized_linear_forward_fp8_matmul(self, input: torch.FloatTensor) -> torch
3434
return fp8_matmul(input, self.weight, self.bias, self.sdnq_dequantizer.scale)
3535

3636

37-
if use_torch_compile:
38-
fp8_matmul = torch.compile(fp8_matmul, fullgraph=True, dynamic=False)
37+
fp8_matmul = compile_func(fp8_matmul)

modules/sdnq/layers/linear/linear_fp8_tensorwise.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from ...common import use_torch_compile # noqa: TID252
7+
from ...common import compile_func # noqa: TID252
88
from ...dequantizer import quantize_fp8, dequantize_symmetric, dequantize_symmetric_with_bias # noqa: TID252
99

1010

@@ -39,5 +39,4 @@ def quantized_linear_forward_fp8_matmul_tensorwise(self, input: torch.FloatTenso
3939
return fp8_matmul_tensorwise(input, self.weight, self.bias, self.sdnq_dequantizer.scale)
4040

4141

42-
if use_torch_compile:
43-
fp8_matmul_tensorwise = torch.compile(fp8_matmul_tensorwise, fullgraph=True, dynamic=False)
42+
fp8_matmul_tensorwise = compile_func(fp8_matmul_tensorwise)

modules/sdnq/layers/linear/linear_int8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from ...common import use_torch_compile # noqa: TID252
7+
from ...common import compile_func # noqa: TID252
88
from ...packed_int import unpack_int_symetric # noqa: TID252
99
from ...dequantizer import quantize_int8, dequantize_symmetric, dequantize_symmetric_with_bias # noqa: TID252
1010

@@ -50,5 +50,4 @@ def quantized_linear_forward_int8_matmul(self, input: torch.FloatTensor) -> torc
5050
return int8_matmul(input, weight, self.bias, scale, quantized_weight_shape, self.sdnq_dequantizer.weights_dtype)
5151

5252

53-
if use_torch_compile:
54-
int8_matmul = torch.compile(int8_matmul, fullgraph=True, dynamic=False)
53+
int8_matmul = compile_func(int8_matmul)

0 commit comments

Comments
 (0)