diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 1da7603..49fad82 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -942,6 +942,15 @@ def __setitem__( self._validate_index(key, op="setitem") # Indexing self._array with array_api_strict arrays can be erroneous np_key = key._array if isinstance(key, Array) else key + + # sanitize the value + other = value + if isinstance(value, (bool, int, float, complex)): + other = self._promote_scalar(value) + dt = _result_type(self.dtype, other.dtype) + if dt != self.dtype: + raise TypeError(f"mismatched dtypes: {self.dtype = } and {other.dtype = }") + self._array.__setitem__(np_key, asarray(value)._array) def __sub__(self, other: Array | complex, /) -> Array: diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index de52f4c..e3c16f4 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -18,12 +18,14 @@ _integer_or_boolean_dtypes, _real_numeric_dtypes, _numeric_dtypes, + uint8, int8, int16, int32, int64, uint64, float64, + complex128, bool as bool_, ) from .._flags import set_array_api_strict_flags @@ -193,6 +195,29 @@ def test_indexing_arrays_different_devices(): a[idx1, idx2] +def test_setitem_invalid_promotions(): + # Check that violating these two raises: + # Setting array values must not affect the data type of self, and + # Behavior must otherwise follow Type Promotion Rules. + a = asarray([1, 2, 3]) + with pytest.raises(TypeError): + a[0] = 3.5 + + with pytest.raises(TypeError): + a[0] = asarray(3.5) + + a = asarray([1, 2, 3], dtype=uint8) + with pytest.raises(TypeError): + a[0] = asarray(42, dtype=uint64) + + a = asarray([1, 2, 3], dtype=float64) + with pytest.raises(TypeError): + a[0] = 3.5j + + with pytest.raises(TypeError): + a[0] = asarray(3.5j, dtype=complex128) + + def test_promoted_scalar_inherits_device(): device1 = Device("device1") x = asarray([1., 2, 3], device=device1)