diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index 679f368bae30..ea3f848f6055 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -752,16 +752,13 @@ def __init__(self, name, N=256, *, bad=None, under=None, over=None): #: `matplotlib.colorbar.Colorbar` constructor. self.colorbar_extend = False - def __call__(self, X, alpha=None, bytes=False): + def __call__(self, X, alpha=None, bytes=False, by_index='auto'): r""" Parameters ---------- X : float or int or array-like - The data value(s) to convert to RGBA. - For floats, *X* should be in the interval ``[0.0, 1.0]`` to - return the RGBA values ``X*100`` percent along the Colormap line. - For integers, *X* should be in the interval ``[0, Colormap.N)`` to - return RGBA values *indexed* from the Colormap with index ``X``. + The data value(s) to convert to RGBA. The interpretation (normalized + values or index values) depends on *by_index*. alpha : float or array-like or None Alpha must be a scalar between 0 and 1, a sequence of such floats with shape matching X, or None. @@ -769,27 +766,35 @@ def __call__(self, X, alpha=None, bytes=False): If False (default), the returned RGBA values will be floats in the interval ``[0, 1]`` otherwise they will be `numpy.uint8`\s in the interval ``[0, 255]``. + by_index: bool or 'auto', default: 'auto' + How the input *X* is interpreted: + + - If True, *X* is treated as an array of indices into the Colormap + lookup table, i.e. the range ``[0, Colormap.N)`` covers the colormap. + - If False, *X* is treated normalized values, i.e. the range + ``[0.0, 1.0]`` covers the colormap. + - If 'auto', the type of *X* is used to determine the interpretation. + float inputs are treated like ``by_index=False``, integer inputs + are treated like ``by_index=True``. Returns ------- Tuple of RGBA values if X is scalar, otherwise an array of RGBA values with a shape of ``X.shape + (4, )``. """ - rgba, mask = self._get_rgba_and_mask(X, alpha=alpha, bytes=bytes) + rgba, mask = self._get_rgba_and_mask(X, alpha=alpha, bytes=bytes, + by_index=by_index) if not np.iterable(X): rgba = tuple(rgba) return rgba - def _get_rgba_and_mask(self, X, alpha=None, bytes=False): + def _get_rgba_and_mask(self, X, alpha=None, bytes=False, by_index='auto'): r""" Parameters ---------- X : float or int or array-like - The data value(s) to convert to RGBA. - For floats, *X* should be in the interval ``[0.0, 1.0]`` to - return the RGBA values ``X*100`` percent along the Colormap line. - For integers, *X* should be in the interval ``[0, Colormap.N)`` to - return RGBA values *indexed* from the Colormap with index ``X``. + The data value(s) to convert to RGBA. The interpretation (normalized + values or index values) depends on *by_index*. alpha : float or array-like or None Alpha must be a scalar between 0 and 1, a sequence of such floats with shape matching X, or None. @@ -797,6 +802,16 @@ def _get_rgba_and_mask(self, X, alpha=None, bytes=False): If False (default), the returned RGBA values will be floats in the interval ``[0, 1]`` otherwise they will be `numpy.uint8`\s in the interval ``[0, 255]``. + by_index: bool or 'auto', default: 'auto' + How the input *X* is interpreted: + + - If True, *X* is treated as an array of indices into the Colormap + lookup table, i.e. the range ``[0, Colormap.N)`` covers the colormap. + - If False, *X* is treated normalized values, i.e. the range + ``[0.0, 1.0]`` covers the colormap. + - If 'auto', the type of *X* is used to determine the interpretation. + float inputs are treated like ``by_index=False``, integer inputs + are treated like ``by_index=True``. Returns ------- @@ -812,7 +827,7 @@ def _get_rgba_and_mask(self, X, alpha=None, bytes=False): if not xa.dtype.isnative: # Native byteorder is faster. xa = xa.byteswap().view(xa.dtype.newbyteorder()) - if xa.dtype.kind == "f": + if by_index is False or (by_index == 'auto' and xa.dtype.kind == "f"): xa *= self.N # xa == 1 (== N after multiplication) is not out of range. xa[xa == self.N] = self.N - 1 diff --git a/lib/matplotlib/colors.pyi b/lib/matplotlib/colors.pyi index 07bf01b8f995..1a4a5f61159e 100644 --- a/lib/matplotlib/colors.pyi +++ b/lib/matplotlib/colors.pyi @@ -80,7 +80,11 @@ class Colormap: ) -> None: ... @overload def __call__( - self, X: Sequence[float] | np.ndarray, alpha: ArrayLike | None = ..., bytes: bool = ... + self, + X: Sequence[float] | np.ndarray, + alpha: ArrayLike | None = ..., + bytes: bool = ..., + by_index: bool | Literal['auto'] = ..., ) -> np.ndarray: ... @overload def __call__( diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index 42f364848b66..1777d64f7eca 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -203,6 +203,22 @@ def test_colormap_invalid(): assert_array_equal(cmap(np.nan), [0., 0., 0., 0.]) +def test_colormap_by_index(): + cmap = mpl.colormaps["plasma"] + N = cmap.N + assert_array_equal(cmap([0., 0.5, 1.]), cmap([0., 0.5, 1.], by_index=False)) + # auto-detection based on input type by default + assert_array_equal(cmap([0., 0.5, 1.]), cmap([0, N//2, N])) + # by_index=True forces floats as index interpretation + assert_array_equal(cmap([0., N/2, float(N)], by_index=True), cmap([0, N//2, N])) + + cmap = mpl.colormaps["plasma"].with_extremes(over='r', under='b', bad='g') + + assert cmap(1) == cmap(1/N) + assert cmap(1., by_index=True) == cmap(1/N) + assert cmap(1, by_index=False) == cmap(1.) + + def test_colormap_return_types(): """ Make sure that tuples are returned for scalar input and