Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Factor out common checks for set_data in various Image subclasses. #26070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 38 additions & 62 deletions lib/matplotlib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,50 +688,50 @@ def write_png(self, fname):
bytes=True, norm=True)
PIL.Image.fromarray(im).save(fname, format="png")

def set_data(self, A):
@staticmethod
def _normalize_image_array(A):
"""
Set the image array.

Note that this function does *not* update the normalization used.

Parameters
----------
A : array-like or `PIL.Image.Image`
Check validity of image-like input *A* and normalize it to a format suitable for
Image subclasses.
"""
if isinstance(A, PIL.Image.Image):
A = pil_to_array(A) # Needed e.g. to apply png palette.
self._A = cbook.safe_masked_invalid(A, copy=True)

if (self._A.dtype != np.uint8 and
not np.can_cast(self._A.dtype, float, "same_kind")):
raise TypeError(f"Image data of dtype {self._A.dtype} cannot be "
"converted to float")

if self._A.ndim == 3 and self._A.shape[-1] == 1:
# If just one dimension assume scalar and apply colormap
self._A = self._A[:, :, 0]

if not (self._A.ndim == 2
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
raise TypeError(f"Invalid shape {self._A.shape} for image data")

if self._A.ndim == 3:
A = cbook.safe_masked_invalid(A, copy=True)
if A.dtype != np.uint8 and not np.can_cast(A.dtype, float, "same_kind"):
raise TypeError(f"Image data of dtype {A.dtype} cannot be "
f"converted to float")
if A.ndim == 3 and A.shape[-1] == 1:
A = A.squeeze(-1) # If just (M, N, 1), assume scalar and apply colormap.
if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]):
raise TypeError(f"Invalid shape {A.shape} for image data")
if A.ndim == 3:
# If the input data has values outside the valid range (after
# normalisation), we issue a warning and then clip X to the bounds
# - otherwise casting wraps extreme values, hiding outliers and
# making reliable interpretation impossible.
high = 255 if np.issubdtype(self._A.dtype, np.integer) else 1
if self._A.min() < 0 or high < self._A.max():
high = 255 if np.issubdtype(A.dtype, np.integer) else 1
if A.min() < 0 or high < A.max():
_log.warning(
'Clipping input data to the valid range for imshow with '
'RGB data ([0..1] for floats or [0..255] for integers).'
)
self._A = np.clip(self._A, 0, high)
A = np.clip(A, 0, high)
# Cast unsupported integer types to uint8
if self._A.dtype != np.uint8 and np.issubdtype(self._A.dtype,
np.integer):
self._A = self._A.astype(np.uint8)
if A.dtype != np.uint8 and np.issubdtype(A.dtype, np.integer):
A = A.astype(np.uint8)
return A

def set_data(self, A):
"""
Set the image array.

Note that this function does *not* update the normalization used.

Parameters
----------
A : array-like or `PIL.Image.Image`
"""
if isinstance(A, PIL.Image.Image):
A = pil_to_array(A) # Needed e.g. to apply png palette.
self._A = self._normalize_image_array(A)
self._imcache = None
self.stale = True

Expand Down Expand Up @@ -1149,23 +1149,15 @@ def set_data(self, x, y, A):
(M, N) `~numpy.ndarray` or masked array of values to be
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
"""
A = self._normalize_image_array(A)
x = np.array(x, np.float32)
y = np.array(y, np.float32)
A = cbook.safe_masked_invalid(A, copy=True)
if not (x.ndim == y.ndim == 1 and A.shape[0:2] == y.shape + x.shape):
if not (x.ndim == y.ndim == 1 and A.shape[:2] == y.shape + x.shape):
raise TypeError("Axes don't match array shape")
if A.ndim not in [2, 3]:
raise TypeError("Can only plot 2D or 3D data")
if A.ndim == 3 and A.shape[2] not in [1, 3, 4]:
raise TypeError("3D arrays must have three (RGB) "
"or four (RGBA) color components")
if A.ndim == 3 and A.shape[2] == 1:
A = A.squeeze(axis=-1)
self._A = A
self._Ax = x
self._Ay = y
self._imcache = None

self.stale = True

def set_array(self, *args):
Expand Down Expand Up @@ -1307,36 +1299,20 @@ def set_data(self, x, y, A):
- (M, N, 3): RGB array
- (M, N, 4): RGBA array
"""
A = cbook.safe_masked_invalid(A, copy=True)
if x is None:
x = np.arange(0, A.shape[1]+1, dtype=np.float64)
else:
x = np.array(x, np.float64).ravel()
if y is None:
y = np.arange(0, A.shape[0]+1, dtype=np.float64)
else:
y = np.array(y, np.float64).ravel()

if A.shape[:2] != (y.size-1, x.size-1):
A = self._normalize_image_array(A)
x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel()
y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel()
if A.shape[:2] != (y.size - 1, x.size - 1):
raise ValueError(
"Axes don't match array shape. Got %s, expected %s." %
(A.shape[:2], (y.size - 1, x.size - 1)))
if A.ndim not in [2, 3]:
raise ValueError("A must be 2D or 3D")
if A.ndim == 3:
if A.shape[2] == 1:
A = A.squeeze(axis=-1)
elif A.shape[2] not in [3, 4]:
raise ValueError("3D arrays must have RGB or RGBA as last dim")

# For efficient cursor readout, ensure x and y are increasing.
if x[-1] < x[0]:
x = x[::-1]
A = A[:, ::-1]
if y[-1] < y[0]:
y = y[::-1]
A = A[::-1]

self._A = A
self._Ax = x
self._Ay = y
Expand Down