Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 72885cc

Browse files
authored
Merge pull request #26070 from anntzer/isdc
Factor out common checks for set_data in various Image subclasses.
2 parents b49cd20 + b3417a2 commit 72885cc

File tree

1 file changed

+38
-62
lines changed

1 file changed

+38
-62
lines changed

lib/matplotlib/image.py

Lines changed: 38 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -688,50 +688,50 @@ def write_png(self, fname):
688688
bytes=True, norm=True)
689689
PIL.Image.fromarray(im).save(fname, format="png")
690690

691-
def set_data(self, A):
691+
@staticmethod
692+
def _normalize_image_array(A):
692693
"""
693-
Set the image array.
694-
695-
Note that this function does *not* update the normalization used.
696-
697-
Parameters
698-
----------
699-
A : array-like or `PIL.Image.Image`
694+
Check validity of image-like input *A* and normalize it to a format suitable for
695+
Image subclasses.
700696
"""
701-
if isinstance(A, PIL.Image.Image):
702-
A = pil_to_array(A) # Needed e.g. to apply png palette.
703-
self._A = cbook.safe_masked_invalid(A, copy=True)
704-
705-
if (self._A.dtype != np.uint8 and
706-
not np.can_cast(self._A.dtype, float, "same_kind")):
707-
raise TypeError(f"Image data of dtype {self._A.dtype} cannot be "
708-
"converted to float")
709-
710-
if self._A.ndim == 3 and self._A.shape[-1] == 1:
711-
# If just one dimension assume scalar and apply colormap
712-
self._A = self._A[:, :, 0]
713-
714-
if not (self._A.ndim == 2
715-
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
716-
raise TypeError(f"Invalid shape {self._A.shape} for image data")
717-
718-
if self._A.ndim == 3:
697+
A = cbook.safe_masked_invalid(A, copy=True)
698+
if A.dtype != np.uint8 and not np.can_cast(A.dtype, float, "same_kind"):
699+
raise TypeError(f"Image data of dtype {A.dtype} cannot be "
700+
f"converted to float")
701+
if A.ndim == 3 and A.shape[-1] == 1:
702+
A = A.squeeze(-1) # If just (M, N, 1), assume scalar and apply colormap.
703+
if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]):
704+
raise TypeError(f"Invalid shape {A.shape} for image data")
705+
if A.ndim == 3:
719706
# If the input data has values outside the valid range (after
720707
# normalisation), we issue a warning and then clip X to the bounds
721708
# - otherwise casting wraps extreme values, hiding outliers and
722709
# making reliable interpretation impossible.
723-
high = 255 if np.issubdtype(self._A.dtype, np.integer) else 1
724-
if self._A.min() < 0 or high < self._A.max():
710+
high = 255 if np.issubdtype(A.dtype, np.integer) else 1
711+
if A.min() < 0 or high < A.max():
725712
_log.warning(
726713
'Clipping input data to the valid range for imshow with '
727714
'RGB data ([0..1] for floats or [0..255] for integers).'
728715
)
729-
self._A = np.clip(self._A, 0, high)
716+
A = np.clip(A, 0, high)
730717
# Cast unsupported integer types to uint8
731-
if self._A.dtype != np.uint8 and np.issubdtype(self._A.dtype,
732-
np.integer):
733-
self._A = self._A.astype(np.uint8)
718+
if A.dtype != np.uint8 and np.issubdtype(A.dtype, np.integer):
719+
A = A.astype(np.uint8)
720+
return A
734721

722+
def set_data(self, A):
723+
"""
724+
Set the image array.
725+
726+
Note that this function does *not* update the normalization used.
727+
728+
Parameters
729+
----------
730+
A : array-like or `PIL.Image.Image`
731+
"""
732+
if isinstance(A, PIL.Image.Image):
733+
A = pil_to_array(A) # Needed e.g. to apply png palette.
734+
self._A = self._normalize_image_array(A)
735735
self._imcache = None
736736
self.stale = True
737737

@@ -1149,23 +1149,15 @@ def set_data(self, x, y, A):
11491149
(M, N) `~numpy.ndarray` or masked array of values to be
11501150
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
11511151
"""
1152+
A = self._normalize_image_array(A)
11521153
x = np.array(x, np.float32)
11531154
y = np.array(y, np.float32)
1154-
A = cbook.safe_masked_invalid(A, copy=True)
1155-
if not (x.ndim == y.ndim == 1 and A.shape[0:2] == y.shape + x.shape):
1155+
if not (x.ndim == y.ndim == 1 and A.shape[:2] == y.shape + x.shape):
11561156
raise TypeError("Axes don't match array shape")
1157-
if A.ndim not in [2, 3]:
1158-
raise TypeError("Can only plot 2D or 3D data")
1159-
if A.ndim == 3 and A.shape[2] not in [1, 3, 4]:
1160-
raise TypeError("3D arrays must have three (RGB) "
1161-
"or four (RGBA) color components")
1162-
if A.ndim == 3 and A.shape[2] == 1:
1163-
A = A.squeeze(axis=-1)
11641157
self._A = A
11651158
self._Ax = x
11661159
self._Ay = y
11671160
self._imcache = None
1168-
11691161
self.stale = True
11701162

11711163
def set_array(self, *args):
@@ -1307,36 +1299,20 @@ def set_data(self, x, y, A):
13071299
- (M, N, 3): RGB array
13081300
- (M, N, 4): RGBA array
13091301
"""
1310-
A = cbook.safe_masked_invalid(A, copy=True)
1311-
if x is None:
1312-
x = np.arange(0, A.shape[1]+1, dtype=np.float64)
1313-
else:
1314-
x = np.array(x, np.float64).ravel()
1315-
if y is None:
1316-
y = np.arange(0, A.shape[0]+1, dtype=np.float64)
1317-
else:
1318-
y = np.array(y, np.float64).ravel()
1319-
1320-
if A.shape[:2] != (y.size-1, x.size-1):
1302+
A = self._normalize_image_array(A)
1303+
x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel()
1304+
y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel()
1305+
if A.shape[:2] != (y.size - 1, x.size - 1):
13211306
raise ValueError(
13221307
"Axes don't match array shape. Got %s, expected %s." %
13231308
(A.shape[:2], (y.size - 1, x.size - 1)))
1324-
if A.ndim not in [2, 3]:
1325-
raise ValueError("A must be 2D or 3D")
1326-
if A.ndim == 3:
1327-
if A.shape[2] == 1:
1328-
A = A.squeeze(axis=-1)
1329-
elif A.shape[2] not in [3, 4]:
1330-
raise ValueError("3D arrays must have RGB or RGBA as last dim")
1331-
13321309
# For efficient cursor readout, ensure x and y are increasing.
13331310
if x[-1] < x[0]:
13341311
x = x[::-1]
13351312
A = A[:, ::-1]
13361313
if y[-1] < y[0]:
13371314
y = y[::-1]
13381315
A = A[::-1]
1339-
13401316
self._A = A
13411317
self._Ax = x
13421318
self._Ay = y

0 commit comments

Comments
 (0)