@@ -344,7 +344,7 @@ def test_reduce_flush_to_zero(shape, tile, dtype, reduce_op, tile_op, flush_to_z
344344@pytest .mark .parametrize ("reduce_op, torch_op" , argmaxmin_cases )
345345def test_reduce_argmaxmin (shape , tile , dtype , keepdims , reduce_op , torch_op ):
346346 x = make_tensor (shape , dtype = dtype , device = 'cuda' )
347- y = _squeezed_zeros_like (x , axis = 1 , keepdims = keepdims ).to (torch .int64 )
347+ y = _squeezed_zeros_like (x , axis = 1 , keepdims = keepdims ).to (torch .int32 )
348348 grid = (ceil (shape [0 ] / tile ), 1 , 1 )
349349 if len (shape ) == 2 :
350350 kernel = make_reduce_axis1_2d (reduce_op )
@@ -353,7 +353,7 @@ def test_reduce_argmaxmin(shape, tile, dtype, keepdims, reduce_op, torch_op):
353353 kernel = make_reduce_axis1_3d (reduce_op )
354354 args = (x , y , tile , shape [1 ], shape [2 ], keepdims )
355355 ct .launch (torch .cuda .current_stream (), grid , kernel , args )
356- ref_result = torch_op (x , dim = 1 , keepdim = keepdims ).to (torch .int64 )
356+ ref_result = torch_op (x , dim = 1 , keepdim = keepdims ).to (torch .int32 )
357357 assert_equal (y , ref_result )
358358
359359
@@ -366,11 +366,11 @@ def test_reduce_argmaxmin_all_axes(shape, dtype, reduce_op, torch_op, keepdims):
366366 grid = (1 , 1 , 1 )
367367 kernel = make_reduce_axisNone (reduce_op )
368368 if keepdims :
369- y = _squeezed_zeros_like (x , axis = None , keepdims = keepdims ).to (torch .int64 )
369+ y = _squeezed_zeros_like (x , axis = None , keepdims = keepdims ).to (torch .int32 )
370370 ct .launch (torch .cuda .current_stream (), grid , kernel , (x , y , shape [0 ], shape [1 ], keepdims ))
371371 else :
372- y = torch .zeros ((1 ,) * len (shape ), dtype = dtype , device = "cuda" ).to (torch .int64 )
372+ y = torch .zeros ((1 ,) * len (shape ), dtype = dtype , device = "cuda" ).to (torch .int32 )
373373 ct .launch (torch .cuda .current_stream (), grid , kernel , (x , y , shape [0 ], shape [1 ], keepdims ))
374374 y = y .squeeze ()
375- ref_result = torch_op (x , dim = None , keepdim = keepdims ).to (torch .int64 )
375+ ref_result = torch_op (x , dim = None , keepdim = keepdims ).to (torch .int32 )
376376 assert_equal (y , ref_result )
0 commit comments