diff --git a/doc/users/next_whats_new/rgba_pcolormesh.rst b/doc/users/next_whats_new/rgba_pcolormesh.rst new file mode 100644 index 000000000000..4088677867c0 --- /dev/null +++ b/doc/users/next_whats_new/rgba_pcolormesh.rst @@ -0,0 +1,16 @@ +``pcolormesh`` accepts RGB(A) colors +------------------------------------ + +The `~.Axes.pcolormesh` method can now handle explicit colors +specified with RGB(A) values. To specify colors, the array must be 3D +with a shape of ``(M, N, [3, 4])``. + +.. plot:: + :include-source: true + + import matplotlib.pyplot as plt + import numpy as np + + colors = np.linspace(0, 1, 90).reshape((5, 6, 3)) + plt.pcolormesh(colors) + plt.show() diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index cd0660470c2c..ecb6f5e2098b 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -5679,7 +5679,7 @@ def _pcolorargs(self, funcname, *args, shading='auto', **kwargs): if len(args) == 1: C = np.asanyarray(args[0]) - nrows, ncols = C.shape + nrows, ncols = C.shape[:2] if shading in ['gouraud', 'nearest']: X, Y = np.meshgrid(np.arange(ncols), np.arange(nrows)) else: @@ -5708,7 +5708,7 @@ def _pcolorargs(self, funcname, *args, shading='auto', **kwargs): X = X.data # strip mask as downstream doesn't like it... if isinstance(Y, np.ma.core.MaskedArray): Y = Y.data - nrows, ncols = C.shape + nrows, ncols = C.shape[:2] else: raise _api.nargs_error(funcname, takes="1 or 3", given=len(args)) @@ -6045,9 +6045,18 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, Parameters ---------- - C : 2D array-like - The color-mapped values. Color-mapping is controlled by *cmap*, - *norm*, *vmin*, and *vmax*. + C : array-like + The mesh data. Supported array shapes are: + + - (M, N) or M*N: a mesh with scalar data. The values are mapped to + colors using normalization and a colormap. See parameters *norm*, + *cmap*, *vmin*, *vmax*. + - (M, N, 3): an image with RGB values (0-1 float or 0-255 int). + - (M, N, 4): an image with RGBA values (0-1 float or 0-255 int), + i.e. including transparency. + + The first two dimensions (M, N) define the rows and columns of + the mesh data. X, Y : array-like, optional The coordinates of the corners of quadrilaterals of a pcolormesh:: @@ -6207,8 +6216,9 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None, X, Y, C, shading = self._pcolorargs('pcolormesh', *args, shading=shading, kwargs=kwargs) coords = np.stack([X, Y], axis=-1) - # convert to one dimensional array - C = C.ravel() + # convert to one dimensional array, except for 3D RGB(A) arrays + if C.ndim != 3: + C = C.ravel() kwargs.setdefault('snap', mpl.rcParams['pcolormesh.snap']) @@ -6384,14 +6394,10 @@ def pcolorfast(self, *args, alpha=None, norm=None, cmap=None, vmin=None, if style == "quadmesh": # data point in each cell is value at lower left corner coords = np.stack([x, y], axis=-1) - if np.ndim(C) == 2: - qm_kwargs = {"array": np.ma.ravel(C)} - elif np.ndim(C) == 3: - qm_kwargs = {"color": np.ma.reshape(C, (-1, C.shape[-1]))} - else: + if np.ndim(C) not in {2, 3}: raise ValueError("C must be 2D or 3D") collection = mcoll.QuadMesh( - coords, **qm_kwargs, + coords, array=C, alpha=alpha, cmap=cmap, norm=norm, antialiased=False, edgecolors="none") self.add_collection(collection, autolim=False) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index e57ac69bed51..3a8e3bf6266b 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -1955,7 +1955,16 @@ def set_array(self, A): Parameters ---------- - A : (M, N) array-like or M*N array-like + A : array-like + The mesh data. Supported array shapes are: + + - (M, N) or M*N: a mesh with scalar data. The values are mapped to + colors using normalization and a colormap. See parameters *norm*, + *cmap*, *vmin*, *vmax*. + - (M, N, 3): an image with RGB values (0-1 float or 0-255 int). + - (M, N, 4): an image with RGBA values (0-1 float or 0-255 int), + i.e. including transparency. + If the values are provided as a 2D grid, the shape must match the coordinates grid. If the values are 1D, they are reshaped to 2D. M, N follow from the coordinates grid, where the coordinates grid @@ -1976,11 +1985,19 @@ def set_array(self, A): if len(shape) == 1: if shape[0] != (h*w): faulty_data = True - elif shape != (h, w): - if np.prod(shape) == (h * w): + elif shape[:2] != (h, w): + if np.prod(shape[:2]) == (h * w): misshapen_data = True else: faulty_data = True + elif len(shape) == 3 and shape[2] not in {3, 4}: + # 3D data must be RGB(A) (h, w, [3,4]) + # the (h, w) check is taken care of above + raise ValueError( + f"For X ({width}) and Y ({height}) with " + f"{self._shading} shading, the expected shape of " + f"A with RGB(A) colors is ({h}, {w}, [3 or 4]), not " + f"{A.shape}") if misshapen_data: raise ValueError( diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index deb6c3fb2ba1..e70bfc71f9aa 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -1298,6 +1298,17 @@ def test_pcolormesh_alpha(): ax4.pcolormesh(Qx, Qy, Z, cmap=cmap, shading='gouraud', zorder=1) +@pytest.mark.parametrize("dims,alpha", [(3, 1), (4, 0.5)]) +@check_figures_equal(extensions=["png"]) +def test_pcolormesh_rgba(fig_test, fig_ref, dims, alpha): + ax = fig_test.subplots() + c = np.ones((5, 6, dims), dtype=float) / 2 + ax.pcolormesh(c) + + ax = fig_ref.subplots() + ax.pcolormesh(c[..., 0], cmap="gray", vmin=0, vmax=1, alpha=alpha) + + @image_comparison(['pcolormesh_datetime_axis.png'], style='mpl20') def test_pcolormesh_datetime_axis(): # Remove this line when this test image is regenerated. diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 0d1e621d84cb..5b0d64657b32 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -830,6 +830,24 @@ def test_quadmesh_set_array_validation(): r"are incompatible with X \(11\) and/or Y \(8\)"): coll.set_array(z.ravel()) + # RGB(A) tests + z = np.ones((9, 6, 3)) # RGB with wrong X/Y dims + with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 3\) " + r"are incompatible with X \(11\) and/or Y \(8\)"): + coll.set_array(z) + + z = np.ones((9, 6, 4)) # RGBA with wrong X/Y dims + with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 4\) " + r"are incompatible with X \(11\) and/or Y \(8\)"): + coll.set_array(z) + + z = np.ones((7, 10, 2)) # Right X/Y dims, bad 3rd dim + with pytest.raises(ValueError, match=r"For X \(11\) and Y \(8\) with " + r"flat shading, the expected shape of " + r"A with RGB\(A\) colors is \(7, 10, \[3 or 4\]\), " + r"not \(7, 10, 2\)"): + coll.set_array(z) + x = np.arange(10) y = np.arange(7) z = np.random.random((7, 10)) @@ -1048,6 +1066,9 @@ def test_array_wrong_dimensions(): pc = plt.pcolormesh(z) pc.set_array(z) # 2D is OK for Quadmesh pc.update_scalarmappable() + # 3D RGB is OK as well + z = np.arange(36).reshape(3, 4, 3) + pc.set_array(z) def test_get_segments():