diff --git a/doc/api/colors_api.rst b/doc/api/colors_api.rst index 6b02f723d74d..3022512dd371 100644 --- a/doc/api/colors_api.rst +++ b/doc/api/colors_api.rst @@ -53,6 +53,7 @@ Multivariate Colormaps BivarColormap SegmentedBivarColormap BivarColormapFromImage + MultivarColormap Other classes ------------- diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index b46cbce39c58..a072867ff050 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -5685,7 +5685,7 @@ def imshow(self, X, cmap=None, norm=None, *, aspect=None, """ Display data as an image, i.e., on a 2D regular raster. - The input may either be actual RGB(A) data, or 2D scalar data, which + The input may either be actual RGB(A) data, or 2D data, which will be rendered as a pseudocolor image. For displaying a grayscale image, set up the colormapping using the parameters ``cmap='gray', vmin=0, vmax=255``. @@ -5706,24 +5706,25 @@ def imshow(self, X, cmap=None, norm=None, *, aspect=None, - (M, N): an image with scalar data. The values are mapped to colors using normalization and a colormap. See parameters *norm*, *cmap*, *vmin*, *vmax*. + - (K, M, N): multiple images with scalar data. Must be used + with a multivariate or bivariate colormap that supports K channels. - (M, N, 3): an image with RGB values (0-1 float or 0-255 int). - (M, N, 4): an image with RGBA values (0-1 float or 0-255 int), i.e. including transparency. - The first two dimensions (M, N) define the rows and columns of - the image. + The dimensions M and N are the number of rows and columns of the image. Out-of-range RGB(A) values are clipped. - %(cmap_doc)s + %(multi_cmap_doc)s This parameter is ignored if *X* is RGB(A). - %(norm_doc)s + %(multi_norm_doc)s This parameter is ignored if *X* is RGB(A). - %(vmin_vmax_doc)s + %(multi_vmin_vmax_doc)s This parameter is ignored if *X* is RGB(A). @@ -6062,9 +6063,10 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None, Parameters ---------- - C : 2D array-like + C : 2D array-like or list of 2D array-like objects The color-mapped values. Color-mapping is controlled by *cmap*, - *norm*, *vmin*, and *vmax*. + *norm*, *vmin*, and *vmax*. Multiple arrays are only supported + when a bivariate or mulivariate colormap is used, see *cmap*. X, Y : array-like, optional The coordinates of the corners of quadrilaterals of a pcolormesh:: @@ -6111,11 +6113,11 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None, See :doc:`/gallery/images_contours_and_fields/pcolormesh_grids` for more description. - %(cmap_doc)s + %(multi_cmap_doc)s - %(norm_doc)s + %(multi_norm_doc)s - %(vmin_vmax_doc)s + %(multi_vmin_vmax_doc)s %(colorizer_doc)s @@ -6190,6 +6192,26 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None, if shading is None: shading = mpl.rcParams['pcolor.shading'] shading = shading.lower() + + mcolorizer.ColorizingArtist._check_exclusionary_keywords(colorizer, + vmin=vmin, + vmax=vmax) + + # we need to get the colorizer object to know the number of + # n_variates that should exist in the array, we therefore get the + # colorizer here. + colorizer_obj = mcolorizer.ColorizingArtist._get_colorizer(norm=norm, + cmap=cmap, + colorizer=colorizer) + if colorizer_obj.norm.n_input > 1: + # Valid call signatures for pcolor are with args = (C) or args = (X, Y, C) + # If provided, _pcolorargs will check that X, Y and C have the same shape. + # Before this check, we need to convert C from shape (K, N, M), where K is + # the number of variates, to (N, M) with a data type with K fields. + data = mcolorizer._ensure_multivariate_data(args[-1], + colorizer_obj.norm.n_input) + args = (*args[:-1], data) + X, Y, C, shading = self._pcolorargs('pcolor', *args, shading=shading, kwargs=kwargs) linewidths = (0.25,) @@ -6226,9 +6248,8 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None, coords = stack([X, Y], axis=-1) collection = mcoll.PolyQuadMesh( - coords, array=C, cmap=cmap, norm=norm, colorizer=colorizer, + coords, array=C, colorizer=colorizer_obj, alpha=alpha, **kwargs) - collection._check_exclusionary_keywords(colorizer, vmin=vmin, vmax=vmax) collection._scale_norm(norm, vmin, vmax) coords = coords.reshape(-1, 2) # flatten the grid structure; keep x, y @@ -6266,12 +6287,13 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, - (M, N) or M*N: a mesh with scalar data. The values are mapped to colors using normalization and a colormap. See parameters *norm*, *cmap*, *vmin*, *vmax*. + - (K, M, N): multiple layers with scalar data. Must be used with + a multivariate or bivariate colormap. See *cmap*. - (M, N, 3): an image with RGB values (0-1 float or 0-255 int). - (M, N, 4): an image with RGBA values (0-1 float or 0-255 int), i.e. including transparency. - The first two dimensions (M, N) define the rows and columns of - the mesh data. + Here M and N define the rows and columns of the mesh. X, Y : array-like, optional The coordinates of the corners of quadrilaterals of a pcolormesh:: @@ -6302,11 +6324,11 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, expanded as needed into the appropriate 2D arrays, making a rectangular grid. - %(cmap_doc)s + %(multi_cmap_doc)s - %(norm_doc)s + %(multi_norm_doc)s - %(vmin_vmax_doc)s + %(multi_vmin_vmax_doc)s %(colorizer_doc)s @@ -6429,6 +6451,24 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, shading = mpl._val_or_rc(shading, 'pcolor.shading').lower() kwargs.setdefault('edgecolors', 'none') + mcolorizer.ColorizingArtist._check_exclusionary_keywords(colorizer, + vmin=vmin, + vmax=vmax) + # we need to get the colorizer object to know the number of + # n_variates that should exist in the array, we therefore get the + # colorizer here. + colorizer_obj = mcolorizer.ColorizingArtist._get_colorizer(norm=norm, + cmap=cmap, + colorizer=colorizer) + if colorizer_obj.norm.n_input > 1: + # Valid call signatures for pcolor are with args = (C) or args = (X, Y, C) + # If provided, _pcolorargs will check that X, Y and C have the same shape. + # Before this check, we need to convert C from shape (K, N, M), where K is + # the number of variates, to (N, M) with a data type with K fields. + data = mcolorizer._ensure_multivariate_data(args[-1], + colorizer_obj.norm.n_input) + args = (*args[:-1], data) + X, Y, C, shading = self._pcolorargs('pcolormesh', *args, shading=shading, kwargs=kwargs) coords = np.stack([X, Y], axis=-1) @@ -6437,8 +6477,8 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, collection = mcoll.QuadMesh( coords, antialiased=antialiased, shading=shading, - array=C, cmap=cmap, norm=norm, colorizer=colorizer, alpha=alpha, **kwargs) - collection._check_exclusionary_keywords(colorizer, vmin=vmin, vmax=vmax) + array=C, colorizer=colorizer_obj, + alpha=alpha, **kwargs) collection._scale_norm(norm, vmin, vmax) coords = coords.reshape(-1, 2) # flatten the grid structure; keep x, y diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index d90921158ee5..683f62763cb8 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -690,7 +690,17 @@ def safe_masked_invalid(x, copy=False): try: xm = np.ma.masked_where(~(np.isfinite(x)), x, copy=False) except TypeError: - return x + if len(x.dtype.descr) == 1: + return x + else: + # in case of a dtype with multiple fields: + try: + mask = np.empty(x.shape, dtype=np.dtype('bool, '*len(x.dtype.descr))) + for dd, dm in zip(x.dtype.descr, mask.dtype.descr): + mask[dm[0]] = ~(np.isfinite(x[dd[0]])) + xm = np.ma.array(x, mask=mask, copy=False) + except TypeError: + return x return xm diff --git a/lib/matplotlib/cm.py b/lib/matplotlib/cm.py index 2697666b9573..d8c3dafcfe46 100644 --- a/lib/matplotlib/cm.py +++ b/lib/matplotlib/cm.py @@ -278,32 +278,3 @@ def get_cmap(name=None, lut=None): return _colormaps[name] else: return _colormaps[name].resampled(lut) - - -def _ensure_cmap(cmap): - """ - Ensure that we have a `.Colormap` object. - - For internal use to preserve type stability of errors. - - Parameters - ---------- - cmap : None, str, Colormap - - - if a `Colormap`, return it - - if a string, look it up in mpl.colormaps - - if None, look up the default color map in mpl.colormaps - - Returns - ------- - Colormap - - """ - if isinstance(cmap, colors.Colormap): - return cmap - cmap_name = mpl._val_or_rc(cmap, "image.cmap") - # use check_in_list to ensure type stability of the exception raised by - # the internal usage of this (ValueError vs KeyError) - if cmap_name not in _colormaps: - _api.check_in_list(sorted(_colormaps), cmap=cmap_name) - return mpl.colormaps[cmap_name] diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index ab1bd337d805..619e99ad7e2f 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -2611,10 +2611,16 @@ def _get_unmasked_polys(self): arr = self.get_array() if arr is not None: arr = np.ma.getmaskarray(arr) - if arr.ndim == 3: + if self.norm.n_input > 1: + # multivar case + for a in mcolors.MultiNorm._iterable_variates_in_data( + arr, self.norm.n_input): + mask |= np.any(a, axis=0) + elif arr.ndim == 3: # RGB(A) case mask |= np.any(arr, axis=-1) elif arr.ndim == 2: + # scalar case mask |= arr else: mask |= arr.reshape(self._coordinates[:-1, :-1, :].shape[:2]) diff --git a/lib/matplotlib/colorizer.py b/lib/matplotlib/colorizer.py index b4223f389804..e7adaf055944 100644 --- a/lib/matplotlib/colorizer.py +++ b/lib/matplotlib/colorizer.py @@ -24,7 +24,7 @@ import numpy as np from numpy import ma -from matplotlib import _api, colors, cbook, scale, artist +from matplotlib import _api, colors, cbook, scale, artist, cm import matplotlib as mpl mpl._docstring.interpd.register( @@ -90,19 +90,7 @@ def norm(self): @norm.setter def norm(self, norm): - _api.check_isinstance((colors.Normalize, str, None), norm=norm) - if norm is None: - norm = colors.Normalize() - elif isinstance(norm, str): - try: - scale_cls = scale._scale_mapping[norm] - except KeyError: - raise ValueError( - "Invalid norm str name; the following values are " - f"supported: {', '.join(scale._scale_mapping)}" - ) from None - norm = _auto_norm_from_scale(scale_cls)() - + norm = _ensure_norm(norm, n_variates=self.cmap.n_variates) if norm is self.norm: # We aren't updating anything return @@ -232,9 +220,14 @@ def _set_cmap(self, cmap): cmap : `.Colormap` or str or None """ # bury import to avoid circular imports - from matplotlib import cm in_init = self._cmap is None - self._cmap = cm._ensure_cmap(cmap) + cmap_obj = _ensure_cmap(cmap, accept_multivariate=True) + if not in_init: + if self.norm.n_output != cmap_obj.n_variates: + raise ValueError(f"The colormap {cmap} does not support " + f"{self.norm.n_output} variates as required by " + f"the {type(self.norm)} on this Colorizer.") + self._cmap = cmap_obj if not in_init: self.changed() # Things are not set up properly yet. @@ -255,31 +248,35 @@ def set_clim(self, vmin=None, vmax=None): vmin, vmax : float The limits. - The limits may also be passed as a tuple (*vmin*, *vmax*) as a - single positional argument. + For scalar data, the limits may also be passed as a + tuple (*vmin*, *vmax*) single positional argument. .. ACCEPTS: (vmin: float, vmax: float) """ + if self.norm.n_input == 1: + if vmax is None: + try: + vmin, vmax = vmin + except (TypeError, ValueError): + pass + # If the norm's limits are updated self.changed() will be called # through the callbacks attached to the norm, this causes an inconsistent # state, to prevent this blocked context manager is used - if vmax is None: - try: - vmin, vmax = vmin - except (TypeError, ValueError): - pass - orig_vmin_vmax = self.norm.vmin, self.norm.vmax # Blocked context manager prevents callbacks from being triggered # until both vmin and vmax are updated with self.norm.callbacks.blocked(signal='changed'): + # Tote that the @vmin/vmax.setter invokes + # colors._sanitize_extrema() to sanitize the input + # The input is not sanitized here because + # colors._sanitize_extrema() does not handle multivariate input. if vmin is not None: - self.norm.vmin = colors._sanitize_extrema(vmin) + self.norm.vmin = vmin if vmax is not None: - self.norm.vmax = colors._sanitize_extrema(vmax) + self.norm.vmax = vmax - # emit a update signal if the limits are changed if orig_vmin_vmax != (self.norm.vmin, self.norm.vmax): self.norm.callbacks.process('changed') @@ -467,39 +464,61 @@ def colorbar(self): def colorbar(self, colorbar): self._colorizer.colorbar = colorbar - def _format_cursor_data_override(self, data): - # This function overwrites Artist.format_cursor_data(). We cannot - # implement cm.ScalarMappable.format_cursor_data() directly, because - # most cm.ScalarMappable subclasses inherit from Artist first and from - # cm.ScalarMappable second, so Artist.format_cursor_data would always - # have precedence over cm.ScalarMappable.format_cursor_data. - - # Note if cm.ScalarMappable is depreciated, this functionality should be - # implemented as format_cursor_data() on ColorizingArtist. - n = self.cmap.N - if np.ma.getmask(data): - return "[]" - normed = self.norm(data) + @staticmethod + def _sig_digits_from_norm(norm, data, n): + # Determines the number of significant digits + # to use for a number given a norm, and n, where n is the + # number of colors in the colormap. + normed = norm(data) if np.isfinite(normed): - if isinstance(self.norm, colors.BoundaryNorm): + if isinstance(norm, colors.BoundaryNorm): # not an invertible normalization mapping - cur_idx = np.argmin(np.abs(self.norm.boundaries - data)) + cur_idx = np.argmin(np.abs(norm.boundaries - data)) neigh_idx = max(0, cur_idx - 1) # use max diff to prevent delta == 0 delta = np.diff( - self.norm.boundaries[neigh_idx:cur_idx + 2] + norm.boundaries[neigh_idx:cur_idx + 2] ).max() - elif self.norm.vmin == self.norm.vmax: + elif norm.vmin == norm.vmax: # singular norms, use delta of 10% of only value - delta = np.abs(self.norm.vmin * .1) + delta = np.abs(norm.vmin * .1) else: # Midpoints of neighboring color intervals. - neighbors = self.norm.inverse( + neighbors = norm.inverse( (int(normed * n) + np.array([0, 1])) / n) delta = abs(neighbors - data).max() g_sig_digits = cbook._g_sig_digits(data, delta) else: g_sig_digits = 3 # Consistent with default below. + return g_sig_digits + + def _format_cursor_data_override(self, data): + # This function overwrites Artist.format_cursor_data(). We cannot + # implement cm.ScalarMappable.format_cursor_data() directly, because + # most cm.ScalarMappable subclasses inherit from Artist first and from + # cm.ScalarMappable second, so Artist.format_cursor_data would always + # have precedence over cm.ScalarMappable.format_cursor_data. + + # Note if cm.ScalarMappable is depreciated, this functionality should be + # implemented as format_cursor_data() on ColorizingArtist. + if np.ma.getmask(data) or data is None: + return "[]" + if len(data.dtype.descr) > 1: + # We have multivariate data encoded as a data type with multiple fields + # NOTE: If any of the fields are masked, "[]" would be returned via + # the if statement above. + s_sig_digits_list = [] + if isinstance(self.cmap, colors.BivarColormap): + n_s = (self.cmap.N, self.cmap.M) + else: + n_s = [part.N for part in self.cmap] + os = [f"{d:-#.{self._sig_digits_from_norm(no, d, n)}g}" + for no, d, n in zip(self.norm.norms, data, n_s)] + return f"[{', '.join(os)}]" + + # scalar data + n = self.cmap.N + g_sig_digits = self._sig_digits_from_norm(self.norm, data, n) return f"[{data:-#.{g_sig_digits}g}]" @@ -563,10 +582,18 @@ def set_array(self, A): self._A = None return + A = _ensure_multivariate_data(A, self.norm.n_input) + A = cbook.safe_masked_invalid(A, copy=True) if not np.can_cast(A.dtype, float, "same_kind"): - raise TypeError(f"Image data of dtype {A.dtype} cannot be " - "converted to float") + if A.dtype.fields is None: + raise TypeError(f"Image data of dtype {A.dtype} cannot be " + f"converted to float") + else: + for key in A.dtype.fields: + if not np.can_cast(A[key].dtype, float, "same_kind"): + raise TypeError(f"Image data of dtype {A.dtype} cannot be " + f"converted to a sequence of floats") self._A = A if not self.norm.scaled(): @@ -615,6 +642,15 @@ def _get_colorizer(cmap, norm, colorizer): cmap : str or `~matplotlib.colors.Colormap`, default: :rc:`image.cmap` The Colormap instance or registered colormap name used to map scalar data to colors.""", + multi_cmap_doc="""\ +cmap : str, `~matplotlib.colors.Colormap`, `~matplotlib.colors.BivarColormap`\ + or `~matplotlib.colors.MultivarColormap`, default: :rc:`image.cmap` + The Colormap instance or registered colormap name used to map + data values to colors. + + Multivariate data is only accepted if a multivariate colormap + (`~matplotlib.colors.BivarColormap` or `~matplotlib.colors.MultivarColormap`) + is used.""", norm_doc="""\ norm : str or `~matplotlib.colors.Normalize`, optional The normalization method used to scale scalar data to the [0, 1] range @@ -629,6 +665,23 @@ def _get_colorizer(cmap, norm, colorizer): list of available scales, call `matplotlib.scale.get_scale_names()`. In that case, a suitable `.Normalize` subclass is dynamically generated and instantiated.""", + multi_norm_doc="""\ +norm : str, `~matplotlib.colors.Normalize` or list, optional + The normalization method used to scale data to the [0, 1] range + before mapping to colors using *cmap*. By default, a linear scaling is + used, mapping the lowest value to 0 and the highest to 1. + + This can be one of the following: + + - An instance of `.Normalize` or one of its subclasses + (see :ref:`colormapnorms`). + - A scale name, i.e. one of "linear", "log", "symlog", "logit", etc. For a + list of available scales, call `matplotlib.scale.get_scale_names()`. + In this case, a suitable `.Normalize` subclass is dynamically generated + and instantiated. + - A list of scale names or `.Normalize` objects matching the number of + variates in the colormap, for use with `~matplotlib.colors.BivarColormap` + or `~matplotlib.colors.MultivarColormap`, i.e. ``["linear", "log"]``.""", vmin_vmax_doc="""\ vmin, vmax : float, optional When using scalar data and no explicit *norm*, *vmin* and *vmax* define @@ -636,6 +689,17 @@ def _get_colorizer(cmap, norm, colorizer): the complete value range of the supplied data. It is an error to use *vmin*/*vmax* when a *norm* instance is given (but using a `str` *norm* name together with *vmin*/*vmax* is acceptable).""", + multi_vmin_vmax_doc="""\ +vmin, vmax : float or list, optional + When using scalar data and no explicit *norm*, *vmin* and *vmax* define + the data range that the colormap covers. By default, the colormap covers + the complete value range of the supplied data. It is an error to use + *vmin*/*vmax* when a *norm* instance is given (but using a `str` *norm* + name together with *vmin*/*vmax* is acceptable). + + A list of values (vmin or vmax) can be used to define independent limits + for each variate when using a `~matplotlib.colors.BivarColormap` or + `~matplotlib.colors.MultivarColormap`.""", ) @@ -701,3 +765,168 @@ def _auto_norm_from_scale(scale_cls): norm = colors.make_norm_from_scale(scale_cls)( colors.Normalize)() return type(norm) + + +def _ensure_norm(norm, n_variates=1): + if n_variates == 1: + _api.check_isinstance((colors.Normalize, str, None), norm=norm) + if norm is None: + norm = colors.Normalize() + elif isinstance(norm, str): + scale_cls = scale._get_scale_cls_from_str(norm) + norm = _auto_norm_from_scale(scale_cls)() + return norm + else: # n_variates > 1 + if not np.iterable(norm): + # include tuple in the list to improve error message + _api.check_isinstance((colors.Normalize, str, None, tuple), norm=norm) + if norm is None: + norm = colors.MultiNorm([None]*n_variates) + elif isinstance(norm, str): # single string + norm = colors.MultiNorm([norm]*n_variates) + else: # multiple string or objects + norm = colors.MultiNorm(norm) + if isinstance(norm, colors.Normalize) and norm.n_output == n_variates: + return norm + raise ValueError( + "Invalid norm for multivariate colormap with " + f"{n_variates} inputs." + ) + + +def _ensure_cmap(cmap, accept_multivariate=False): + """ + Ensure that we have a `.Colormap` object. + + For internal use to preserve type stability of errors. + + Parameters + ---------- + cmap : None, str, Colormap + + - if a `~matplotlib.colors.Colormap`, + `~matplotlib.colors.MultivarColormap` or + `~matplotlib.colors.BivarColormap`, + return it + - if a string, look it up in three corresponding databases + when not found: raise an error based on the expected shape + - if None, look up the default color map in mpl.colormaps + accept_multivariate : bool, default True + - if False, accept only Colormap, string in mpl.colormaps or None + + Returns + ------- + Colormap + + """ + if not accept_multivariate: + if isinstance(cmap, colors.Colormap): + return cmap + cmap_name = cmap if cmap is not None else mpl.rcParams["image.cmap"] + # use check_in_list to ensure type stability of the exception raised by + # the internal usage of this (ValueError vs KeyError) + if cmap_name not in mpl.colormaps: + _api.check_in_list(sorted(mpl.colormaps), cmap=cmap_name) + + if isinstance(cmap, (colors.Colormap, + colors.BivarColormap, + colors.MultivarColormap)): + return cmap + cmap_name = cmap if cmap is not None else mpl.rcParams["image.cmap"] + if cmap_name in mpl.colormaps: + return mpl.colormaps[cmap_name] + if cmap_name in mpl.multivar_colormaps: + return mpl.multivar_colormaps[cmap_name] + if cmap_name in mpl.bivar_colormaps: + return mpl.bivar_colormaps[cmap_name] + + # this error message is a variant of _api.check_in_list but gives + # additional hints as to how to access multivariate colormaps + + msg = f"{cmap!r} is not a valid value for cmap" + msg += "; supported values for scalar colormaps are " + msg += f"{', '.join(map(repr, sorted(mpl.colormaps)))}\n" + msg += "See matplotlib.bivar_colormaps() and" + msg += " matplotlib.multivar_colormaps() for" + msg += " bivariate and multivariate colormaps." + + raise ValueError(msg) + + if isinstance(cmap, colors.Colormap): + return cmap + cmap_name = cmap if cmap is not None else mpl.rcParams["image.cmap"] + # use check_in_list to ensure type stability of the exception raised by + # the internal usage of this (ValueError vs KeyError) + if cmap_name not in cm.colormaps: + _api.check_in_list(sorted(cm.colormaps), cmap=cmap_name) + return cm.colormaps[cmap_name] + + +def _ensure_multivariate_data(data, n_input): + """ + Ensure that the data has dtype with n_input. + Input data of shape (n_input, n, m) is converted to an array of shape + (n, m) with data type np.dtype(f'{data.dtype}, ' * n_input) + Complex data is returned as a view with dtype np.dtype('float64, float64') + or np.dtype('float32, float32') + If n_input is 1 and data is not of type np.ndarray (i.e. PIL.Image), + the data is returned unchanged. + If data is None, the function returns None + Parameters + ---------- + n_input : int + - number of variates in the data + data : np.ndarray, PIL.Image or None + Returns + ------- + np.ndarray, PIL.Image or None + """ + + if isinstance(data, np.ndarray): + if len(data.dtype.descr) == n_input: + # pass scalar data + # and already formatted data + return data + elif data.dtype in [np.complex64, np.complex128]: + # pass complex data + if data.dtype == np.complex128: + dt = np.dtype('float64, float64') + else: + dt = np.dtype('float32, float32') + reconstructed = np.ma.frombuffer(data.data, dtype=dt).reshape(data.shape) + if np.ma.is_masked(data): + for descriptor in dt.descr: + reconstructed[descriptor[0]][data.mask] = np.ma.masked + return reconstructed + + if n_input > 1 and len(data) == n_input: + # convert data from shape (n_input, n, m) + # to (n,m) with a new dtype + data = [np.ma.array(part, copy=False) for part in data] + dt = np.dtype(', '.join([f'{part.dtype}' for part in data])) + fields = [descriptor[0] for descriptor in dt.descr] + reconstructed = np.ma.empty(data[0].shape, dtype=dt) + for i, f in enumerate(fields): + if data[i].shape != reconstructed.shape: + raise ValueError("For multivariate data all variates must have same " + f"shape, not {data[0].shape} and {data[i].shape}") + reconstructed[f] = data[i] + if np.ma.is_masked(data[i]): + reconstructed[f][data[i].mask] = np.ma.masked + return reconstructed + + if data is None: + return data + + if n_input == 1: + # PIL.Image also gets passed here + return data + + elif n_input == 2: + raise ValueError("Invalid data entry for mutlivariate data. The data" + " must contain complex numbers, or have a first dimension 2," + " or be of a dtype with 2 fields") + else: + raise ValueError("Invalid data entry for mutlivariate data. The shape" + f" of the data must have a first dimension {n_input}" + f" or be of a dtype with {n_input} fields") diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index e3c3b39e8bb2..039534bedf25 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -1420,10 +1420,10 @@ def __init__(self, colormaps, combination_mode, name='multivariate colormap'): combination_mode: str, 'sRGB_add' or 'sRGB_sub' Describe how colormaps are combined in sRGB space - - If 'sRGB_add' -> Mixing produces brighter colors - `sRGB = sum(colors)` - - If 'sRGB_sub' -> Mixing produces darker colors - `sRGB = 1 - sum(1 - colors)` + - If 'sRGB_add': Mixing produces brighter colors + ``sRGB = sum(colors)`` + - If 'sRGB_sub': Mixing produces darker colors + ``sRGB = 1 - sum(1 - colors)`` name : str, optional The name of the colormap family. """ @@ -1598,12 +1598,12 @@ def with_extremes(self, *, bad=None, under=None, over=None): bad: :mpltype:`color`, default: None If Matplotlib color, the bad value is set accordingly in the copy - under tuple of :mpltype:`color`, default: None - If tuple, the `under` value of each component is set with the values + under: tuple of :mpltype:`color`, default: None + If tuple, the ``under`` value of each component is set with the values from the tuple. - over tuple of :mpltype:`color`, default: None - If tuple, the `over` value of each component is set with the values + over: tuple of :mpltype:`color`, default: None + If tuple, the ``over`` value of each component is set with the values from the tuple. Returns @@ -2320,6 +2320,16 @@ def __init__(self, vmin=None, vmax=None, clip=False): self._scale = None self.callbacks = cbook.CallbackRegistry(signals=["changed"]) + @property + def n_input(self): + # To be overridden by subclasses with multiple inputs + return 1 + + @property + def n_output(self): + # To be overridden by subclasses with multiple outputs + return 1 + @property def vmin(self): return self._vmin @@ -3219,6 +3229,232 @@ def inverse(self, value): return value +class MultiNorm(Normalize): + """ + A mixin class which contains multiple scalar norms + """ + + def __init__(self, norms, vmin=None, vmax=None, clip=False): + """ + Parameters + ---------- + norms : List of strings or `Normalize` objects + The constituent norms. The list must have a minimum length of 2. + vmin, vmax : float, None, or list of float or None + Limits of the constituent norms. + If a list, each each value is assigned to one of the constituent + norms. Single values are repeated to form a list of appropriate size. + + clip : bool or list of bools, default: False + Determines the behavior for mapping values outside the range + ``[vmin, vmax]`` for the constituent norms. + If a list, each each value is assigned to one of the constituent + norms. Single values are repeated to form a list of appropriate size. + + """ + + if isinstance(norms, str) or not np.iterable(norms): + raise ValueError("A MultiNorm must be assigned multiple norms") + norms = [n for n in norms] + for i, n in enumerate(norms): + if n is None: + norms[i] = Normalize() + elif isinstance(n, str): + scale_cls = scale._get_scale_cls_from_str(n) + norms[i] = mpl.colorizer._auto_norm_from_scale(scale_cls)() + + # Convert the list of norms to a tuple to make it immutable. + # If there is a use case for swapping a single norm, we can add support for + # that later + self._norms = tuple(norms) + + self.callbacks = cbook.CallbackRegistry(signals=["changed"]) + + self.vmin = vmin + self.vmax = vmax + self.clip = clip + + self._id_norms = [n.callbacks.connect('changed', + self._changed) for n in self._norms] + + @property + def n_input(self): + return len(self._norms) + + @property + def n_output(self): + return len(self._norms) + + @property + def norms(self): + return self._norms + + @property + def vmin(self): + return tuple(n.vmin for n in self._norms) + + @vmin.setter + def vmin(self, value): + if not np.iterable(value): + value = [value]*self.n_input + if len(value) != self.n_input: + raise ValueError(f"Invalid vmin for `MultiNorm` with {self.n_input}" + " inputs.") + with self.callbacks.blocked(): + for i, v in enumerate(value): + if v is not None: + self.norms[i].vmin = v + self._changed() + + @property + def vmax(self): + return tuple(n.vmax for n in self._norms) + + @vmax.setter + def vmax(self, value): + if not np.iterable(value): + value = [value]*self.n_input + if len(value) != self.n_input: + raise ValueError(f"Invalid vmax for `MultiNorm` with {self.n_input}" + " inputs.") + with self.callbacks.blocked(): + for i, v in enumerate(value): + if v is not None: + self.norms[i].vmax = v + self._changed() + + @property + def clip(self): + return tuple(n.clip for n in self._norms) + + @clip.setter + def clip(self, value): + if not np.iterable(value): + value = [value]*self.n_input + with self.callbacks.blocked(): + for i, v in enumerate(value): + if v is not None: + self.norms[i].clip = v + self._changed() + + def _changed(self): + """ + Call this whenever the norm is changed to notify all the + callback listeners to the 'changed' signal. + """ + self.callbacks.process('changed') + + def __call__(self, value, clip=None): + """ + Normalize the data and return the normalized data. + Each variate in the input is assigned to the a constituent norm. + + Parameters + ---------- + value + Data to normalize. Must be of length `n_input` or have a data type with + `n_input` fields. + clip : List of bools or bool, optional + See the description of the parameter *clip* in Normalize. + If ``None``, defaults to ``self.clip`` (which defaults to + ``False``). + + Returns + ------- + Data + Normalized input values as a list of length `n_input` + + Notes + ----- + If not already initialized, ``self.vmin`` and ``self.vmax`` are + initialized using ``self.autoscale_None(value)``. + """ + if clip is None: + clip = self.clip + elif not np.iterable(clip): + clip = [clip]*self.n_input + + value = self._iterable_variates_in_data(value, self.n_input) + result = [n(v, clip=c) for n, v, c in zip(self.norms, value, clip)] + return result + + def inverse(self, value): + """ + Maps the normalized value (i.e., index in the colormap) back to image + data value. + + Parameters + ---------- + value + Normalized value. Must be of length `n_input` or have a data type with + `n_input` fields. + """ + value = self._iterable_variates_in_data(value, self.n_input) + result = [n.inverse(v) for n, v in zip(self.norms, value)] + return result + + def autoscale(self, A): + """ + For each constituent norm, Set *vmin*, *vmax* to min, max of the corresponding + variate in *A*. + """ + with self.callbacks.blocked(): + # Pause callbacks while we are updating so we only get + # a single update signal at the end + A = self._iterable_variates_in_data(A, self.n_input) + for n, a in zip(self.norms, A): + n.autoscale(a) + self._changed() + + def autoscale_None(self, A): + """ + If *vmin* or *vmax* are not set on any constituent norm, + use the min/max of the corresponding variate in *A* to set them. + + Parameters + ---------- + A + Data, must be of length `n_input` or be an np.ndarray type with + `n_input` fields. + """ + with self.callbacks.blocked(): + A = self._iterable_variates_in_data(A, self.n_input) + for n, a in zip(self.norms, A): + n.autoscale_None(a) + self._changed() + + def scaled(self): + """Return whether both *vmin* and *vmax* are set on all constitient norms""" + return all([(n.vmin is not None and n.vmax is not None) for n in self.norms]) + + @staticmethod + def _iterable_variates_in_data(data, n_input): + """ + Provides an iterable over the variates contained in the data. + + An input array with n_input fields is returned as a list of length n referencing + slices of the original array. + + Parameters + ---------- + data : np.ndarray, tuple or list + The input array. It must either be an array with n_input fields or have + a length (n_input) + + Returns + ------- + list of np.ndarray + + """ + if isinstance(data, np.ndarray) and data.dtype.fields is not None: + data = [data[descriptor[0]] for descriptor in data.dtype.descr] + if not len(data) == n_input: + raise ValueError("The input to this `MultiNorm` must be of shape " + f"({n_input}, ...), or have a data type with {n_input} " + "fields.") + return data + + def rgb_to_hsv(arr): """ Convert an array of float RGB values (in the range [0, 1]) to HSV values. diff --git a/lib/matplotlib/colors.pyi b/lib/matplotlib/colors.pyi index 3e761c949068..3f9e0c9d93e8 100644 --- a/lib/matplotlib/colors.pyi +++ b/lib/matplotlib/colors.pyi @@ -263,6 +263,10 @@ class Normalize: @vmax.setter def vmax(self, value: float | None) -> None: ... @property + def n_input(self) -> int: ... + @property + def n_output(self) -> int: ... + @property def clip(self) -> bool: ... @clip.setter def clip(self, value: bool) -> None: ... @@ -387,6 +391,34 @@ class BoundaryNorm(Normalize): class NoNorm(Normalize): ... +class MultiNorm(Normalize): + # Here "type: ignore[override]" is used for functions with a return type + # that differs from the function in the base class. + # i.e. where `MultiNorm` returns a tuple and Normalize returns a `float` etc. + def __init__( + self, + norms: ArrayLike, + vmin: ArrayLike | float | None = ..., + vmax: ArrayLike | float | None = ..., + clip: ArrayLike | bool = ... + ) -> None: ... + @property + def norms(self) -> tuple: ... + @property # type: ignore[override] + def vmin(self) -> tuple[float | None]: ... + @vmin.setter + def vmin(self, value: ArrayLike | float | None) -> None: ... + @property # type: ignore[override] + def vmax(self) -> tuple[float | None]: ... + @vmax.setter + def vmax(self, value: ArrayLike | float | None) -> None: ... + @property # type: ignore[override] + def clip(self) -> tuple[bool]: ... + @clip.setter + def clip(self, value: ArrayLike | bool) -> None: ... + def __call__(self, value: ArrayLike, clip: ArrayLike | bool | None = ...) -> list: ... # type: ignore[override] + def inverse(self, value: ArrayLike) -> list: ... # type: ignore[override] + def rgb_to_hsv(arr: ArrayLike) -> np.ndarray: ... def hsv_to_rgb(hsv: ArrayLike) -> np.ndarray: ... diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index bd1254c27fe1..904a19db44d3 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -641,20 +641,37 @@ def write_png(self, fname): PIL.Image.fromarray(im).save(fname, format="png") @staticmethod - def _normalize_image_array(A): + def _normalize_image_array(A, n_input=1): """ Check validity of image-like input *A* and normalize it to a format suitable for Image subclasses. """ + A = mcolorizer._ensure_multivariate_data(A, n_input) A = cbook.safe_masked_invalid(A, copy=True) - if A.dtype != np.uint8 and not np.can_cast(A.dtype, float, "same_kind"): - raise TypeError(f"Image data of dtype {A.dtype} cannot be " - f"converted to float") + if n_input == 1: + if A.dtype != np.uint8 and not np.can_cast(A.dtype, float, "same_kind"): + raise TypeError(f"Image data of dtype {A.dtype} cannot be " + f"converted to float") + else: + for key in A.dtype.fields: + if not np.can_cast(A[key].dtype, float, "same_kind"): + raise TypeError(f"Image data of dtype {A.dtype} cannot be " + f"converted to a sequence of floats") if A.ndim == 3 and A.shape[-1] == 1: A = A.squeeze(-1) # If just (M, N, 1), assume scalar and apply colormap. if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]): + if A.ndim == 3 and A.shape[0] == 2: + raise TypeError(f"Invalid shape {A.shape} for image data." + " For multivariate data a valid colormap must be" + " explicitly declared, for example" + f" cmap='BiOrangeBlue' or cmap='2VarAddA'") + if A.ndim == 3 and A.shape[0] > 2 and A.shape[0] <= 8: + raise TypeError(f"Invalid shape {A.shape} for image data." + " For multivariate data a multivariate colormap" + " must be explicitly declared, for example" + f" cmap='{A.shape[0]}VarAddA'") raise TypeError(f"Invalid shape {A.shape} for image data") - if A.ndim == 3: + if A.ndim == 3 and n_input == 1: # If the input data has values outside the valid range (after # normalisation), we issue a warning and then clip X to the bounds # - otherwise casting wraps extreme values, hiding outliers and @@ -685,7 +702,7 @@ def set_data(self, A): """ if isinstance(A, PIL.Image.Image): A = pil_to_array(A) # Needed e.g. to apply png palette. - self._A = self._normalize_image_array(A) + self._A = self._normalize_image_array(A, self.norm.n_input) self._imcache = None self.stale = True @@ -753,6 +770,12 @@ def set_interpolation_stage(self, s): """ s = mpl._val_or_rc(s, 'image.interpolation_stage') _api.check_in_list(['data', 'rgba', 'auto'], s=s) + if self.norm.n_input > 1: + if s == 'data': + raise ValueError("'rgba' is the only interpolation stage" + " available for multivariate data.") + else: + s = 'rgba' self._interpolation_stage = s self.stale = True diff --git a/lib/matplotlib/scale.py b/lib/matplotlib/scale.py index 44fbe5209c4d..e1e5884a0617 100644 --- a/lib/matplotlib/scale.py +++ b/lib/matplotlib/scale.py @@ -715,6 +715,35 @@ def get_scale_names(): return sorted(_scale_mapping) +def _get_scale_cls_from_str(scale_as_str): + """ + Returns the scale class from a string. + + Used in the creation of norms from a string to ensure a reasonable error + in the case where an invalid string is used. This cannot use + `_api.check_getitem()`, because the norm keyword accepts arguments + other than strings. + + Parameters + ---------- + scale_as_str : string + A string corresponding to a scale + + Returns + ------- + A subclass of ScaleBase. + + """ + try: + scale_cls = _scale_mapping[scale_as_str] + except KeyError: + raise ValueError( + "Invalid norm str name; the following values are " + f"supported: {', '.join(_scale_mapping)}" + ) from None + return scale_cls + + def scale_factory(scale, axis, **kwargs): """ Return a scale class by name. diff --git a/lib/matplotlib/streamplot.py b/lib/matplotlib/streamplot.py index ece8bebf8192..725fff7b23fd 100644 --- a/lib/matplotlib/streamplot.py +++ b/lib/matplotlib/streamplot.py @@ -6,7 +6,7 @@ import numpy as np import matplotlib as mpl -from matplotlib import _api, cm, patches +from matplotlib import _api, colorizer, patches import matplotlib.colors as mcolors import matplotlib.collections as mcollections import matplotlib.lines as mlines @@ -228,7 +228,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, if use_multicolor_lines: if norm is None: norm = mcolors.Normalize(color.min(), color.max()) - cmap = cm._ensure_cmap(cmap) + cmap = colorizer._ensure_cmap(cmap) streamlines = [] arrows = [] diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivar_cmap_from_image.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivar_cmap_from_image.png new file mode 100644 index 000000000000..a69968132949 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivar_cmap_from_image.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_cmap_call.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_cmap_call.png new file mode 100644 index 000000000000..e350c6c5c180 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_cmap_call.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_cmap_shapes.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_cmap_shapes.png new file mode 100644 index 000000000000..0bb37e451772 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_cmap_shapes.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_masked_data.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_masked_data.png new file mode 100644 index 000000000000..982348635143 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_masked_data.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_visualizations.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_visualizations.png new file mode 100644 index 000000000000..949688a678ea Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/bivariate_visualizations.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivar_cmap_call.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivar_cmap_call.png new file mode 100644 index 000000000000..3941300a3a29 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivar_cmap_call.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_figimage.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_figimage.png new file mode 100644 index 000000000000..934414cde061 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_figimage.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_alpha.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_alpha.png new file mode 100644 index 000000000000..81697387500b Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_alpha.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_complex_data.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_complex_data.png new file mode 100644 index 000000000000..f47f17653c18 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_complex_data.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_norm.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_norm.png new file mode 100644 index 000000000000..f4b02be22436 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_imshow_norm.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_pcolormesh_alpha.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_pcolormesh_alpha.png new file mode 100644 index 000000000000..98d18d4a9de8 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_pcolormesh_alpha.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_pcolormesh_norm.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_pcolormesh_norm.png new file mode 100644 index 000000000000..d9e6fb277e0d Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_pcolormesh_norm.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_visualizations.png b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_visualizations.png new file mode 100644 index 000000000000..8f35a5590ce5 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_multivariate_axes/multivariate_visualizations.png differ diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index 8d0f3467f045..4c295a557ccc 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -1828,3 +1828,44 @@ def test_LinearSegmentedColormap_from_list_value_color_tuple(): cmap([value for value, _ in value_color_tuples]), to_rgba_array([color for _, color in value_color_tuples]), ) + + +def test_multi_norm(): + # tests for mcolors.MultiNorm + + # test wrong input + with pytest.raises(ValueError, + match="A MultiNorm must be assigned multiple norms"): + mcolors.MultiNorm("bad_norm_name") + with pytest.raises(ValueError, + match="Invalid norm str name"): + mcolors.MultiNorm(["bad_norm_name"]) + + # test get vmin, vmax + norm = mpl.colors.MultiNorm(['linear', 'log']) + norm.vmin = 1 + norm.vmax = 2 + assert norm.vmin[0] == 1 + assert norm.vmin[1] == 1 + assert norm.vmax[0] == 2 + assert norm.vmax[1] == 2 + + # test call with clip + assert_array_equal(norm([3, 3], clip=False), [2.0, 1.584962500721156]) + assert_array_equal(norm([3, 3], clip=True), [1.0, 1.0]) + assert_array_equal(norm([3, 3], clip=[True, False]), [1.0, 1.584962500721156]) + norm.clip = False + assert_array_equal(norm([3, 3]), [2.0, 1.584962500721156]) + norm.clip = True + assert_array_equal(norm([3, 3]), [1.0, 1.0]) + norm.clip = [True, False] + assert_array_equal(norm([3, 3]), [1.0, 1.584962500721156]) + norm.clip = True + + # test inverse + assert_array_almost_equal(norm.inverse([0.5, 0.5849625007211562]), [1.5, 1.5]) + + # test autoscale + norm.autoscale([[0, 1, 2, 3], [0.1, 1, 2, 3]]) + assert_array_equal(norm.vmin, [0, 0.1]) + assert_array_equal(norm.vmax, [3, 3]) diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index 0e9f3fb37fbd..1e1b9ed244c8 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -14,7 +14,7 @@ import matplotlib as mpl from matplotlib import ( - colors, image as mimage, patches, pyplot as plt, style, rcParams) + cbook, colors, image as mimage, patches, pyplot as plt, style, rcParams) from matplotlib.image import (AxesImage, BboxImage, FigureImage, NonUniformImage, PcolorImage) from matplotlib.testing.decorators import check_figures_equal, image_comparison @@ -1130,8 +1130,14 @@ def test_image_cursor_formatting(): data = np.ma.masked_array([0], mask=[False]) assert im.format_cursor_data(data) == '[0]' - data = np.nan - assert im.format_cursor_data(data) == '[nan]' + # This used to test + # > data = np.nan + # > assert im.format_cursor_data(data) == '[nan]' + # However, a value of nan will be masked by `cbook.safe_masked_invalid(data)` + # called by `image._ImageBase._normalize_image_array(data)` + # The test is therefore changed to: + data = cbook.safe_masked_invalid(np.array(np.nan)) + assert im.format_cursor_data(data) == '[]' @check_figures_equal(extensions=['png', 'pdf', 'svg']) diff --git a/lib/matplotlib/tests/test_multivariate_axes.py b/lib/matplotlib/tests/test_multivariate_axes.py new file mode 100644 index 000000000000..4fd0f0de6295 --- /dev/null +++ b/lib/matplotlib/tests/test_multivariate_axes.py @@ -0,0 +1,688 @@ +import numpy as np +from numpy.testing import assert_array_equal, assert_allclose +import matplotlib.pyplot as plt +from matplotlib.testing.decorators import (image_comparison, + remove_ticks_and_titles) +import matplotlib.colors as mcolors +from matplotlib import cbook +import matplotlib as mpl +import pytest +import re + + +@image_comparison(["bivariate_visualizations.png"]) +def test_bivariate_visualizations(): + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + + fig, axes = plt.subplots(1, 7, figsize=(12, 2)) + + axes[0].imshow((x_0, x_1), cmap='BiPeak', interpolation='nearest') + axes[1].matshow((x_0, x_1), cmap='BiPeak') + axes[2].pcolor((x_0, x_1), cmap='BiPeak') + axes[3].pcolormesh((x_0, x_1), cmap='BiPeak') + + x = np.arange(5) + y = np.arange(5) + X, Y = np.meshgrid(x, y) + axes[4].pcolor(X, Y, (x_0, x_1), cmap='BiPeak') + axes[5].pcolormesh(X, Y, (x_0, x_1), cmap='BiPeak') + + patches = [ + mpl.patches.Wedge((.3, .7), .1, 0, 360), # Full circle + mpl.patches.Wedge((.7, .8), .2, 0, 360, width=0.05), # Full ring + mpl.patches.Wedge((.8, .3), .2, 0, 45), # Full sector + mpl.patches.Wedge((.8, .3), .2, 22.5, 90, width=0.10), # Ring sector + ] + colors_0 = np.arange(len(patches)) // 2 + colors_1 = np.arange(len(patches)) % 2 + p = mpl.collections.PatchCollection(patches, cmap='BiPeak', alpha=0.5) + p.set_array((colors_0, colors_1)) + axes[6].add_collection(p) + + remove_ticks_and_titles(fig) + + +@image_comparison(["multivariate_visualizations.png"]) +def test_multivariate_visualizations(): + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + x_2 = np.arange(25, dtype='float32').reshape(5, 5) % 6 + + fig, axes = plt.subplots(1, 7, figsize=(12, 2)) + + axes[0].imshow((x_0, x_1, x_2), cmap='3VarAddA', interpolation='nearest') + axes[1].matshow((x_0, x_1, x_2), cmap='3VarAddA') + axes[2].pcolor((x_0, x_1, x_2), cmap='3VarAddA') + axes[3].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA') + + x = np.arange(5) + y = np.arange(5) + X, Y = np.meshgrid(x, y) + axes[4].pcolor(X, Y, (x_0, x_1, x_2), cmap='3VarAddA') + axes[5].pcolormesh(X, Y, (x_0, x_1, x_2), cmap='3VarAddA') + + patches = [ + mpl.patches.Wedge((.3, .7), .1, 0, 360), # Full circle + mpl.patches.Wedge((.7, .8), .2, 0, 360, width=0.05), # Full ring + mpl.patches.Wedge((.8, .3), .2, 0, 45), # Full sector + mpl.patches.Wedge((.8, .3), .2, 22.5, 90, width=0.10), # Ring sector + ] + colors_0 = np.arange(len(patches)) // 2 + colors_1 = np.arange(len(patches)) % 2 + colors_2 = np.arange(len(patches)) % 3 + p = mpl.collections.PatchCollection(patches, cmap='3VarAddA', alpha=0.5) + p.set_array((colors_0, colors_1, colors_2)) + axes[6].add_collection(p) + + remove_ticks_and_titles(fig) + + +@image_comparison(["multivariate_pcolormesh_alpha.png"]) +def test_multivariate_pcolormesh_alpha(): + """ + Check that the the alpha keyword works for pcolormesh + This test covers all plotting modes that use the same pipeline + (inherit from Collection). + """ + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + x_2 = np.arange(25, dtype='float32').reshape(5, 5) % 6 + + fig, axes = plt.subplots(2, 3) + + axes[0, 0].pcolormesh(x_1, alpha=0.5) + axes[0, 1].pcolormesh((x_0, x_1), cmap='BiPeak', alpha=0.5) + axes[0, 2].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA', alpha=0.5) + + al = np.arange(25, dtype='float32').reshape(5, 5)[::-1].T % 6 / 5 + + axes[1, 0].pcolormesh(x_1, alpha=al) + axes[1, 1].pcolormesh((x_0, x_1), cmap='BiPeak', alpha=al) + axes[1, 2].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA', alpha=al) + + remove_ticks_and_titles(fig) + + +@image_comparison(["multivariate_pcolormesh_norm.png"]) +def test_multivariate_pcolormesh_norm(): + """ + Test vmin, vmax and norm + Norm is checked via a LogNorm, as this converts + A LogNorm converts the input to a masked array, masking for X <= 0 + By using a LogNorm, this functionality is also tested. + This test covers all plotting modes that use the same pipeline + (inherit from Collection). + """ + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + x_2 = np.arange(25, dtype='float32').reshape(5, 5) % 6 + + fig, axes = plt.subplots(3, 5) + + axes[0, 0].pcolormesh(x_1) + axes[0, 1].pcolormesh((x_0, x_1), cmap='BiPeak') + axes[0, 2].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA') + axes[0, 3].pcolormesh((x_0, x_1), cmap='BiPeak') + axes[0, 4].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA') + + vmin = 1 + vmax = 3 + axes[1, 0].pcolormesh(x_1, vmin=vmin, vmax=vmax) + axes[1, 1].pcolormesh((x_0, x_1), cmap='BiPeak', vmin=vmin, vmax=vmax) + axes[1, 2].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA', + vmin=vmin, vmax=vmax) + axes[1, 3].pcolormesh((x_0, x_1), cmap='BiPeak', + vmin=(None, vmin), vmax=(None, vmax)) + axes[1, 4].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA', + vmin=(None, vmin, None), vmax=(None, vmax, None)) + + n = mcolors.LogNorm(vmin=1, vmax=5) + axes[2, 0].pcolormesh(x_1, norm=n) + axes[2, 1].pcolormesh((x_0, x_1), cmap='BiPeak', norm=(n, n)) + axes[2, 2].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA', norm=(n, n, n)) + axes[2, 3].pcolormesh((x_0, x_1), cmap='BiPeak', norm=(None, n)) + axes[2, 4].pcolormesh((x_0, x_1, x_2), cmap='3VarAddA', + norm=(None, n, None)) + + remove_ticks_and_titles(fig) + + +@image_comparison(["multivariate_imshow_alpha.png"]) +def test_multivariate_imshow_alpha(): + """ + Check that the the alpha keyword works for pcolormesh + """ + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + x_2 = np.arange(25, dtype='float32').reshape(5, 5) % 6 + + fig, axes = plt.subplots(2, 3) + + # interpolation='nearest' to reduce size of baseline image + axes[0, 0].imshow(x_1, interpolation='nearest', alpha=0.5) + axes[0, 1].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak', alpha=0.5) + axes[0, 2].imshow((x_0, x_1, x_2), interpolation='nearest', + cmap='3VarAddA', alpha=0.5) + + al = np.arange(25, dtype='float32').reshape(5, 5)[::-1].T % 6 / 5 + + axes[1, 0].imshow(x_1, interpolation='nearest', alpha=al) + axes[1, 1].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak', alpha=al) + axes[1, 2].imshow((x_0, x_1, x_2), interpolation='nearest', + cmap='3VarAddA', alpha=al) + + remove_ticks_and_titles(fig) + + +@image_comparison(["multivariate_imshow_norm.png"]) +def test_multivariate_imshow_norm(): + """ + Test vmin, vmax and norm + Norm is checked via a LogNorm. + A LogNorm converts the input to a masked array, masking for X <= 0 + By using a LogNorm, this functionality is also tested. + """ + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + x_2 = np.arange(25, dtype='float32').reshape(5, 5) % 6 + + fig, axes = plt.subplots(3, 5, dpi=10) + + # interpolation='nearest' to reduce size of baseline image and + # removes ambiguity when using masked array (from LogNorm) + axes[0, 0].imshow(x_1, interpolation='nearest') + axes[0, 1].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak') + axes[0, 2].imshow((x_0, x_1, x_2), interpolation='nearest', cmap='3VarAddA') + axes[0, 3].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak') + axes[0, 4].imshow((x_0, x_1, x_2), interpolation='nearest', cmap='3VarAddA') + + vmin = 1 + vmax = 3 + axes[1, 0].imshow(x_1, interpolation='nearest', vmin=vmin, vmax=vmax) + axes[1, 1].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak', + vmin=vmin, vmax=vmax) + axes[1, 2].imshow((x_0, x_1, x_2), interpolation='nearest', cmap='3VarAddA', + vmin=vmin, vmax=vmax) + axes[1, 3].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak', + vmin=(None, vmin), vmax=(None, vmax)) + axes[1, 4].imshow((x_0, x_1, x_2), interpolation='nearest', cmap='3VarAddA', + vmin=(None, vmin, None), vmax=(None, vmax, None)) + + n = mcolors.LogNorm(vmin=1, vmax=5) + axes[2, 0].imshow(x_1, interpolation='nearest', norm=n) + axes[2, 1].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak', norm=(n, n)) + axes[2, 2].imshow((x_0, x_1, x_2), interpolation='nearest', cmap='3VarAddA', + norm=(n, n, n)) + axes[2, 3].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak', + norm=(None, n)) + axes[2, 4].imshow((x_0, x_1, x_2), interpolation='nearest', cmap='3VarAddA', + norm=(None, n, None)) + + remove_ticks_and_titles(fig) + + +@image_comparison(["multivariate_imshow_complex_data.png"]) +def test_multivariate_imshow_complex_data(): + """ + Tests plotting using complex numbers + """ + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + x = np.zeros((5, 5), dtype='complex128') + x.real[:, :] = x_0 + x.imag[:, :] = x_1 + + fig, axes = plt.subplots(1, 3, figsize=(6, 2)) + axes[0].imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak') + axes[1].imshow(x, interpolation='nearest', cmap='BiPeak') + axes[2].imshow(x.astype('complex64'), interpolation='nearest', cmap='BiPeak') + + remove_ticks_and_titles(fig) + + +@image_comparison(["bivariate_masked_data.png"]) +def test_bivariate_masked_data(): + """ + Tests the masked arrays with multivariate data + + Note that this uses interpolation_stage='rgba' + because interpolation_stage='data' is not yet (dec. 2024) + implemented for multivariate data. + """ + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(25, dtype='float32').reshape(5, 5).T % 5 + + x_0 = np.ma.masked_where(x_0 > 3, x_0) + + fig, axes = plt.subplots(1, 4, figsize=(8, 2)) + axes[0].imshow(x_0, interpolation_stage='rgba') + # one of two arrays masked + axes[1].imshow((x_0, x_1), cmap='BiPeak', interpolation_stage='rgba') + # a single masked array + x = np.ma.array((x_0, x_1)) + axes[2].imshow(x, cmap='BiPeak', interpolation_stage='rgba') + # a single masked array with a complex dtype + m = np.ma.zeros((5, 5), dtype='complex128') + m.real[:, :] = x_0 + m.imag[:, :] = x_1 + m.mask = x_0.mask + axes[3].imshow(m, cmap='BiPeak', interpolation='nearest') + + remove_ticks_and_titles(fig) + + +def test_bivariate_inconsistent_shape(): + x_0 = np.arange(25, dtype='float32').reshape(5, 5) % 5 + x_1 = np.arange(30, dtype='float32').reshape(6, 5).T % 5 + + fig, ax = plt.subplots(1) + with pytest.raises(ValueError, match="For multivariate data all"): + ax.imshow((x_0, x_1), interpolation='nearest', cmap='BiPeak') + + +@image_comparison(["bivariate_cmap_shapes.png"]) +def test_bivariate_cmap_shapes(): + x_0 = np.arange(100, dtype='float32').reshape(10, 10) % 10 + x_1 = np.arange(100, dtype='float32').reshape(10, 10).T % 10 + + fig, axes = plt.subplots(1, 4, figsize=(10, 2)) + + # shape = square + axes[0].imshow((x_0, x_1), cmap='BiPeak', vmin=1, vmax=8, interpolation='nearest') + # shape = cone + axes[1].imshow((x_0, x_1), cmap='BiCone', vmin=0.5, vmax=8.5, + interpolation='nearest') + + # shape = ignore + cmap = mpl.bivar_colormaps['BiCone'] + cmap = cmap.with_extremes(shape='ignore') + axes[2].imshow((x_0, x_1), cmap=cmap, vmin=1, vmax=8, interpolation='nearest') + + # shape = circleignore + cmap = mpl.bivar_colormaps['BiCone'] + cmap = cmap.with_extremes(shape='circleignore') + axes[3].imshow((x_0, x_1), cmap=cmap, vmin=0.5, vmax=8.5, interpolation='nearest') + remove_ticks_and_titles(fig) + + +@image_comparison(["multivariate_figimage.png"]) +def test_multivariate_figimage(): + fig = plt.figure(figsize=(2, 2), dpi=100) + x, y = np.ix_(np.arange(100) / 100.0, np.arange(100) / 100) + z = np.sin(x**2 + y**2 - x*y) + c = np.sin(20*x**2 + 50*y**2) + img = np.stack((z, c)) + + fig.figimage(img, xo=0, yo=0, origin='lower', cmap='BiPeak') + fig.figimage(img[:, ::-1, :], xo=0, yo=100, origin='lower', cmap='BiPeak') + fig.figimage(img[:, :, ::-1], xo=100, yo=0, origin='lower', cmap='BiPeak') + fig.figimage(img[:, ::-1, ::-1], xo=100, yo=100, origin='lower', cmap='BiPeak') + + +@image_comparison(["multivar_cmap_call.png"]) +def test_multivar_cmap_call(): + """ + This evaluates manual calls to a bivariate colormap + The figure exists because implementing an image comparison + is easier than anumeraical comparisons for mulitdimensional arrays + """ + x_0 = np.arange(100, dtype='float32').reshape(10, 10) % 10 + x_1 = np.arange(100, dtype='float32').reshape(10, 10).T % 10 + + fig, axes = plt.subplots(1, 5, figsize=(10, 2)) + + cmap = mpl.multivar_colormaps['2VarAddA'] + + # call with 0D + axes[0].scatter(3, 6, c=[cmap((0.5, 0.5))]) + + # call with 1D + im = cmap((x_0[0]/9, x_1[::-1, 0]/9)) + axes[0].scatter(np.arange(10), np.arange(10), c=im) + + # call with 2D + cmap = mpl.multivar_colormaps['2VarSubA'] + im = cmap((x_0/9, x_1/9)) + axes[1].imshow(im, interpolation='nearest') + + # call with 3D array + im = cmap(((x_0/9, x_0/9), + (x_1/9, x_1/9))) + axes[2].imshow(im.reshape((20, 10, 4)), interpolation='nearest') + + # call with constant alpha, and data of type int + im = cmap((x_0.astype('int')*25, x_1.astype('int')*25), alpha=0.5) + axes[3].imshow(im, interpolation='nearest') + + # call with variable alpha + im = cmap((x_0/9, x_1/9), alpha=(x_0/9)**2, bytes=True) + axes[4].imshow(im, interpolation='nearest') + + # call with wrong shape + with pytest.raises(ValueError, match="must have a first dimension 2"): + cmap((0, 0, 0)) + + # call with wrong shape of alpha + with pytest.raises(ValueError, match=r"shape \(2,\) does not match " + r"that of X\[0\] \(\)"): + cmap((0, 0), alpha=(1, 1)) + remove_ticks_and_titles(fig) + + +def test_multivar_set_array(): + # compliments test_collections.test_collection_set_array() + vals = np.arange(24).reshape((2, 3, 4)) + c = mpl.collections.Collection(cmap="BiOrangeBlue") + + c.set_array(vals) + vals = np.empty((2, 3, 4), dtype='object') + with pytest.raises(TypeError, match="^Image data of dtype"): + # Test set_array with wrong dtype + c.set_array(vals) + + +def test_correct_error_on_multivar_colormap_in_streamplot(): + w = 3 + Y, X = np.mgrid[-w:w:100j, -w:w:100j] + U = -1 - X**2 + Y + V = 1 + X - Y**2 + speed = np.sqrt(U**2 + V**2) + + fig, ax = plt.subplots(1, 1) + # Varying color along a streamline + strm = ax.streamplot(X, Y, U, V, color=U, linewidth=2) + with pytest.raises(ValueError, match=r"'BiOrangeBlue' is not a "): + strm = ax.streamplot(X, Y, U, V, color=U, linewidth=2, cmap='BiOrangeBlue') + + +def test_multivar_cmap_call_tuple(): + cmap = mpl.multivar_colormaps['2VarAddA'] + assert_array_equal(cmap((0.0, 0.0)), (0, 0, 0, 1)) + assert_array_equal(cmap((1.0, 1.0)), (1, 1, 1, 1)) + assert_allclose(cmap((0.0, 0.0), alpha=0.1), (0, 0, 0, 0.1), atol=0.1) + + cmap = mpl.multivar_colormaps['2VarSubA'] + assert_array_equal(cmap((0.0, 0.0)), (1, 1, 1, 1)) + assert_allclose(cmap((1.0, 1.0)), (0, 0, 0, 1), atol=0.1) + + +@image_comparison(["bivariate_cmap_call.png"]) +def test_bivariate_cmap_call(): + """ + This evaluates manual calls to a bivariate colormap + The figure exists because implementing an image comparison + is easier than anumeraical comparisons for mulitdimensional arrays + """ + x_0 = np.arange(100, dtype='float32').reshape(10, 10) % 10 + x_1 = np.arange(100, dtype='float32').reshape(10, 10).T % 10 + + fig, axes = plt.subplots(1, 5, figsize=(10, 2)) + + cmap = mpl.bivar_colormaps['BiCone'] + + # call with 1D + im = cmap((x_0[0]/9, x_1[::-1, 0]/9)) + axes[0].scatter(np.arange(10), np.arange(10), c=im) + + # call with 2D + im = cmap((x_0/9, x_1/9)) + axes[1].imshow(im, interpolation='nearest') + + # call with 3D array + im = cmap(((x_0/9, x_0/9), + (x_1/9, x_1/9))) + axes[2].imshow(im.reshape((20, 10, 4)), interpolation='nearest') + + cmap = mpl.bivar_colormaps['BiOrangeBlue'] + # call with constant alpha, and data of type int + im = cmap((x_0.astype('int')*25, x_1.astype('int')*25), alpha=0.5) + axes[3].imshow(im, interpolation='nearest') + + # call with variable alpha + im = cmap((x_0/9, x_1/9), alpha=(x_0/9)**2, bytes=True) + axes[4].imshow(im, interpolation='nearest') + + # call with wrong shape + with pytest.raises(ValueError, match="must have a first dimension 2"): + cmap((0, 0, 0)) + + # call with wrong shape of alpha + with pytest.raises(ValueError, match=r"shape \(2,\) does not match " + r"that of X\[0\] \(\)"): + cmap((0, 0), alpha=(1, 1)) + + remove_ticks_and_titles(fig) + + +def test_bivar_cmap_call_tuple(): + cmap = mpl.bivar_colormaps['BiOrangeBlue'] + assert_allclose(cmap((1.0, 1.0)), (1, 1, 1, 1), atol=0.01) + assert_allclose(cmap((0.0, 0.0)), (0, 0, 0, 1), atol=0.1) + assert_allclose(cmap((0.0, 0.0), alpha=0.1), (0, 0, 0, 0.1), atol=0.1) + + +@image_comparison(["bivar_cmap_from_image.png"]) +def test_bivar_cmap_from_image(): + """ + This tests the creation and use of a bivariate colormap + generated from an image + """ + # create bivariate colormap + im = np.ones((10, 12, 4)) + im[:, :, 0] = np.arange(10)[:, np.newaxis]/10 + im[:, :, 1] = np.arange(12)[np.newaxis, :]/12 + fig, axes = plt.subplots(1, 2) + axes[0].imshow(im, interpolation='nearest') + + # use bivariate colormap + data_0 = np.arange(12).reshape((3, 4)) + data_1 = np.arange(12).reshape((4, 3)).T + cmap = mpl.colors.BivarColormapFromImage(im) + axes[1].imshow((data_0, data_1), cmap=cmap, + interpolation='nearest') + + remove_ticks_and_titles(fig) + + +def test_wrong_multivar_clim_shape(): + fig, ax = plt.subplots() + im = np.arange(24).reshape((2, 3, 4)) + with pytest.raises(ValueError, match="Invalid vmin for `MultiNorm` with"): + ax.imshow(im, cmap='BiPeak', vmin=(None, None, None)) + with pytest.raises(ValueError, match="Invalid vmax for `MultiNorm` with"): + ax.imshow(im, cmap='BiPeak', vmax=(None, None, None)) + + +def test_wrong_multivar_norm_shape(): + fig, ax = plt.subplots() + im = np.arange(24).reshape((2, 3, 4)) + with pytest.raises(ValueError, match="for multivariate colormap with 2"): + ax.imshow(im, cmap='BiPeak', norm=(None, None, None)) + + +def test_wrong_multivar_data_shape(): + fig, ax = plt.subplots() + im = np.arange(12).reshape((1, 3, 4)) + with pytest.raises(ValueError, match="The data must contain complex numbers, or"): + ax.imshow(im, cmap='BiPeak') + im = np.arange(12).reshape((3, 4)) + with pytest.raises(ValueError, match="The data must contain complex numbers, or"): + ax.imshow(im, cmap='BiPeak') + + +def test_missing_multivar_cmap_imshow(): + fig, ax = plt.subplots() + im = np.arange(200).reshape((2, 10, 10)) + with pytest.raises(TypeError, + match=("a valid colormap must be explicitly declared" + + ", for example cmap='BiOrangeBlue'")): + ax.imshow(im) + im = np.arange(300).reshape((3, 10, 10)) + with pytest.raises(TypeError, + match=("multivariate colormap must be explicitly declared" + + ", for example cmap='3VarAddA")): + ax.imshow(im) + im = np.arange(1000).reshape((10, 10, 10)) + with pytest.raises(TypeError, + match=re.escape("Invalid shape (10, 10, 10) for image data")): + ax.imshow(im) + + +def test_setting_A_on_ColorizingArtist(): + # correct use + vm = mpl.colorizer.ColorizingArtist(mpl.colorizer.Colorizer(cmap='3VarAddA')) + data = np.arange(3*25).reshape((3, 5, 5)) + vm.set_array(data) + + # attempting to set wrong shape of data + with pytest.raises(ValueError, match=re.escape( + " must have a first dimension 3 or be of" + )): + data = np.arange(2*25).reshape((2, 5, 5)) + vm.set_array(data) + + +def test_setting_cmap_on_Colorizer(): + # correct use + colorizer = mpl.colorizer.Colorizer(cmap='BiOrangeBlue') + colorizer.cmap = 'BiPeak' + # incorrect use + with pytest.raises(ValueError, match=re.escape( + "The colormap viridis does not support 2" + )): + colorizer.cmap = 'viridis' + # incorrect use + colorizer = mpl.colorizer.Colorizer() + with pytest.raises(ValueError, match=re.escape( + "The colormap BiOrangeBlue does not support 1" + )): + colorizer.cmap = 'BiOrangeBlue' + + +def test_setting_norm_on_Colorizer(): + # correct use + vm = mpl.colorizer.Colorizer(cmap='3VarAddA') + vm.norm = 'linear' + vm.norm = ['linear', 'log', 'asinh'] + vm.norm = (None, None, None) + + # attempting to set wrong shape of norm + with pytest.raises(ValueError, match=re.escape( + "Invalid norm for multivariate colormap with 3 inputs." + )): + vm.norm = (None, None) + + +def test_setting_clim_on_ScalarMappable(): + # correct use + vm = mpl.colorizer.Colorizer(cmap='3VarAddA') + vm.set_clim(0, 1) + vm.set_clim([0, 0, 0], [1, 2, 3]) + # attempting to set wrong shape of vmin/vmax + with pytest.raises(ValueError, match=re.escape( + "Invalid vmin for `MultiNorm` with 3 inputs")): + vm.set_clim(vmin=[0, 0]) + + +def test_bivar_eq(): + """ + Tests equality between bivariate colormaps + """ + cmap_0 = mpl.bivar_colormaps['BiOrangeBlue'] + + cmap_1 = mpl.bivar_colormaps['BiOrangeBlue'] + assert (cmap_0 == cmap_1) is True + + cmap_1 = mpl.multivar_colormaps['2VarAddA'] + assert (cmap_0 == cmap_1) is False + + cmap_1 = mpl.bivar_colormaps['BiPeak'] + assert (cmap_0 == cmap_1) is False + + cmap_1 = mpl.bivar_colormaps['BiOrangeBlue'] + cmap_1 = cmap_1.with_extremes(bad='k') + assert (cmap_0 == cmap_1) is False + + cmap_1 = mpl.bivar_colormaps['BiOrangeBlue'] + cmap_1 = cmap_1.with_extremes(outside='k') + assert (cmap_0 == cmap_1) is False + + cmap_1 = mpl.bivar_colormaps['BiOrangeBlue'] + cmap_1 = cmap_1.with_extremes(shape='circle') + assert (cmap_0 == cmap_1) is False + + +def test_cmap_error(): + data = np.ones((2, 4, 4)) + for call in (plt.imshow, plt.pcolor, plt.pcolormesh): + with pytest.raises(ValueError, match=re.escape( + "See matplotlib.bivar_colormaps() and matplotlib.multivar_colormaps()" + " for bivariate and multivariate colormaps." + )): + call(data, cmap='not_a_cmap') + + with pytest.raises(ValueError, match=re.escape( + "See matplotlib.bivar_colormaps() and matplotlib.multivar_colormaps()" + " for bivariate and multivariate colormaps." + )): + mpl.collections.PatchCollection([], cmap='not_a_cmap') + + +def test_artist_format_cursor_data_multivar(): + + X = np.zeros((4, 3)) + X[0, 0] = 0.9 + X[0, 1] = 0.99 + X[0, 2] = 0.999 + X[1, 0] = -1 + X[1, 1] = 0 + X[1, 2] = 1 + X[2, 0] = 0.09 + X[2, 1] = 0.009 + X[2, 2] = 0.0009 + X[3, 0] = np.nan + + Y = np.arange(np.prod(X.shape)).reshape(X.shape) + + labels_list = [ + "[0.9, 0.00]", + "[1., 1.00]", + "[1., 2.00]", + "[-1.0, 3.00]", + "[0.0, 4.00]", + "[1.0, 5.00]", + "[0.09, 6.00]", + "[0.009, 7.00]", + "[0.0009, 8.00]", + "[]", + ] + + pos = [[0, 0], [1, 0], [2, 0], + [0, 1], [1, 1], [2, 1], + [0, 2], [1, 2], [2, 2], + [3, 0]] + + from matplotlib.backend_bases import MouseEvent + + for cmap in ['BiOrangeBlue', '2VarAddA']: + fig, ax = plt.subplots() + norm = mpl.colors.BoundaryNorm(np.linspace(-1, 1, 20), 256) + data = (X, Y) + im = ax.imshow(data, cmap=cmap, norm=(norm, None)) + + for v, text in zip(pos, labels_list): + xdisp, ydisp = ax.transData.transform(v) + event = MouseEvent('motion_notify_event', fig.canvas, xdisp, ydisp) + assert im.format_cursor_data(im.get_cursor_data(event)) == text + + +def test_multivariate_safe_masked_invalid(): + dt = np.dtype('float32, float32').newbyteorder('>') + x = np.zeros(2, dtype=dt) + x['f0'][0] = np.nan + xm = cbook.safe_masked_invalid(x) + assert (xm['f0'].mask == (True, False)).all() + assert (xm['f1'].mask == (False, False)).all() + assert '