Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5804110

Browse files
committed
ENH: Allow RGB(A) arrays for pcolormesh
Allow a user to set the array values to explicit colors with RGB(A) values in the 3rd dimension.
1 parent 0aac9f1 commit 5804110

File tree

5 files changed

+85
-10
lines changed

5 files changed

+85
-10
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
``pcolormesh`` accepts RGB(A) colors
2+
------------------------------------
3+
4+
The `~.Axes.pcolormesh` method can now handle explicit colors
5+
specified with RGB(A) values. To specify colors, the array must be 3D
6+
with a shape of ``(M, N, [3, 4])``.
7+
8+
.. plot::
9+
:include-source: true
10+
11+
import matplotlib.pyplot as plt
12+
import numpy as np
13+
14+
colors = np.linspace(0, 1, 90).reshape((5, 6, 3))
15+
plt.pcolormesh(colors)
16+
plt.show()

lib/matplotlib/axes/_axes.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5679,7 +5679,7 @@ def _pcolorargs(self, funcname, *args, shading='auto', **kwargs):
56795679

56805680
if len(args) == 1:
56815681
C = np.asanyarray(args[0])
5682-
nrows, ncols = C.shape
5682+
nrows, ncols = C.shape[:2]
56835683
if shading in ['gouraud', 'nearest']:
56845684
X, Y = np.meshgrid(np.arange(ncols), np.arange(nrows))
56855685
else:
@@ -5708,7 +5708,7 @@ def _pcolorargs(self, funcname, *args, shading='auto', **kwargs):
57085708
X = X.data # strip mask as downstream doesn't like it...
57095709
if isinstance(Y, np.ma.core.MaskedArray):
57105710
Y = Y.data
5711-
nrows, ncols = C.shape
5711+
nrows, ncols = C.shape[:2]
57125712
else:
57135713
raise _api.nargs_error(funcname, takes="1 or 3", given=len(args))
57145714

@@ -6045,9 +6045,18 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None,
60456045
60466046
Parameters
60476047
----------
6048-
C : 2D array-like
6049-
The color-mapped values. Color-mapping is controlled by *cmap*,
6050-
*norm*, *vmin*, and *vmax*.
6048+
C : array-like
6049+
The mesh data. Supported array shapes are:
6050+
6051+
- (M, N) or M*N: a mesh with scalar data. The values are mapped to
6052+
colors using normalization and a colormap. See parameters *norm*,
6053+
*cmap*, *vmin*, *vmax*.
6054+
- (M, N, 3): an image with RGB values (0-1 float or 0-255 int).
6055+
- (M, N, 4): an image with RGBA values (0-1 float or 0-255 int),
6056+
i.e. including transparency.
6057+
6058+
The first two dimensions (M, N) define the rows and columns of
6059+
the mesh data.
60516060
60526061
X, Y : array-like, optional
60536062
The coordinates of the corners of quadrilaterals of a pcolormesh::
@@ -6207,8 +6216,9 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None,
62076216
X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
62086217
shading=shading, kwargs=kwargs)
62096218
coords = np.stack([X, Y], axis=-1)
6210-
# convert to one dimensional array
6211-
C = C.ravel()
6219+
# convert to one dimensional array, except for 3D RGB(A) arrays
6220+
if C.ndim != 3:
6221+
C = C.ravel()
62126222

62136223
kwargs.setdefault('snap', mpl.rcParams['pcolormesh.snap'])
62146224

lib/matplotlib/collections.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,7 +1955,16 @@ def set_array(self, A):
19551955
19561956
Parameters
19571957
----------
1958-
A : (M, N) array-like or M*N array-like
1958+
A : array-like
1959+
The mesh data. Supported array shapes are:
1960+
1961+
- (M, N) or M*N: a mesh with scalar data. The values are mapped to
1962+
colors using normalization and a colormap. See parameters *norm*,
1963+
*cmap*, *vmin*, *vmax*.
1964+
- (M, N, 3): an image with RGB values (0-1 float or 0-255 int).
1965+
- (M, N, 4): an image with RGBA values (0-1 float or 0-255 int),
1966+
i.e. including transparency.
1967+
19591968
If the values are provided as a 2D grid, the shape must match the
19601969
coordinates grid. If the values are 1D, they are reshaped to 2D.
19611970
M, N follow from the coordinates grid, where the coordinates grid
@@ -1976,11 +1985,19 @@ def set_array(self, A):
19761985
if len(shape) == 1:
19771986
if shape[0] != (h*w):
19781987
faulty_data = True
1979-
elif shape != (h, w):
1980-
if np.prod(shape) == (h * w):
1988+
elif shape[:2] != (h, w):
1989+
if np.prod(shape[:2]) == (h * w):
19811990
misshapen_data = True
19821991
else:
19831992
faulty_data = True
1993+
elif len(shape) == 3 and shape[2] not in {3, 4}:
1994+
# 3D data must be RGB(A) (h, w, [3,4])
1995+
# the (h, w) check is taken care of above
1996+
raise ValueError(
1997+
f"For X ({width}) and Y ({height}) with "
1998+
f"{self._shading} shading, the expected shape of "
1999+
f"A with RGB(A) colors is ({h}, {w}, [3 or 4]), not "
2000+
f"{A.shape}")
19842001

19852002
if misshapen_data:
19862003
raise ValueError(

lib/matplotlib/tests/test_axes.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,17 @@ def test_pcolormesh_alpha():
12981298
ax4.pcolormesh(Qx, Qy, Z, cmap=cmap, shading='gouraud', zorder=1)
12991299

13001300

1301+
@pytest.mark.parametrize("dims,alpha", [(3, 1), (4, 0.5)])
1302+
@check_figures_equal(extensions=["png"])
1303+
def test_pcolormesh_rgba(fig_test, fig_ref, dims, alpha):
1304+
ax = fig_test.subplots()
1305+
c = np.ones((5, 6, dims), dtype=float) / 2
1306+
ax.pcolormesh(c)
1307+
1308+
ax = fig_ref.subplots()
1309+
ax.pcolormesh(c[..., 0], cmap="gray", vmin=0, vmax=1, alpha=alpha)
1310+
1311+
13011312
@image_comparison(['pcolormesh_datetime_axis.png'], style='mpl20')
13021313
def test_pcolormesh_datetime_axis():
13031314
# Remove this line when this test image is regenerated.

lib/matplotlib/tests/test_collections.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,24 @@ def test_quadmesh_set_array_validation():
830830
r"are incompatible with X \(11\) and/or Y \(8\)"):
831831
coll.set_array(z.ravel())
832832

833+
# RGB(A) tests
834+
z = np.ones((9, 6, 3)) # RGB with wrong X/Y dims
835+
with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 3\) "
836+
r"are incompatible with X \(11\) and/or Y \(8\)"):
837+
coll.set_array(z)
838+
839+
z = np.ones((9, 6, 4)) # RGBA with wrong X/Y dims
840+
with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 4\) "
841+
r"are incompatible with X \(11\) and/or Y \(8\)"):
842+
coll.set_array(z)
843+
844+
z = np.ones((7, 10, 2)) # Right X/Y dims, bad 3rd dim
845+
with pytest.raises(ValueError, match=r"For X \(11\) and Y \(8\) with "
846+
r"flat shading, the expected shape of "
847+
r"A with RGB\(A\) colors is \(7, 10, \[3 or 4\]\), "
848+
r"not \(7, 10, 2\)"):
849+
coll.set_array(z)
850+
833851
x = np.arange(10)
834852
y = np.arange(7)
835853
z = np.random.random((7, 10))
@@ -1048,6 +1066,9 @@ def test_array_wrong_dimensions():
10481066
pc = plt.pcolormesh(z)
10491067
pc.set_array(z) # 2D is OK for Quadmesh
10501068
pc.update_scalarmappable()
1069+
# 3D RGB is OK as well
1070+
z = np.arange(36).reshape(3, 4, 3)
1071+
pc.set_array(z)
10511072

10521073

10531074
def test_get_segments():

0 commit comments

Comments
 (0)