Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4223aea

Browse files
authored
support numpy int dtypes in shapes (keras-team#18850)
* support numpy int dtypes in shapes * add test
1 parent 494a99a commit 4223aea

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

keras/backend/common/variables.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ def standardize_dtype(dtype):
389389
"torch" in str(dtype) or "jax.numpy" in str(dtype)
390390
):
391391
dtype = str(dtype).split(".")[-1]
392+
elif hasattr(dtype, "__name__"):
393+
dtype = dtype.__name__
392394

393395
if dtype not in ALLOWED_DTYPES:
394396
raise ValueError(f"Invalid dtype: {dtype}")
@@ -414,10 +416,10 @@ def standardize_shape(shape):
414416
if config.backend() == "jax" and str(e) == "b":
415417
# JAX2TF tracing represents `None` dimensions as `b`
416418
continue
417-
if not isinstance(e, int):
419+
if not is_int_dtype(type(e)):
418420
raise ValueError(
419421
f"Cannot convert '{shape}' to a shape. "
420-
f"Found invalid entry '{e}'. "
422+
f"Found invalid entry '{e}' of type '{type(e)}'. "
421423
)
422424
if e < 0:
423425
raise ValueError(

keras/layers/core/input_layer_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,12 @@ def test_call_method(self):
131131
layer = InputLayer(shape=(32,))
132132
output = layer.call()
133133
self.assertIsNone(output)
134+
135+
def test_numpy_shape(self):
136+
# non-python int type shapes should be ok
137+
InputLayer(shape=(np.int64(32),))
138+
# float should still raise
139+
with self.assertRaisesRegex(
140+
ValueError, "Cannot convert"
141+
):
142+
InputLayer(shape=(np.float64(32),))

0 commit comments

Comments
 (0)