Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cd9d95d

Browse files
Fix array type problems (issue 211), except for scalar broadcasting
1 parent b76036e commit cd9d95d

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

h5py/_hl/dataset.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,18 +391,20 @@ def __setitem__(self, args, val):
391391

392392
# Generally we try to avoid converting the arrays on the Python
393393
# side. However, for compound literals this is unavoidable.
394-
if self.dtype.kind == "O" or (
395-
self.dtype.kind == 'V' and \
396-
(not isinstance(val, numpy.ndarray) or val.dtype.kind != 'V') ):
394+
if self.dtype.kind == "O" or \
395+
(self.dtype.kind == 'V' and \
396+
(not isinstance(val, numpy.ndarray) or val.dtype.kind != 'V') and \
397+
(self.dtype.subdtype == None)):
397398
val = numpy.asarray(val, dtype=self.dtype, order='C')
398399
else:
399400
val = numpy.asarray(val, order='C')
400401

401402
# Check for array dtype compatibility and convert
402403
if self.dtype.subdtype is not None:
403404
shp = self.dtype.subdtype[1]
404-
if val.shape[-len(shp):] != shp:
405-
raise TypeError("Can't broadcast to array dimension %s" % (shp,))
405+
valshp = val.shape[-len(shp):]
406+
if valshp != shp: # Last dimension has to match
407+
raise TypeError("When writing to array types, last N dimensions have to match (got %s, but should be %s)" % (valshp, shp,))
406408
mtype = h5t.py_create(numpy.dtype((val.dtype, shp)))
407409
mshape = val.shape[0:len(val.shape)-len(shp)]
408410
else:
@@ -418,7 +420,7 @@ def __setitem__(self, args, val):
418420
# Broadcast scalars if necessary.
419421
if (mshape == () and selection.mshape != ()):
420422
if self.dtype.subdtype is not None:
421-
raise NotImplementedError("Scalar broadcasting is not supported for array dtypes")
423+
raise TypeError("Scalar broadcasting is not supported for array dtypes")
422424
val2 = numpy.empty(selection.mshape[-1], dtype=val.dtype)
423425
val2[...] = val
424426
val = val2

0 commit comments

Comments
 (0)