File tree 1 file changed +27
-12
lines changed
1 file changed +27
-12
lines changed Original file line number Diff line number Diff line change @@ -333,14 +333,27 @@ def argsort_equal_checker(
333
333
rtol = None ,
334
334
exact_dtype = True ,
335
335
):
336
- self .assertEqual (
337
- ref_inputs [0 ][actual ],
338
- ref_inputs [0 ][correct ],
339
- atol = atol ,
340
- rtol = rtol ,
341
- equal_nan = True ,
342
- exact_dtype = exact_dtype ,
343
- )
336
+ isScalar = len (ref_inputs [0 ].shape ) == 0
337
+
338
+ if (not isScalar ):
339
+ self .assertEqual (
340
+ ref_inputs [0 ][actual ],
341
+ ref_inputs [0 ][correct ],
342
+ atol = atol ,
343
+ rtol = rtol ,
344
+ equal_nan = True ,
345
+ exact_dtype = exact_dtype ,
346
+ )
347
+ else :
348
+ # Both actual and correct should be a scalar (0)
349
+ self .assertEqual (
350
+ actual ,
351
+ correct ,
352
+ atol = atol ,
353
+ rtol = rtol ,
354
+ equal_nan = True ,
355
+ exact_dtype = exact_dtype ,
356
+ )
344
357
345
358
346
359
def sort_equal_checker (
@@ -363,12 +376,14 @@ def sort_equal_checker(
363
376
exact_dtype = exact_dtype ,
364
377
)
365
378
366
- self .assertEqual (
367
- ref_inputs [0 ][actual .indices ],
368
- ref_inputs [0 ][correct .indices ],
379
+ argsort_equal_checker (
380
+ self ,
381
+ ref_inputs ,
382
+ example_inputs ,
383
+ correct .indices ,
384
+ actual .indices ,
369
385
atol = atol ,
370
386
rtol = rtol ,
371
- equal_nan = True ,
372
387
exact_dtype = exact_dtype ,
373
388
)
374
389
You can’t perform that action at this time.
0 commit comments