Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d08f8ea

Browse files
committed
Quadmesh validates set_array dimensions
1 parent 9c530bc commit d08f8ea

File tree

2 files changed

+90
-3
lines changed

2 files changed

+90
-3
lines changed

lib/matplotlib/collections.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2024,14 +2024,15 @@ def __init__(self, *args, **kwargs):
20242024
kwargs.setdefault("pickradius", 0)
20252025
# end of signature deprecation code
20262026

2027-
super().__init__(**kwargs)
20282027
_api.check_shape((None, None, 2), coordinates=coords)
20292028
self._coordinates = coords
20302029
self._antialiased = antialiased
20312030
self._shading = shading
2032-
20332031
self._bbox = transforms.Bbox.unit()
20342032
self._bbox.update_from_data_xy(self._coordinates.reshape(-1, 2))
2033+
# super init delayed after own init because array kwarg requires
2034+
# self._coordinates and self._shading
2035+
super().__init__(**kwargs)
20352036

20362037
# Only needed during signature deprecation
20372038
__init__.__signature__ = inspect.signature(
@@ -2047,6 +2048,53 @@ def set_paths(self):
20472048
self._paths = self._convert_mesh_to_paths(self._coordinates)
20482049
self.stale = True
20492050

2051+
def set_array(self, A):
2052+
"""
2053+
Set the data values.
2054+
2055+
Parameters
2056+
----------
2057+
A : (M, N) array-like or M*N array-like
2058+
If the values are provided as a 2D grid, the shape must match the
2059+
coordinates grid. If the values are 1D, they are reshaped to 2D.
2060+
M, N follow from the coordinates grid, where the coordinates grid
2061+
shape is (M, N) for 'gouraud' *shading* and (M+1, N+1) for 'flat'
2062+
shading.
2063+
"""
2064+
height, width = self._coordinates.shape[0:-1]
2065+
misshapen_data = False
2066+
faulty_data = False
2067+
2068+
if self._shading == 'flat':
2069+
h, w = height-1, width-1
2070+
else:
2071+
h, w = height, width
2072+
2073+
if A is not None:
2074+
shape = np.shape(A)
2075+
if len(shape) == 1:
2076+
if shape[0] != (h*w):
2077+
faulty_data = True
2078+
elif shape != (h, w):
2079+
if np.prod(shape) == (h * w):
2080+
misshapen_data = True
2081+
else:
2082+
faulty_data = True
2083+
2084+
if misshapen_data:
2085+
_api.warn_deprecated(
2086+
"3.5", message=f"For X ({width}) and Y ({height}) "
2087+
f"with {self._shading} shading, the expected shape of "
2088+
f"A is ({h}, {w}). Passing A ({A.shape}) is deprecated "
2089+
"since %(since)s and will become an error %(removal)s.")
2090+
2091+
if faulty_data:
2092+
raise TypeError(
2093+
f"Dimensions of A {A.shape} are incompatible with "
2094+
f"X ({width}) and/or Y ({height})")
2095+
2096+
return super().set_array(A)
2097+
20502098
def get_datalim(self, transData):
20512099
return (self.get_transform() - transData).transform_bbox(self._bbox)
20522100

lib/matplotlib/tests/test_collections.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def test_quadmesh_deprecated_signature(
734734
X += 0.2 * Y
735735
coords = np.stack([X, Y], axis=-1)
736736
assert coords.shape == (3, 4, 2)
737-
C = np.linspace(0, 2, 12).reshape(3, 4)
737+
C = np.linspace(0, 2, 6).reshape(2, 3)
738738

739739
ax = fig_test.add_subplot()
740740
ax.set(xlim=(0, 5), ylim=(0, 4))
@@ -789,6 +789,32 @@ def test_quadmesh_deprecated_positional(fig_test, fig_ref):
789789
ax.add_collection(qmesh)
790790

791791

792+
def test_quadmesh_set_array_validation():
793+
x = np.arange(11)
794+
y = np.arange(8)
795+
z = np.random.random((7, 10))
796+
fig, ax = plt.subplots()
797+
coll = ax.pcolormesh(x, y, z)
798+
799+
# Test deprecated warning when faulty shape is passed.
800+
with pytest.warns(MatplotlibDeprecationWarning):
801+
coll.set_array(z.reshape(10, 7))
802+
803+
z = np.arange(54).reshape((6, 9))
804+
with pytest.raises(TypeError, match=r"Dimensions of A \(6, 9\) "
805+
r"are incompatible with X \(11\) and/or Y \(8\)"):
806+
coll.set_array(z)
807+
with pytest.raises(TypeError, match=r"Dimensions of A \(54,\) "
808+
r"are incompatible with X \(11\) and/or Y \(8\)"):
809+
coll.set_array(z.ravel())
810+
811+
x = np.arange(10)
812+
y = np.arange(7)
813+
z = np.random.random((7, 10))
814+
fig, ax = plt.subplots()
815+
coll = ax.pcolormesh(x, y, z, shading='gouraud')
816+
817+
792818
def test_quadmesh_get_coordinates():
793819
x = [0, 1, 2]
794820
y = [2, 4, 6]
@@ -817,6 +843,19 @@ def test_quadmesh_set_array():
817843
fig.canvas.draw()
818844
assert np.array_equal(coll.get_array(), np.ones(9))
819845

846+
z = np.arange(16).reshape((4, 4))
847+
fig, ax = plt.subplots()
848+
coll = ax.pcolormesh(x, y, np.ones(z.shape), shading='gouraud')
849+
# Test that the collection is able to update with a 2d array
850+
coll.set_array(z)
851+
fig.canvas.draw()
852+
assert np.array_equal(coll.get_array(), z)
853+
854+
# Check that pre-flattened arrays work too
855+
coll.set_array(np.ones(16))
856+
fig.canvas.draw()
857+
assert np.array_equal(coll.get_array(), np.ones(16))
858+
820859

821860
def test_quadmesh_vmin_vmax():
822861
# test when vmin/vmax on the norm changes, the quadmesh gets updated

0 commit comments

Comments
 (0)