From b3417a2a24d82ce770ad95c5a71f808f072860f2 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sat, 3 Jun 2023 16:48:13 +0200 Subject: [PATCH] Factor out common checks for set_data in various Image subclasses. --- lib/matplotlib/image.py | 100 +++++++++++++++------------------------- 1 file changed, 38 insertions(+), 62 deletions(-) diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index 135934a244a8..c8d6f89a0621 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -688,50 +688,50 @@ def write_png(self, fname): bytes=True, norm=True) PIL.Image.fromarray(im).save(fname, format="png") - def set_data(self, A): + @staticmethod + def _normalize_image_array(A): """ - Set the image array. - - Note that this function does *not* update the normalization used. - - Parameters - ---------- - A : array-like or `PIL.Image.Image` + Check validity of image-like input *A* and normalize it to a format suitable for + Image subclasses. """ - if isinstance(A, PIL.Image.Image): - A = pil_to_array(A) # Needed e.g. to apply png palette. - self._A = cbook.safe_masked_invalid(A, copy=True) - - if (self._A.dtype != np.uint8 and - not np.can_cast(self._A.dtype, float, "same_kind")): - raise TypeError(f"Image data of dtype {self._A.dtype} cannot be " - "converted to float") - - if self._A.ndim == 3 and self._A.shape[-1] == 1: - # If just one dimension assume scalar and apply colormap - self._A = self._A[:, :, 0] - - if not (self._A.ndim == 2 - or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]): - raise TypeError(f"Invalid shape {self._A.shape} for image data") - - if self._A.ndim == 3: + A = cbook.safe_masked_invalid(A, copy=True) + if A.dtype != np.uint8 and not np.can_cast(A.dtype, float, "same_kind"): + raise TypeError(f"Image data of dtype {A.dtype} cannot be " + f"converted to float") + if A.ndim == 3 and A.shape[-1] == 1: + A = A.squeeze(-1) # If just (M, N, 1), assume scalar and apply colormap. + if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]): + raise TypeError(f"Invalid shape {A.shape} for image data") + if A.ndim == 3: # If the input data has values outside the valid range (after # normalisation), we issue a warning and then clip X to the bounds # - otherwise casting wraps extreme values, hiding outliers and # making reliable interpretation impossible. - high = 255 if np.issubdtype(self._A.dtype, np.integer) else 1 - if self._A.min() < 0 or high < self._A.max(): + high = 255 if np.issubdtype(A.dtype, np.integer) else 1 + if A.min() < 0 or high < A.max(): _log.warning( 'Clipping input data to the valid range for imshow with ' 'RGB data ([0..1] for floats or [0..255] for integers).' ) - self._A = np.clip(self._A, 0, high) + A = np.clip(A, 0, high) # Cast unsupported integer types to uint8 - if self._A.dtype != np.uint8 and np.issubdtype(self._A.dtype, - np.integer): - self._A = self._A.astype(np.uint8) + if A.dtype != np.uint8 and np.issubdtype(A.dtype, np.integer): + A = A.astype(np.uint8) + return A + def set_data(self, A): + """ + Set the image array. + + Note that this function does *not* update the normalization used. + + Parameters + ---------- + A : array-like or `PIL.Image.Image` + """ + if isinstance(A, PIL.Image.Image): + A = pil_to_array(A) # Needed e.g. to apply png palette. + self._A = self._normalize_image_array(A) self._imcache = None self.stale = True @@ -1149,23 +1149,15 @@ def set_data(self, x, y, A): (M, N) `~numpy.ndarray` or masked array of values to be colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array. """ + A = self._normalize_image_array(A) x = np.array(x, np.float32) y = np.array(y, np.float32) - A = cbook.safe_masked_invalid(A, copy=True) - if not (x.ndim == y.ndim == 1 and A.shape[0:2] == y.shape + x.shape): + if not (x.ndim == y.ndim == 1 and A.shape[:2] == y.shape + x.shape): raise TypeError("Axes don't match array shape") - if A.ndim not in [2, 3]: - raise TypeError("Can only plot 2D or 3D data") - if A.ndim == 3 and A.shape[2] not in [1, 3, 4]: - raise TypeError("3D arrays must have three (RGB) " - "or four (RGBA) color components") - if A.ndim == 3 and A.shape[2] == 1: - A = A.squeeze(axis=-1) self._A = A self._Ax = x self._Ay = y self._imcache = None - self.stale = True def set_array(self, *args): @@ -1307,28 +1299,13 @@ def set_data(self, x, y, A): - (M, N, 3): RGB array - (M, N, 4): RGBA array """ - A = cbook.safe_masked_invalid(A, copy=True) - if x is None: - x = np.arange(0, A.shape[1]+1, dtype=np.float64) - else: - x = np.array(x, np.float64).ravel() - if y is None: - y = np.arange(0, A.shape[0]+1, dtype=np.float64) - else: - y = np.array(y, np.float64).ravel() - - if A.shape[:2] != (y.size-1, x.size-1): + A = self._normalize_image_array(A) + x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel() + y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel() + if A.shape[:2] != (y.size - 1, x.size - 1): raise ValueError( "Axes don't match array shape. Got %s, expected %s." % (A.shape[:2], (y.size - 1, x.size - 1))) - if A.ndim not in [2, 3]: - raise ValueError("A must be 2D or 3D") - if A.ndim == 3: - if A.shape[2] == 1: - A = A.squeeze(axis=-1) - elif A.shape[2] not in [3, 4]: - raise ValueError("3D arrays must have RGB or RGBA as last dim") - # For efficient cursor readout, ensure x and y are increasing. if x[-1] < x[0]: x = x[::-1] @@ -1336,7 +1313,6 @@ def set_data(self, x, y, A): if y[-1] < y[0]: y = y[::-1] A = A[::-1] - self._A = A self._Ax = x self._Ay = y