diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index d90921158ee5..683f62763cb8 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -690,7 +690,17 @@ def safe_masked_invalid(x, copy=False): try: xm = np.ma.masked_where(~(np.isfinite(x)), x, copy=False) except TypeError: - return x + if len(x.dtype.descr) == 1: + return x + else: + # in case of a dtype with multiple fields: + try: + mask = np.empty(x.shape, dtype=np.dtype('bool, '*len(x.dtype.descr))) + for dd, dm in zip(x.dtype.descr, mask.dtype.descr): + mask[dm[0]] = ~(np.isfinite(x[dd[0]])) + xm = np.ma.array(x, mask=mask, copy=False) + except TypeError: + return x return xm diff --git a/lib/matplotlib/cm.py b/lib/matplotlib/cm.py index 2697666b9573..d8c3dafcfe46 100644 --- a/lib/matplotlib/cm.py +++ b/lib/matplotlib/cm.py @@ -278,32 +278,3 @@ def get_cmap(name=None, lut=None): return _colormaps[name] else: return _colormaps[name].resampled(lut) - - -def _ensure_cmap(cmap): - """ - Ensure that we have a `.Colormap` object. - - For internal use to preserve type stability of errors. - - Parameters - ---------- - cmap : None, str, Colormap - - - if a `Colormap`, return it - - if a string, look it up in mpl.colormaps - - if None, look up the default color map in mpl.colormaps - - Returns - ------- - Colormap - - """ - if isinstance(cmap, colors.Colormap): - return cmap - cmap_name = mpl._val_or_rc(cmap, "image.cmap") - # use check_in_list to ensure type stability of the exception raised by - # the internal usage of this (ValueError vs KeyError) - if cmap_name not in _colormaps: - _api.check_in_list(sorted(_colormaps), cmap=cmap_name) - return mpl.colormaps[cmap_name] diff --git a/lib/matplotlib/colorizer.py b/lib/matplotlib/colorizer.py index b4223f389804..560ed647450e 100644 --- a/lib/matplotlib/colorizer.py +++ b/lib/matplotlib/colorizer.py @@ -24,7 +24,7 @@ import numpy as np from numpy import ma -from matplotlib import _api, colors, cbook, scale, artist +from matplotlib import _api, colors, cbook, artist, cm import matplotlib as mpl mpl._docstring.interpd.register( @@ -78,7 +78,7 @@ def _scale_norm(self, norm, vmin, vmax, A): raise ValueError( "Passing a Normalize instance simultaneously with " "vmin/vmax is not supported. Please pass vmin/vmax " - "directly to the norm when creating it.") + "directly to the norm when creating it") # always resolve the autoscaling so we have concrete limits # rather than deferring to draw time. @@ -90,19 +90,7 @@ def norm(self): @norm.setter def norm(self, norm): - _api.check_isinstance((colors.Normalize, str, None), norm=norm) - if norm is None: - norm = colors.Normalize() - elif isinstance(norm, str): - try: - scale_cls = scale._scale_mapping[norm] - except KeyError: - raise ValueError( - "Invalid norm str name; the following values are " - f"supported: {', '.join(scale._scale_mapping)}" - ) from None - norm = _auto_norm_from_scale(scale_cls)() - + norm = _ensure_norm(norm, n_variates=self.cmap.n_variates) if norm is self.norm: # We aren't updating anything return @@ -186,7 +174,7 @@ def _pass_image_data(x, alpha=None, bytes=False, norm=True): if norm and (xx.max() > 1 or xx.min() < 0): raise ValueError("Floating point image RGB values " - "must be in the 0..1 range.") + "must be in the 0..1 range") if bytes: xx = (xx * 255).astype(np.uint8) elif xx.dtype == np.uint8: @@ -232,9 +220,14 @@ def _set_cmap(self, cmap): cmap : `.Colormap` or str or None """ # bury import to avoid circular imports - from matplotlib import cm in_init = self._cmap is None - self._cmap = cm._ensure_cmap(cmap) + cmap_obj = _ensure_cmap(cmap, accept_multivariate=True) + if not in_init: + if self.norm.n_variables != cmap_obj.n_variates: + raise ValueError(f"The colormap {cmap} does not support " + f"{self.norm.n_variables} variates as required by " + f"the {type(self.norm)} on this Colorizer") + self._cmap = cmap_obj if not in_init: self.changed() # Things are not set up properly yet. @@ -255,31 +248,33 @@ def set_clim(self, vmin=None, vmax=None): vmin, vmax : float The limits. - The limits may also be passed as a tuple (*vmin*, *vmax*) as a - single positional argument. + For scalar data, the limits may also be passed as a + tuple (*vmin*, *vmax*) single positional argument. .. ACCEPTS: (vmin: float, vmax: float) """ + if self.norm.n_variables == 1: + if vmax is None: + try: + vmin, vmax = vmin + except (TypeError, ValueError): + pass + # If the norm's limits are updated self.changed() will be called # through the callbacks attached to the norm, this causes an inconsistent # state, to prevent this blocked context manager is used - if vmax is None: - try: - vmin, vmax = vmin - except (TypeError, ValueError): - pass - orig_vmin_vmax = self.norm.vmin, self.norm.vmax # Blocked context manager prevents callbacks from being triggered # until both vmin and vmax are updated with self.norm.callbacks.blocked(signal='changed'): + # Since the @vmin/vmax.setter invokes colors._sanitize_extrema() + # to sanitize the input, the input is not sanitized here if vmin is not None: - self.norm.vmin = colors._sanitize_extrema(vmin) + self.norm.vmin = vmin if vmax is not None: - self.norm.vmax = colors._sanitize_extrema(vmax) + self.norm.vmax = vmax - # emit a update signal if the limits are changed if orig_vmin_vmax != (self.norm.vmin, self.norm.vmax): self.norm.callbacks.process('changed') @@ -476,31 +471,53 @@ def _format_cursor_data_override(self, data): # Note if cm.ScalarMappable is depreciated, this functionality should be # implemented as format_cursor_data() on ColorizingArtist. - n = self.cmap.N - if np.ma.getmask(data): + if np.ma.getmask(data) or data is None: return "[]" - normed = self.norm(data) + if len(data.dtype.descr) > 1: + # We have multivariate data encoded as a data type with multiple fields + # NOTE: If any of the fields are masked, "[]" would be returned via + # the if statement above. + s_sig_digits_list = [] + if isinstance(self.cmap, colors.BivarColormap): + n_s = (self.cmap.N, self.cmap.M) + else: + n_s = [part.N for part in self.cmap] + os = [f"{d:-#.{self._sig_digits_from_norm(no, d, n)}g}" + for no, d, n in zip(self.norm.norms, data, n_s)] + return f"[{', '.join(os)}]" + + # scalar data + n = self.cmap.N + g_sig_digits = self._sig_digits_from_norm(self.norm, data, n) + return f"[{data:-#.{g_sig_digits}g}]" + + @staticmethod + def _sig_digits_from_norm(norm, data, n): + # Determines the number of significant digits + # to use for a number given a norm, and n, where n is the + # number of colors in the colormap. + normed = norm(data) if np.isfinite(normed): - if isinstance(self.norm, colors.BoundaryNorm): + if isinstance(norm, colors.BoundaryNorm): # not an invertible normalization mapping - cur_idx = np.argmin(np.abs(self.norm.boundaries - data)) + cur_idx = np.argmin(np.abs(norm.boundaries - data)) neigh_idx = max(0, cur_idx - 1) # use max diff to prevent delta == 0 delta = np.diff( - self.norm.boundaries[neigh_idx:cur_idx + 2] + norm.boundaries[neigh_idx:cur_idx + 2] ).max() - elif self.norm.vmin == self.norm.vmax: + elif norm.vmin == norm.vmax: # singular norms, use delta of 10% of only value - delta = np.abs(self.norm.vmin * .1) + delta = np.abs(norm.vmin * .1) else: # Midpoints of neighboring color intervals. - neighbors = self.norm.inverse( + neighbors = norm.inverse( (int(normed * n) + np.array([0, 1])) / n) delta = abs(neighbors - data).max() g_sig_digits = cbook._g_sig_digits(data, delta) else: g_sig_digits = 3 # Consistent with default below. - return f"[{data:-#.{g_sig_digits}g}]" + return g_sig_digits class _ScalarMappable(_ColorizerInterface): @@ -563,10 +580,18 @@ def set_array(self, A): self._A = None return + A = _ensure_multivariate_data(A, self.norm.n_variables) + A = cbook.safe_masked_invalid(A, copy=True) if not np.can_cast(A.dtype, float, "same_kind"): - raise TypeError(f"Image data of dtype {A.dtype} cannot be " - "converted to float") + if A.dtype.fields is None: + raise TypeError(f"Image data of dtype {A.dtype} cannot be " + f"converted to float") + else: + for key in A.dtype.fields: + if not np.can_cast(A[key].dtype, float, "same_kind"): + raise TypeError(f"Image data of dtype {A.dtype} cannot be " + f"converted to a sequence of floats") self._A = A if not self.norm.scaled(): @@ -615,6 +640,15 @@ def _get_colorizer(cmap, norm, colorizer): cmap : str or `~matplotlib.colors.Colormap`, default: :rc:`image.cmap` The Colormap instance or registered colormap name used to map scalar data to colors.""", + multi_cmap_doc="""\ +cmap : str, `~matplotlib.colors.Colormap`, `~matplotlib.colors.BivarColormap`\ + or `~matplotlib.colors.MultivarColormap`, default: :rc:`image.cmap` + The Colormap instance or registered colormap name used to map + data values to colors. + + Multivariate data is only accepted if a multivariate colormap + (`~matplotlib.colors.BivarColormap` or `~matplotlib.colors.MultivarColormap`) + is used.""", norm_doc="""\ norm : str or `~matplotlib.colors.Normalize`, optional The normalization method used to scale scalar data to the [0, 1] range @@ -629,6 +663,23 @@ def _get_colorizer(cmap, norm, colorizer): list of available scales, call `matplotlib.scale.get_scale_names()`. In that case, a suitable `.Normalize` subclass is dynamically generated and instantiated.""", + multi_norm_doc="""\ +norm : str, `~matplotlib.colors.Normalize` or list, optional + The normalization method used to scale data to the [0, 1] range + before mapping to colors using *cmap*. By default, a linear scaling is + used, mapping the lowest value to 0 and the highest to 1. + + This can be one of the following: + + - An instance of `.Normalize` or one of its subclasses + (see :ref:`colormapnorms`). + - A scale name, i.e. one of "linear", "log", "symlog", "logit", etc. For a + list of available scales, call `matplotlib.scale.get_scale_names()`. + In this case, a suitable `.Normalize` subclass is dynamically generated + and instantiated. + - A list of scale names or `.Normalize` objects matching the number of + variates in the colormap, for use with `~matplotlib.colors.BivarColormap` + or `~matplotlib.colors.MultivarColormap`, i.e. ``["linear", "log"]``.""", vmin_vmax_doc="""\ vmin, vmax : float, optional When using scalar data and no explicit *norm*, *vmin* and *vmax* define @@ -636,6 +687,17 @@ def _get_colorizer(cmap, norm, colorizer): the complete value range of the supplied data. It is an error to use *vmin*/*vmax* when a *norm* instance is given (but using a `str` *norm* name together with *vmin*/*vmax* is acceptable).""", + multi_vmin_vmax_doc="""\ +vmin, vmax : float or list, optional + When using scalar data and no explicit *norm*, *vmin* and *vmax* define + the data range that the colormap covers. By default, the colormap covers + the complete value range of the supplied data. It is an error to use + *vmin*/*vmax* when a *norm* instance is given (but using a `str` *norm* + name together with *vmin*/*vmax* is acceptable). + + A list of values (vmin or vmax) can be used to define independent limits + for each variate when using a `~matplotlib.colors.BivarColormap` or + `~matplotlib.colors.MultivarColormap`.""", ) @@ -701,3 +763,167 @@ def _auto_norm_from_scale(scale_cls): norm = colors.make_norm_from_scale(scale_cls)( colors.Normalize)() return type(norm) + + +def _ensure_norm(norm, n_variates=1): + if n_variates == 1: + _api.check_isinstance((colors.Normalize, str, None), norm=norm) + if norm is None: + norm = colors.Normalize() + elif isinstance(norm, str): + scale_cls = colors._get_scale_cls_from_str(norm) + norm = _auto_norm_from_scale(scale_cls)() + return norm + else: # n_variates > 1 + if not np.iterable(norm): + # include tuple in the list to improve error message + _api.check_isinstance((colors.Normalize, str, None, tuple), norm=norm) + if norm is None: + norm = colors.MultiNorm([None]*n_variates) + elif isinstance(norm, str): # single string + norm = colors.MultiNorm([norm]*n_variates) + else: # multiple string or objects + norm = colors.MultiNorm(norm) + if isinstance(norm, colors.Normalize) and norm.n_variables == n_variates: + return norm + raise ValueError( + "Invalid norm for multivariate colormap with " + f"{n_variates} inputs" + ) + + +def _ensure_cmap(cmap, accept_multivariate=False): + """ + Ensure that we have a `.Colormap` object. + + For internal use to preserve type stability of errors. + + Parameters + ---------- + cmap : None, str, Colormap + + - if a `~matplotlib.colors.Colormap`, + `~matplotlib.colors.MultivarColormap` or + `~matplotlib.colors.BivarColormap`, + return it + - if a string, look it up in three corresponding databases + when not found: raise an error based on the expected shape + - if None, look up the default color map in mpl.colormaps + accept_multivariate : bool, default True + - if False, accept only Colormap, string in mpl.colormaps or None + + Returns + ------- + Colormap + + """ + if not accept_multivariate: + if isinstance(cmap, colors.Colormap): + return cmap + cmap_name = cmap if cmap is not None else mpl.rcParams["image.cmap"] + # use check_in_list to ensure type stability of the exception raised by + # the internal usage of this (ValueError vs KeyError) + if cmap_name not in mpl.colormaps: + _api.check_in_list(sorted(mpl.colormaps), cmap=cmap_name) + + if isinstance(cmap, (colors.Colormap, + colors.BivarColormap, + colors.MultivarColormap)): + return cmap + cmap_name = cmap if cmap is not None else mpl.rcParams["image.cmap"] + if cmap_name in mpl.colormaps: + return mpl.colormaps[cmap_name] + if cmap_name in mpl.multivar_colormaps: + return mpl.multivar_colormaps[cmap_name] + if cmap_name in mpl.bivar_colormaps: + return mpl.bivar_colormaps[cmap_name] + + # this error message is a variant of _api.check_in_list but gives + # additional hints as to how to access multivariate colormaps + raise ValueError(f"{cmap!r} is not a valid value for cmap" + "; supported values for scalar colormaps are " + f"{', '.join(map(repr, sorted(mpl.colormaps)))}\n" + "See `matplotlib.bivar_colormaps()` and" + " `matplotlib.multivar_colormaps()` for" + " bivariate and multivariate colormaps") + + if isinstance(cmap, colors.Colormap): + return cmap + cmap_name = cmap if cmap is not None else mpl.rcParams["image.cmap"] + # use check_in_list to ensure type stability of the exception raised by + # the internal usage of this (ValueError vs KeyError) + if cmap_name not in cm.colormaps: + _api.check_in_list(sorted(cm.colormaps), cmap=cmap_name) + return cm.colormaps[cmap_name] + + +def _ensure_multivariate_data(data, n_variables): + """ + Ensure that the data has dtype with n_variables. + Input data of shape (n_variables, n, m) is converted to an array of shape + (n, m) with data type np.dtype(f'{data.dtype}, ' * n_variables) + Complex data is returned as a view with dtype np.dtype('float64, float64') + or np.dtype('float32, float32') + If n_variables is 1 and data is not of type np.ndarray (i.e. PIL.Image), + the data is returned unchanged. + If data is None, the function returns None + + Parameters + ---------- + n_variables : int + - number of variates in the data + data : np.ndarray, PIL.Image or None + + Returns + ------- + np.ndarray, PIL.Image or None + """ + + if isinstance(data, np.ndarray): + if len(data.dtype.descr) == n_variables: + # pass scalar data + # and already formatted data + return data + elif data.dtype in [np.complex64, np.complex128]: + # pass complex data + if data.dtype == np.complex128: + dt = np.dtype('float64, float64') + else: + dt = np.dtype('float32, float32') + reconstructed = np.ma.frombuffer(data.data, dtype=dt).reshape(data.shape) + if np.ma.is_masked(data): + for descriptor in dt.descr: + reconstructed[descriptor[0]][data.mask] = np.ma.masked + return reconstructed + + if n_variables > 1 and len(data) == n_variables: + # convert data from shape (n_variables, n, m) + # to (n,m) with a new dtype + data = [np.ma.array(part, copy=False) for part in data] + dt = np.dtype(', '.join([f'{part.dtype}' for part in data])) + fields = [descriptor[0] for descriptor in dt.descr] + reconstructed = np.ma.empty(data[0].shape, dtype=dt) + for i, f in enumerate(fields): + if data[i].shape != reconstructed.shape: + raise ValueError("For multivariate data all variates must have same " + f"shape, not {data[0].shape} and {data[i].shape}") + reconstructed[f] = data[i] + if np.ma.is_masked(data[i]): + reconstructed[f][data[i].mask] = np.ma.masked + return reconstructed + + if data is None: + return data + + if n_variables == 1: + # PIL.Image also gets passed here + return data + + elif n_variables == 2: + raise ValueError("Invalid data entry for multivariate data. The data" + " must contain complex numbers, or have a first dimension 2," + " or be of a dtype with 2 fields") + else: + raise ValueError("Invalid data entry for multivariate data. The shape" + f" of the data must have a first dimension {n_variables}" + f" or be of a dtype with {n_variables} fields") diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index e3c3b39e8bb2..d5ca9c959e29 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -1420,10 +1420,10 @@ def __init__(self, colormaps, combination_mode, name='multivariate colormap'): combination_mode: str, 'sRGB_add' or 'sRGB_sub' Describe how colormaps are combined in sRGB space - - If 'sRGB_add' -> Mixing produces brighter colors - `sRGB = sum(colors)` - - If 'sRGB_sub' -> Mixing produces darker colors - `sRGB = 1 - sum(1 - colors)` + - If 'sRGB_add': Mixing produces brighter colors + ``sRGB = sum(colors)`` + - If 'sRGB_sub': Mixing produces darker colors + ``sRGB = 1 - sum(1 - colors)`` name : str, optional The name of the colormap family. """ @@ -1595,15 +1595,15 @@ def with_extremes(self, *, bad=None, under=None, over=None): Parameters ---------- - bad: :mpltype:`color`, default: None + bad : :mpltype:`color`, default: None If Matplotlib color, the bad value is set accordingly in the copy - under tuple of :mpltype:`color`, default: None - If tuple, the `under` value of each component is set with the values + under : tuple of :mpltype:`color`, default: None + If tuple, the ``under`` value of each component is set with the values from the tuple. - over tuple of :mpltype:`color`, default: None - If tuple, the `over` value of each component is set with the values + over : tuple of :mpltype:`color`, default: None + If tuple, the ``over`` value of each component is set with the values from the tuple. Returns @@ -2320,6 +2320,11 @@ def __init__(self, vmin=None, vmax=None, clip=False): self._scale = None self.callbacks = cbook.CallbackRegistry(signals=["changed"]) + @property + def n_variables(self): + # To be overridden by subclasses with multiple inputs + return 1 + @property def vmin(self): return self._vmin @@ -3219,6 +3224,224 @@ def inverse(self, value): return value +class MultiNorm(Normalize): + """ + A mixin class which contains multiple scalar norms + """ + + def __init__(self, norms, vmin=None, vmax=None, clip=False): + """ + Parameters + ---------- + norms : List of strings or `Normalize` objects + The constituent norms. The list must have a minimum length of 2. + vmin, vmax : float, None, or list of float or None + Limits of the constituent norms. + If a list, each value is assigned to each of the constituent + norms. Single values are repeated to form a list of appropriate size. + + clip : bool or list of bools, default: False + Determines the behavior for mapping values outside the range + ``[vmin, vmax]`` for the constituent norms. + If a list, each value is assigned to each of the constituent + norms. Single values are repeated to form a list of appropriate size. + + """ + + if isinstance(norms, str) or not np.iterable(norms): + raise ValueError("A MultiNorm must be assigned multiple norms") + + norms = [*norms] + for i, n in enumerate(norms): + if n is None: + norms[i] = Normalize() + elif isinstance(n, str): + scale_cls = _get_scale_cls_from_str(n) + norms[i] = mpl.colorizer._auto_norm_from_scale(scale_cls)() + elif not isinstance(n, Normalize): + raise ValueError( + "MultiNorm must be assigned multiple norms, where each norm " + f"is of type `None` `str`, or `Normalize`, not {type(n)}") + + # Convert the list of norms to a tuple to make it immutable. + # If there is a use case for swapping a single norm, we can add support for + # that later + self._norms = tuple(norms) + + self.callbacks = cbook.CallbackRegistry(signals=["changed"]) + + self.vmin = vmin + self.vmax = vmax + self.clip = clip + + for n in self._norms: + n.callbacks.connect('changed', self._changed) + + @property + def n_variables(self): + return len(self._norms) + + @property + def norms(self): + return self._norms + + @property + def vmin(self): + return tuple(n.vmin for n in self._norms) + + @vmin.setter + def vmin(self, value): + value = np.broadcast_to(value, self.n_variables) + with self.callbacks.blocked(): + for i, v in enumerate(value): + if v is not None: + self.norms[i].vmin = v + self._changed() + + @property + def vmax(self): + return tuple(n.vmax for n in self._norms) + + @vmax.setter + def vmax(self, value): + value = np.broadcast_to(value, self.n_variables) + with self.callbacks.blocked(): + for i, v in enumerate(value): + if v is not None: + self.norms[i].vmax = v + self._changed() + + @property + def clip(self): + return tuple(n.clip for n in self._norms) + + @clip.setter + def clip(self, value): + value = np.broadcast_to(value, self.n_variables) + with self.callbacks.blocked(): + for i, v in enumerate(value): + if v is not None: + self.norms[i].clip = v + self._changed() + + def _changed(self): + """ + Call this whenever the norm is changed to notify all the + callback listeners to the 'changed' signal. + """ + self.callbacks.process('changed') + + def __call__(self, value, clip=None): + """ + Normalize the data and return the normalized data. + + Each variate in the input is assigned to the constituent norm. + + Parameters + ---------- + value + Data to normalize. Must be of length `n_variables` or have a data type with + `n_variables` fields. + clip : list of bools or bool, optional + See the description of the parameter *clip* in Normalize. + If ``None``, defaults to ``self.clip`` (which defaults to + ``False``). + + Returns + ------- + Data + Normalized input values as a list of length `n_variables` + + Notes + ----- + If not already initialized, ``self.vmin`` and ``self.vmax`` are + initialized using ``self.autoscale_None(value)``. + """ + if clip is None: + clip = self.clip + elif not np.iterable(clip): + clip = [clip]*self.n_variables + + value = self._iterable_variates_in_data(value, self.n_variables) + result = [n(v, clip=c) for n, v, c in zip(self.norms, value, clip)] + return result + + def inverse(self, value): + """ + Map the normalized value (i.e., index in the colormap) back to image data value. + + Parameters + ---------- + value + Normalized value. Must be of length `n_variables` or have a data type with + `n_variables` fields. + """ + value = self._iterable_variates_in_data(value, self.n_variables) + result = [n.inverse(v) for n, v in zip(self.norms, value)] + return result + + def autoscale(self, A): + """ + For each constituent norm, Set *vmin*, *vmax* to min, max of the corresponding + variate in *A*. + """ + with self.callbacks.blocked(): + # Pause callbacks while we are updating so we only get + # a single update signal at the end + A = self._iterable_variates_in_data(A, self.n_variables) + for n, a in zip(self.norms, A): + n.autoscale(a) + self._changed() + + def autoscale_None(self, A): + """ + If *vmin* or *vmax* are not set on any constituent norm, + use the min/max of the corresponding variate in *A* to set them. + + Parameters + ---------- + A + Data, must be of length `n_variables` or be an np.ndarray type with + `n_variables` fields. + """ + with self.callbacks.blocked(): + A = self._iterable_variates_in_data(A, self.n_variables) + for n, a in zip(self.norms, A): + n.autoscale_None(a) + self._changed() + + def scaled(self): + """Return whether both *vmin* and *vmax* are set on all constituent norms""" + return all([(n.vmin is not None and n.vmax is not None) for n in self.norms]) + + @staticmethod + def _iterable_variates_in_data(data, n_variables): + """ + Provides an iterable over the variates contained in the data. + + An input array with `n_variables` fields is returned as a list of length n + referencing slices of the original array. + + Parameters + ---------- + data : np.ndarray, tuple or list + The input array. It must either be an array with n_variables fields or have + a length (n_variables) + + Returns + ------- + list of np.ndarray + + """ + if isinstance(data, np.ndarray) and data.dtype.fields is not None: + data = [data[descriptor[0]] for descriptor in data.dtype.descr] + if len(data) != n_variables: + raise ValueError("The input to this `MultiNorm` must be of shape " + f"({n_variables}, ...), or have a data type with " + f"{n_variables} fields.") + return data + + def rgb_to_hsv(arr): """ Convert an array of float RGB values (in the range [0, 1]) to HSV values. @@ -3856,3 +4079,34 @@ def from_levels_and_colors(levels, colors, extend='neither'): norm = BoundaryNorm(levels, ncolors=n_data_colors) return cmap, norm + + +def _get_scale_cls_from_str(scale_as_str): + """ + Returns the scale class from a string. + + Used in the creation of norms from a string to ensure a reasonable error + in the case where an invalid string is used. This would normally use + `_api.check_getitem()`, which would produce the error + > 'not_a_norm' is not a valid value for norm; supported values are + > 'linear', 'log', 'symlog', 'asinh', 'logit', 'function', 'functionlog' + which is misleading because the norm keyword also accepts `Normalize` objects. + + Parameters + ---------- + scale_as_str : string + A string corresponding to a scale + + Returns + ------- + A subclass of ScaleBase. + + """ + try: + scale_cls = scale._scale_mapping[scale_as_str] + except KeyError: + raise ValueError( + "Invalid norm str name; the following values are " + f"supported: {', '.join(scale._scale_mapping)}" + ) from None + return scale_cls diff --git a/lib/matplotlib/colors.pyi b/lib/matplotlib/colors.pyi index 3e761c949068..c7233e7da6fb 100644 --- a/lib/matplotlib/colors.pyi +++ b/lib/matplotlib/colors.pyi @@ -263,6 +263,8 @@ class Normalize: @vmax.setter def vmax(self, value: float | None) -> None: ... @property + def n_variables(self) -> int: ... + @property def clip(self) -> bool: ... @clip.setter def clip(self, value: bool) -> None: ... @@ -387,6 +389,34 @@ class BoundaryNorm(Normalize): class NoNorm(Normalize): ... +class MultiNorm(Normalize): + # Here "type: ignore[override]" is used for functions with a return type + # that differs from the function in the base class. + # i.e. where `MultiNorm` returns a tuple and Normalize returns a `float` etc. + def __init__( + self, + norms: ArrayLike, + vmin: ArrayLike | float | None = ..., + vmax: ArrayLike | float | None = ..., + clip: ArrayLike | bool = ... + ) -> None: ... + @property + def norms(self) -> tuple[Normalize, ...]: ... + @property # type: ignore[override] + def vmin(self) -> tuple[float | None, ...]: ... + @vmin.setter + def vmin(self, value: ArrayLike | float | None) -> None: ... + @property # type: ignore[override] + def vmax(self) -> tuple[float | None, ...]: ... + @vmax.setter + def vmax(self, value: ArrayLike | float | None) -> None: ... + @property # type: ignore[override] + def clip(self) -> tuple[bool, ...]: ... + @clip.setter + def clip(self, value: ArrayLike | bool) -> None: ... + def __call__(self, value: ArrayLike, clip: ArrayLike | bool | None = ...) -> list: ... # type: ignore[override] + def inverse(self, value: ArrayLike) -> list: ... # type: ignore[override] + def rgb_to_hsv(arr: ArrayLike) -> np.ndarray: ... def hsv_to_rgb(hsv: ArrayLike) -> np.ndarray: ... diff --git a/lib/matplotlib/streamplot.py b/lib/matplotlib/streamplot.py index ece8bebf8192..725fff7b23fd 100644 --- a/lib/matplotlib/streamplot.py +++ b/lib/matplotlib/streamplot.py @@ -6,7 +6,7 @@ import numpy as np import matplotlib as mpl -from matplotlib import _api, cm, patches +from matplotlib import _api, colorizer, patches import matplotlib.colors as mcolors import matplotlib.collections as mcollections import matplotlib.lines as mlines @@ -228,7 +228,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, if use_multicolor_lines: if norm is None: norm = mcolors.Normalize(color.min(), color.max()) - cmap = cm._ensure_cmap(cmap) + cmap = colorizer._ensure_cmap(cmap) streamlines = [] arrows = [] diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index 8d0f3467f045..39583574b04f 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -1828,3 +1828,48 @@ def test_LinearSegmentedColormap_from_list_value_color_tuple(): cmap([value for value, _ in value_color_tuples]), to_rgba_array([color for _, color in value_color_tuples]), ) + + +def test_multi_norm(): + # tests for mcolors.MultiNorm + + # test wrong input + with pytest.raises(ValueError, + match="A MultiNorm must be assigned multiple norms"): + mcolors.MultiNorm("bad_norm_name") + with pytest.raises(ValueError, + match="Invalid norm str name"): + mcolors.MultiNorm(["bad_norm_name"]) + with pytest.raises(ValueError, + match="MultiNorm must be assigned multiple norms, " + "where each norm is of type `None`"): + mcolors.MultiNorm([4]) + + # test get vmin, vmax + norm = mpl.colors.MultiNorm(['linear', 'log']) + norm.vmin = 1 + norm.vmax = 2 + assert norm.vmin[0] == 1 + assert norm.vmin[1] == 1 + assert norm.vmax[0] == 2 + assert norm.vmax[1] == 2 + + # test call with clip + assert_array_equal(norm([3, 3], clip=False), [2.0, 1.584962500721156]) + assert_array_equal(norm([3, 3], clip=True), [1.0, 1.0]) + assert_array_equal(norm([3, 3], clip=[True, False]), [1.0, 1.584962500721156]) + norm.clip = False + assert_array_equal(norm([3, 3]), [2.0, 1.584962500721156]) + norm.clip = True + assert_array_equal(norm([3, 3]), [1.0, 1.0]) + norm.clip = [True, False] + assert_array_equal(norm([3, 3]), [1.0, 1.584962500721156]) + norm.clip = True + + # test inverse + assert_array_almost_equal(norm.inverse([0.5, 0.5849625007211562]), [1.5, 1.5]) + + # test autoscale + norm.autoscale([[0, 1, 2, 3], [0.1, 1, 2, 3]]) + assert_array_equal(norm.vmin, [0, 0.1]) + assert_array_equal(norm.vmax, [3, 3]) diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index 0e9f3fb37fbd..1e1b9ed244c8 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -14,7 +14,7 @@ import matplotlib as mpl from matplotlib import ( - colors, image as mimage, patches, pyplot as plt, style, rcParams) + cbook, colors, image as mimage, patches, pyplot as plt, style, rcParams) from matplotlib.image import (AxesImage, BboxImage, FigureImage, NonUniformImage, PcolorImage) from matplotlib.testing.decorators import check_figures_equal, image_comparison @@ -1130,8 +1130,14 @@ def test_image_cursor_formatting(): data = np.ma.masked_array([0], mask=[False]) assert im.format_cursor_data(data) == '[0]' - data = np.nan - assert im.format_cursor_data(data) == '[nan]' + # This used to test + # > data = np.nan + # > assert im.format_cursor_data(data) == '[nan]' + # However, a value of nan will be masked by `cbook.safe_masked_invalid(data)` + # called by `image._ImageBase._normalize_image_array(data)` + # The test is therefore changed to: + data = cbook.safe_masked_invalid(np.array(np.nan)) + assert im.format_cursor_data(data) == '[]' @check_figures_equal(extensions=['png', 'pdf', 'svg'])