Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,40 @@ def test_matmul(self, device, dtype):
# make sure it runs
torch.matmul(x, w.t())

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
@skip_if_rocm("ROCm enablement in progress")
def test_slice_and_copy_int4wo(self, device, dtype):
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
)
quantize_(l, Int4WeightOnlyConfig())
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert (
param.data.tensor_impl.packed_weight.data_ptr()
== param_data.tensor_impl.packed_weight.data_ptr()
)
assert (
param.data.tensor_impl.scale_and_zero.data_ptr()
== param_data.tensor_impl.scale_and_zero.data_ptr()
)
assert param.data.dequantize()[0][0] == 0

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
quantize_(dummy_l, Int4WeightOnlyConfig())
quantized = dummy_l.weight
quantized = quantized.narrow(0, 0, 512)

param_data.copy_(quantized)

# making sure param.data is updated
assert param.data.dequantize()[0][0] != 0


common_utils.instantiate_parametrized_tests(TestAffineQuantized)
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)
Expand Down
66 changes: 45 additions & 21 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,30 +350,54 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

if func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim in [0, 1]:
int_data, scale, zero_point = self.get_plain()
data_len = int_data.shape[dim]
scale_len = scale.shape[dim]
ratio = data_len / scale_len
start_scale = int(start / ratio)
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
zero_point = aten.slice.Tensor(
zero_point, dim, start_scale, end_scale, step
)
# this is to handle padding
int_data, scale, zero_point = self._layout.post_process(
int_data, scale, zero_point, self.block_size
)
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return return_and_correct_aliasing(func, args, kwargs, sliced)
n_by_8, k_by_inner_tiles, _, _ = self.packed_weight.shape
sz_dim1, sz_dim0, _ = self.scale_and_zero.shape
data_len = self.shape[dim]
assert dim in [0, 1], (
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)

if dim == 0:
pw_len = n_by_8
sz_len = sz_dim0
else:
raise NotImplementedError(
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
pw_len = k_by_inner_tiles
sz_len = sz_dim1

if pw_len == 0 or sz_len == 0:
return return_and_correct_aliasing(
func,
args,
kwargs,
TensorCoreTiledAQTTensorImpl(
self.packed_weight,
self.scale_and_zero,
self.transposed,
self._layout,
),
)

pw_ratio = data_len / pw_len
start_pw = int(start / pw_ratio)
end_pw = int(end / pw_ratio)

sz_ratio = data_len / sz_len
start_sz = int(start / sz_ratio)
end_sz = int(end / sz_ratio)

packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step)
scale_and_zero = aten.slice(
self.scale_and_zero, 1 - dim, start_sz, end_sz, step
)
return return_and_correct_aliasing(
func,
args,
kwargs,
TensorCoreTiledAQTTensorImpl(
packed_weight, scale_and_zero, self.transposed, self._layout
),
)

raise NotImplementedError(
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)
Expand Down
Loading