Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a18b4bf

Browse files
authored
Merge pull request #18870 from aitikgupta/scalarmappable-set_array
Expand ScalarMappable.set_array to accept array-like inputs
2 parents 4c5eff9 + 0ec672f commit a18b4bf

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

lib/matplotlib/cm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,21 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
361361

362362
def set_array(self, A):
363363
"""
364-
Set the image array from numpy array *A*.
364+
Set the image array from array-like *A*.
365365
366366
Parameters
367367
----------
368-
A : ndarray or None
368+
A : array-like or None
369369
"""
370+
if A is None:
371+
self._A = None
372+
return
373+
374+
A = cbook.safe_masked_invalid(A, copy=True)
375+
if not np.can_cast(A.dtype, float, "same_kind"):
376+
raise TypeError(f"Image data of dtype {A.dtype} cannot be "
377+
"converted to float")
378+
370379
self._A = A
371380

372381
def get_array(self):

lib/matplotlib/image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,8 +1346,7 @@ def make_image(self, renderer, magnification=1.0, unsampled=False):
13461346

13471347
def set_data(self, A):
13481348
"""Set the image array."""
1349-
cm.ScalarMappable.set_array(self,
1350-
cbook.safe_masked_invalid(A, copy=True))
1349+
cm.ScalarMappable.set_array(self, A)
13511350
self.stale = True
13521351

13531352

lib/matplotlib/tests/test_collections.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,22 @@ def test_collection_set_verts_array():
677677
assert np.array_equal(ap._codes, atp._codes)
678678

679679

680+
def test_collection_set_array():
681+
vals = [*range(10)]
682+
683+
# Test set_array with list
684+
c = Collection()
685+
c.set_array(vals)
686+
687+
# Test set_array with wrong dtype
688+
with pytest.raises(TypeError, match="^Image data of dtype"):
689+
c.set_array("wrong_input")
690+
691+
# Test if array kwarg is copied
692+
vals[5] = 45
693+
assert np.not_equal(vals, c.get_array()).any()
694+
695+
680696
def test_blended_collection_autolim():
681697
a = [1, 2, 4]
682698
height = .2

0 commit comments

Comments
 (0)