Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b254132

Browse files
committed
Merge branch 'qiqix/argmax_to_int32' into 'main'
Change result type for argreduce ops from int64 to int32 See merge request dl/tileir/cutile-python!45
2 parents 09be170 + 444f67c commit b254132

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/cuda/tile/_ir/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2849,7 +2849,7 @@ def argreduce(fn: str, x: Var, axis: Optional[int], keepdims: bool) -> Var:
28492849

28502850
x_dtype = datatype.default_int_type if datatype.is_boolean(x_type.dtype) else x_type.dtype
28512851
x = _promote_and_broadcast_to(x, TileTy(x_dtype, x_shape))
2852-
output_dtype = datatype.int64
2852+
output_dtype = datatype.default_int_type
28532853
output_shape = TupleTy([]) if axis is None else TupleTy(x_shape[:axis] + x_shape[axis + 1:])
28542854
x = add_operation(
28552855
TileArgReduce, TileTy(output_dtype, output_shape),

test/test_reduction.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def test_reduce_flush_to_zero(shape, tile, dtype, reduce_op, tile_op, flush_to_z
344344
@pytest.mark.parametrize("reduce_op, torch_op", argmaxmin_cases)
345345
def test_reduce_argmaxmin(shape, tile, dtype, keepdims, reduce_op, torch_op):
346346
x = make_tensor(shape, dtype=dtype, device='cuda')
347-
y = _squeezed_zeros_like(x, axis=1, keepdims=keepdims).to(torch.int64)
347+
y = _squeezed_zeros_like(x, axis=1, keepdims=keepdims).to(torch.int32)
348348
grid = (ceil(shape[0] / tile), 1, 1)
349349
if len(shape) == 2:
350350
kernel = make_reduce_axis1_2d(reduce_op)
@@ -353,7 +353,7 @@ def test_reduce_argmaxmin(shape, tile, dtype, keepdims, reduce_op, torch_op):
353353
kernel = make_reduce_axis1_3d(reduce_op)
354354
args = (x, y, tile, shape[1], shape[2], keepdims)
355355
ct.launch(torch.cuda.current_stream(), grid, kernel, args)
356-
ref_result = torch_op(x, dim=1, keepdim=keepdims).to(torch.int64)
356+
ref_result = torch_op(x, dim=1, keepdim=keepdims).to(torch.int32)
357357
assert_equal(y, ref_result)
358358

359359

@@ -366,11 +366,11 @@ def test_reduce_argmaxmin_all_axes(shape, dtype, reduce_op, torch_op, keepdims):
366366
grid = (1, 1, 1)
367367
kernel = make_reduce_axisNone(reduce_op)
368368
if keepdims:
369-
y = _squeezed_zeros_like(x, axis=None, keepdims=keepdims).to(torch.int64)
369+
y = _squeezed_zeros_like(x, axis=None, keepdims=keepdims).to(torch.int32)
370370
ct.launch(torch.cuda.current_stream(), grid, kernel, (x, y, shape[0], shape[1], keepdims))
371371
else:
372-
y = torch.zeros((1,) * len(shape), dtype=dtype, device="cuda").to(torch.int64)
372+
y = torch.zeros((1,) * len(shape), dtype=dtype, device="cuda").to(torch.int32)
373373
ct.launch(torch.cuda.current_stream(), grid, kernel, (x, y, shape[0], shape[1], keepdims))
374374
y = y.squeeze()
375-
ref_result = torch_op(x, dim=None, keepdim=keepdims).to(torch.int64)
375+
ref_result = torch_op(x, dim=None, keepdim=keepdims).to(torch.int32)
376376
assert_equal(y, ref_result)

0 commit comments

Comments
 (0)