Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 418ddb8

Browse files
committed
Factor out common checks for set_data in various Image subclasses.
1 parent 5f25d20 commit 418ddb8

File tree

1 file changed

+35
-62
lines changed

1 file changed

+35
-62
lines changed

lib/matplotlib/image.py

Lines changed: 35 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,34 @@ def write_png(self, fname):
688688
bytes=True, norm=True)
689689
PIL.Image.fromarray(im).save(fname, format="png")
690690

691+
@staticmethod
692+
def _normalize_image_array(A):
693+
# Common checks and typecasts on A for various Image subclasses.
694+
A = cbook.safe_masked_invalid(A, copy=True)
695+
if A.dtype != np.uint8 and not np.can_cast(A.dtype, float, "same_kind"):
696+
raise TypeError(f"Image data of dtype {A.dtype} cannot be "
697+
f"converted to float")
698+
if A.ndim == 3 and A.shape[-1] == 1:
699+
A = A.squeeze(-1) # If just (M, N, 1), assume scalar and apply colormap.
700+
if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]):
701+
raise TypeError(f"Invalid shape {A.shape} for image data")
702+
if A.ndim == 3:
703+
# If the input data has values outside the valid range (after
704+
# normalisation), we issue a warning and then clip X to the bounds
705+
# - otherwise casting wraps extreme values, hiding outliers and
706+
# making reliable interpretation impossible.
707+
high = 255 if np.issubdtype(A.dtype, np.integer) else 1
708+
if A.min() < 0 or high < A.max():
709+
_log.warning(
710+
'Clipping input data to the valid range for imshow with '
711+
'RGB data ([0..1] for floats or [0..255] for integers).'
712+
)
713+
A = np.clip(A, 0, high)
714+
# Cast unsupported integer types to uint8
715+
if A.dtype != np.uint8 and np.issubdtype(A.dtype, np.integer):
716+
A = A.astype(np.uint8)
717+
return A
718+
691719
def set_data(self, A):
692720
"""
693721
Set the image array.
@@ -700,38 +728,7 @@ def set_data(self, A):
700728
"""
701729
if isinstance(A, PIL.Image.Image):
702730
A = pil_to_array(A) # Needed e.g. to apply png palette.
703-
self._A = cbook.safe_masked_invalid(A, copy=True)
704-
705-
if (self._A.dtype != np.uint8 and
706-
not np.can_cast(self._A.dtype, float, "same_kind")):
707-
raise TypeError(f"Image data of dtype {self._A.dtype} cannot be "
708-
"converted to float")
709-
710-
if self._A.ndim == 3 and self._A.shape[-1] == 1:
711-
# If just one dimension assume scalar and apply colormap
712-
self._A = self._A[:, :, 0]
713-
714-
if not (self._A.ndim == 2
715-
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
716-
raise TypeError(f"Invalid shape {self._A.shape} for image data")
717-
718-
if self._A.ndim == 3:
719-
# If the input data has values outside the valid range (after
720-
# normalisation), we issue a warning and then clip X to the bounds
721-
# - otherwise casting wraps extreme values, hiding outliers and
722-
# making reliable interpretation impossible.
723-
high = 255 if np.issubdtype(self._A.dtype, np.integer) else 1
724-
if self._A.min() < 0 or high < self._A.max():
725-
_log.warning(
726-
'Clipping input data to the valid range for imshow with '
727-
'RGB data ([0..1] for floats or [0..255] for integers).'
728-
)
729-
self._A = np.clip(self._A, 0, high)
730-
# Cast unsupported integer types to uint8
731-
if self._A.dtype != np.uint8 and np.issubdtype(self._A.dtype,
732-
np.integer):
733-
self._A = self._A.astype(np.uint8)
734-
731+
self._A = self._normalize_image_array(A)
735732
self._imcache = None
736733
self.stale = True
737734

@@ -1149,23 +1146,15 @@ def set_data(self, x, y, A):
11491146
(M, N) `~numpy.ndarray` or masked array of values to be
11501147
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
11511148
"""
1149+
A = self._normalize_image_array(A)
11521150
x = np.array(x, np.float32)
11531151
y = np.array(y, np.float32)
1154-
A = cbook.safe_masked_invalid(A, copy=True)
1155-
if not (x.ndim == y.ndim == 1 and A.shape[0:2] == y.shape + x.shape):
1152+
if not (x.ndim == y.ndim == 1 and A.shape[:2] == y.shape + x.shape):
11561153
raise TypeError("Axes don't match array shape")
1157-
if A.ndim not in [2, 3]:
1158-
raise TypeError("Can only plot 2D or 3D data")
1159-
if A.ndim == 3 and A.shape[2] not in [1, 3, 4]:
1160-
raise TypeError("3D arrays must have three (RGB) "
1161-
"or four (RGBA) color components")
1162-
if A.ndim == 3 and A.shape[2] == 1:
1163-
A = A.squeeze(axis=-1)
11641154
self._A = A
11651155
self._Ax = x
11661156
self._Ay = y
11671157
self._imcache = None
1168-
11691158
self.stale = True
11701159

11711160
def set_array(self, *args):
@@ -1307,36 +1296,20 @@ def set_data(self, x, y, A):
13071296
- (M, N, 3): RGB array
13081297
- (M, N, 4): RGBA array
13091298
"""
1310-
A = cbook.safe_masked_invalid(A, copy=True)
1311-
if x is None:
1312-
x = np.arange(0, A.shape[1]+1, dtype=np.float64)
1313-
else:
1314-
x = np.array(x, np.float64).ravel()
1315-
if y is None:
1316-
y = np.arange(0, A.shape[0]+1, dtype=np.float64)
1317-
else:
1318-
y = np.array(y, np.float64).ravel()
1319-
1320-
if A.shape[:2] != (y.size-1, x.size-1):
1299+
A = self._normalize_image_array(A)
1300+
x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel()
1301+
y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel()
1302+
if A.shape[:2] != (y.size - 1, x.size - 1):
13211303
raise ValueError(
13221304
"Axes don't match array shape. Got %s, expected %s." %
13231305
(A.shape[:2], (y.size - 1, x.size - 1)))
1324-
if A.ndim not in [2, 3]:
1325-
raise ValueError("A must be 2D or 3D")
1326-
if A.ndim == 3:
1327-
if A.shape[2] == 1:
1328-
A = A.squeeze(axis=-1)
1329-
elif A.shape[2] not in [3, 4]:
1330-
raise ValueError("3D arrays must have RGB or RGBA as last dim")
1331-
13321306
# For efficient cursor readout, ensure x and y are increasing.
13331307
if x[-1] < x[0]:
13341308
x = x[::-1]
13351309
A = A[:, ::-1]
13361310
if y[-1] < y[0]:
13371311
y = y[::-1]
13381312
A = A[::-1]
1339-
13401313
self._A = A
13411314
self._Ax = x
13421315
self._Ay = y

0 commit comments

Comments
 (0)