From 83ca5c5851dd91b484cad7b1e6c10a8478be907b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 May 2025 15:31:02 -0700 Subject: [PATCH 1/6] Fixing aliasing behavior for slice in AQT TensorCoreTiledLayout Summary: slice op is supposed to preserve aliasing (output of slice should alias the input), but this is not true for TensorCoreTiledLayout (used by int4wo), and some others like gemlite Reason is that we do unpacking, pading and prepacking right now, which creates new tensors. We fixes it in this PR by doing slicing on the packed inner Tensor directly, specifically packed_weight and scale_and_zero in TensorCoreTiledLayout. Test Plan: python test/dtypes/test_affine_quantized.py -k test_slice_and_copy_int4wo Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 25 +++++++++ .../dtypes/uintx/tensor_core_tiled_layout.py | 52 ++++++++++++------- 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 9dc195a6da..35147af907 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -387,6 +387,31 @@ def test_matmul(self, device, dtype): # make sure it runs torch.matmul(x, w.t()) + @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("dtype", [torch.bfloat16]) + @skip_if_no_cuda() + @skip_if_rocm("ROCm enablement in progress") + def test_slice_and_copy_int4wo(self, device, dtype): + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter(torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")) + quantize_(l, Int4WeightOnlyConfig()) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + assert param.data.dequantize()[0][0] == 0 + + # dummy_l has random input (shouldn't be 0) + dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + quantize_(dummy_l, Int4WeightOnlyConfig()) + quantized = dummy_l.weight + quantized = quantized.narrow(0, 0, 512) + + param_data.copy_(quantized) + + # making sure param.data is updated + assert param.data.dequantize()[0][0] != 0 + + common_utils.instantiate_parametrized_tests(TestAffineQuantized) common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 240561b741..a3402faff7 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -350,25 +350,39 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim in [0, 1]: - int_data, scale, zero_point = self.get_plain() - data_len = int_data.shape[dim] - scale_len = scale.shape[dim] - ratio = data_len / scale_len - start_scale = int(start / ratio) - end_scale = int(end / ratio) - - int_data = aten.slice.Tensor(int_data, dim, start, end, step) - scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - zero_point = aten.slice.Tensor( - zero_point, dim, start_scale, end_scale, step - ) - # this is to handle padding - int_data, scale, zero_point = self._layout.post_process( - int_data, scale, zero_point, self.block_size - ) - sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return return_and_correct_aliasing(func, args, kwargs, sliced) + n_by_8, k_by_inner_tiles, _, _ = self.packed_weight.shape + sz_dim1, sz_dim0, _ = self.scale_and_zero.shape + data_len = self.shape[dim] + if dim == 0: + pw_len = n_by_8 + sz_len = sz_dim0 + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) + scale_and_zero = aten.slice(self.scale_and_zero, 1-dim, start_sz, end_sz, step) + return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(self.packed_weight, self.scale_and_zero, self.transposed, self._layout)) + elif dim == 1: + pw_len = k_by_inner_tiles + sz_len = sz_dim1 + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) + scale_and_zero = aten.slice(self.scale_and_zero, 1-dim, start_sz, end_sz, step) + return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(self.packed_weight, self.scale_and_zero, self.transposed, self._layout)) else: raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" From 801d76f54cf9f5e4edacbd17ff694dd84a0727ad Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 May 2025 15:54:14 -0700 Subject: [PATCH 2/6] simplify code --- .../dtypes/uintx/tensor_core_tiled_layout.py | 38 ++++++------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index a3402faff7..6997fb038e 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -353,40 +353,26 @@ def __torch_dispatch__(cls, func, types, args, kwargs): n_by_8, k_by_inner_tiles, _, _ = self.packed_weight.shape sz_dim1, sz_dim0, _ = self.scale_and_zero.shape data_len = self.shape[dim] + assert dim in [0, 1], f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + if dim == 0: pw_len = n_by_8 sz_len = sz_dim0 - - pw_ratio = data_len / pw_len - start_pw = int(start / pw_ratio) - end_pw = int(end / pw_ratio) - - sz_ratio = data_len / sz_len - start_sz = int(start / sz_ratio) - end_sz = int(end / sz_ratio) - - packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) - scale_and_zero = aten.slice(self.scale_and_zero, 1-dim, start_sz, end_sz, step) - return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(self.packed_weight, self.scale_and_zero, self.transposed, self._layout)) - elif dim == 1: + else: pw_len = k_by_inner_tiles sz_len = sz_dim1 - pw_ratio = data_len / pw_len - start_pw = int(start / pw_ratio) - end_pw = int(end / pw_ratio) + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) - sz_ratio = data_len / sz_len - start_sz = int(start / sz_ratio) - end_sz = int(end / sz_ratio) + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) - packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) - scale_and_zero = aten.slice(self.scale_and_zero, 1-dim, start_sz, end_sz, step) - return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(self.packed_weight, self.scale_and_zero, self.transposed, self._layout)) - else: - raise NotImplementedError( - f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) + packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) + scale_and_zero = aten.slice(self.scale_and_zero, 1-dim, start_sz, end_sz, step) + return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(self.packed_weight, self.scale_and_zero, self.transposed, self._layout)) raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" From d06f1a15597c484e0f0a446bd6845a69d6d740a0 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 May 2025 16:02:33 -0700 Subject: [PATCH 3/6] add check for data_ptr --- test/dtypes/test_affine_quantized.py | 2 ++ torchao/dtypes/uintx/tensor_core_tiled_layout.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 35147af907..374d936d26 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -398,6 +398,8 @@ def test_slice_and_copy_int4wo(self, device, dtype): param = l.weight param_data = param.data param_data = param_data.narrow(0, 0, 512) + assert param.data.tensor_impl.packed_weight.data_ptr() == param_data.tensor_impl.packed_weight.data_ptr() + assert param.data.tensor_impl.scale_and_zero.data_ptr() == param_data.tensor_impl.scale_and_zero.data_ptr() assert param.data.dequantize()[0][0] == 0 # dummy_l has random input (shouldn't be 0) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 6997fb038e..82ee529601 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -372,7 +372,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) scale_and_zero = aten.slice(self.scale_and_zero, 1-dim, start_sz, end_sz, step) - return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(self.packed_weight, self.scale_and_zero, self.transposed, self._layout)) + return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(packed_weight, scale_and_zero, self.transposed, self._layout)) raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" From df064d00970e24782fb0fcda500ac4724e3388e9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 May 2025 16:05:15 -0700 Subject: [PATCH 4/6] format --- test/dtypes/test_affine_quantized.py | 15 +++++++++++---- .../dtypes/uintx/tensor_core_tiled_layout.py | 17 ++++++++++++++--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 374d936d26..4ed39a0eff 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -393,13 +393,21 @@ def test_matmul(self, device, dtype): @skip_if_rocm("ROCm enablement in progress") def test_slice_and_copy_int4wo(self, device, dtype): l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - l.weight = torch.nn.Parameter(torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) quantize_(l, Int4WeightOnlyConfig()) param = l.weight param_data = param.data param_data = param_data.narrow(0, 0, 512) - assert param.data.tensor_impl.packed_weight.data_ptr() == param_data.tensor_impl.packed_weight.data_ptr() - assert param.data.tensor_impl.scale_and_zero.data_ptr() == param_data.tensor_impl.scale_and_zero.data_ptr() + assert ( + param.data.tensor_impl.packed_weight.data_ptr() + == param_data.tensor_impl.packed_weight.data_ptr() + ) + assert ( + param.data.tensor_impl.scale_and_zero.data_ptr() + == param_data.tensor_impl.scale_and_zero.data_ptr() + ) assert param.data.dequantize()[0][0] == 0 # dummy_l has random input (shouldn't be 0) @@ -414,7 +422,6 @@ def test_slice_and_copy_int4wo(self, device, dtype): assert param.data.dequantize()[0][0] != 0 - common_utils.instantiate_parametrized_tests(TestAffineQuantized) common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 82ee529601..f38fb0647e 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -353,7 +353,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): n_by_8, k_by_inner_tiles, _, _ = self.packed_weight.shape sz_dim1, sz_dim0, _ = self.scale_and_zero.shape data_len = self.shape[dim] - assert dim in [0, 1], f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + assert dim in [0, 1], ( + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) if dim == 0: pw_len = n_by_8 @@ -371,8 +373,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs): end_sz = int(end / sz_ratio) packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) - scale_and_zero = aten.slice(self.scale_and_zero, 1-dim, start_sz, end_sz, step) - return return_and_correct_aliasing(func, args, kwargs, TensorCoreTiledAQTTensorImpl(packed_weight, scale_and_zero, self.transposed, self._layout)) + scale_and_zero = aten.slice( + self.scale_and_zero, 1 - dim, start_sz, end_sz, step + ) + return return_and_correct_aliasing( + func, + args, + kwargs, + TensorCoreTiledAQTTensorImpl( + packed_weight, scale_and_zero, self.transposed, self._layout + ), + ) raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" From 9d70683fee3e4b411a9d1037d72b341888f30b9a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 May 2025 16:32:50 -0700 Subject: [PATCH 5/6] avoid div by zero --- torchao/dtypes/uintx/tensor_core_tiled_layout.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index f38fb0647e..9b66688478 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -364,6 +364,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs): pw_len = k_by_inner_tiles sz_len = sz_dim1 + if pw_len == 0 or sz_len == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + TensorCoreTiledAQTTensorImpl( + self.packed_weight, self.scale_and_zero, self.transposed, self._layout + ), + ) + pw_ratio = data_len / pw_len start_pw = int(start / pw_ratio) end_pw = int(end / pw_ratio) From 2fb26debde6bb218b62fde0b076235a62cd91aec Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 May 2025 16:37:07 -0700 Subject: [PATCH 6/6] format --- torchao/dtypes/uintx/tensor_core_tiled_layout.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 9b66688478..0ba2720ec1 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -370,7 +370,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): args, kwargs, TensorCoreTiledAQTTensorImpl( - self.packed_weight, self.scale_and_zero, self.transposed, self._layout + self.packed_weight, + self.scale_and_zero, + self.transposed, + self._layout, ), )