File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -389,6 +389,8 @@ def standardize_dtype(dtype):
389
389
"torch" in str (dtype ) or "jax.numpy" in str (dtype )
390
390
):
391
391
dtype = str (dtype ).split ("." )[- 1 ]
392
+ elif hasattr (dtype , "__name__" ):
393
+ dtype = dtype .__name__
392
394
393
395
if dtype not in ALLOWED_DTYPES :
394
396
raise ValueError (f"Invalid dtype: { dtype } " )
@@ -414,10 +416,10 @@ def standardize_shape(shape):
414
416
if config .backend () == "jax" and str (e ) == "b" :
415
417
# JAX2TF tracing represents `None` dimensions as `b`
416
418
continue
417
- if not isinstance ( e , int ):
419
+ if not is_int_dtype ( type ( e ) ):
418
420
raise ValueError (
419
421
f"Cannot convert '{ shape } ' to a shape. "
420
- f"Found invalid entry '{ e } '. "
422
+ f"Found invalid entry '{ e } ' of type ' { type ( e ) } ' . "
421
423
)
422
424
if e < 0 :
423
425
raise ValueError (
Original file line number Diff line number Diff line change @@ -131,3 +131,12 @@ def test_call_method(self):
131
131
layer = InputLayer (shape = (32 ,))
132
132
output = layer .call ()
133
133
self .assertIsNone (output )
134
+
135
+ def test_numpy_shape (self ):
136
+ # non-python int type shapes should be ok
137
+ InputLayer (shape = (np .int64 (32 ),))
138
+ # float should still raise
139
+ with self .assertRaisesRegex (
140
+ ValueError , "Cannot convert"
141
+ ):
142
+ InputLayer (shape = (np .float64 (32 ),))
You can’t perform that action at this time.
0 commit comments