Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 98fbc70

Browse files
committed
Fixed equality checking for scalars
1 parent b9edb36 commit 98fbc70

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

test/inductor/test_torchinductor_opinfo.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,27 @@ def argsort_equal_checker(
333333
rtol=None,
334334
exact_dtype=True,
335335
):
336-
self.assertEqual(
337-
ref_inputs[0][actual],
338-
ref_inputs[0][correct],
339-
atol=atol,
340-
rtol=rtol,
341-
equal_nan=True,
342-
exact_dtype=exact_dtype,
343-
)
336+
isScalar = len(ref_inputs[0].shape) == 0
337+
338+
if (not isScalar):
339+
self.assertEqual(
340+
ref_inputs[0][actual],
341+
ref_inputs[0][correct],
342+
atol=atol,
343+
rtol=rtol,
344+
equal_nan=True,
345+
exact_dtype=exact_dtype,
346+
)
347+
else:
348+
# Both actual and correct should be a scalar (0)
349+
self.assertEqual(
350+
actual,
351+
correct,
352+
atol=atol,
353+
rtol=rtol,
354+
equal_nan=True,
355+
exact_dtype=exact_dtype,
356+
)
344357

345358

346359
def sort_equal_checker(
@@ -363,12 +376,14 @@ def sort_equal_checker(
363376
exact_dtype=exact_dtype,
364377
)
365378

366-
self.assertEqual(
367-
ref_inputs[0][actual.indices],
368-
ref_inputs[0][correct.indices],
379+
argsort_equal_checker(
380+
self,
381+
ref_inputs,
382+
example_inputs,
383+
correct.indices,
384+
actual.indices,
369385
atol=atol,
370386
rtol=rtol,
371-
equal_nan=True,
372387
exact_dtype=exact_dtype,
373388
)
374389

0 commit comments

Comments
 (0)