diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index ee9058e9f52a..c2becd70733a 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -2024,14 +2024,15 @@ def __init__(self, *args, **kwargs): kwargs.setdefault("pickradius", 0) # end of signature deprecation code - super().__init__(**kwargs) _api.check_shape((None, None, 2), coordinates=coords) self._coordinates = coords self._antialiased = antialiased self._shading = shading - self._bbox = transforms.Bbox.unit() self._bbox.update_from_data_xy(self._coordinates.reshape(-1, 2)) + # super init delayed after own init because array kwarg requires + # self._coordinates and self._shading + super().__init__(**kwargs) # Only needed during signature deprecation __init__.__signature__ = inspect.signature( @@ -2047,6 +2048,53 @@ def set_paths(self): self._paths = self._convert_mesh_to_paths(self._coordinates) self.stale = True + def set_array(self, A): + """ + Set the data values. + + Parameters + ---------- + A : (M, N) array-like or M*N array-like + If the values are provided as a 2D grid, the shape must match the + coordinates grid. If the values are 1D, they are reshaped to 2D. + M, N follow from the coordinates grid, where the coordinates grid + shape is (M, N) for 'gouraud' *shading* and (M+1, N+1) for 'flat' + shading. + """ + height, width = self._coordinates.shape[0:-1] + misshapen_data = False + faulty_data = False + + if self._shading == 'flat': + h, w = height-1, width-1 + else: + h, w = height, width + + if A is not None: + shape = np.shape(A) + if len(shape) == 1: + if shape[0] != (h*w): + faulty_data = True + elif shape != (h, w): + if np.prod(shape) == (h * w): + misshapen_data = True + else: + faulty_data = True + + if misshapen_data: + _api.warn_deprecated( + "3.5", message=f"For X ({width}) and Y ({height}) " + f"with {self._shading} shading, the expected shape of " + f"A is ({h}, {w}). Passing A ({A.shape}) is deprecated " + "since %(since)s and will become an error %(removal)s.") + + if faulty_data: + raise TypeError( + f"Dimensions of A {A.shape} are incompatible with " + f"X ({width}) and/or Y ({height})") + + return super().set_array(A) + def get_datalim(self, transData): return (self.get_transform() - transData).transform_bbox(self._bbox) diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 3a12f2fc3fc1..4d8df2dc6a81 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -734,7 +734,7 @@ def test_quadmesh_deprecated_signature( X += 0.2 * Y coords = np.stack([X, Y], axis=-1) assert coords.shape == (3, 4, 2) - C = np.linspace(0, 2, 12).reshape(3, 4) + C = np.linspace(0, 2, 6).reshape(2, 3) ax = fig_test.add_subplot() ax.set(xlim=(0, 5), ylim=(0, 4)) @@ -789,6 +789,32 @@ def test_quadmesh_deprecated_positional(fig_test, fig_ref): ax.add_collection(qmesh) +def test_quadmesh_set_array_validation(): + x = np.arange(11) + y = np.arange(8) + z = np.random.random((7, 10)) + fig, ax = plt.subplots() + coll = ax.pcolormesh(x, y, z) + + # Test deprecated warning when faulty shape is passed. + with pytest.warns(MatplotlibDeprecationWarning): + coll.set_array(z.reshape(10, 7)) + + z = np.arange(54).reshape((6, 9)) + with pytest.raises(TypeError, match=r"Dimensions of A \(6, 9\) " + r"are incompatible with X \(11\) and/or Y \(8\)"): + coll.set_array(z) + with pytest.raises(TypeError, match=r"Dimensions of A \(54,\) " + r"are incompatible with X \(11\) and/or Y \(8\)"): + coll.set_array(z.ravel()) + + x = np.arange(10) + y = np.arange(7) + z = np.random.random((7, 10)) + fig, ax = plt.subplots() + coll = ax.pcolormesh(x, y, z, shading='gouraud') + + def test_quadmesh_get_coordinates(): x = [0, 1, 2] y = [2, 4, 6] @@ -817,6 +843,19 @@ def test_quadmesh_set_array(): fig.canvas.draw() assert np.array_equal(coll.get_array(), np.ones(9)) + z = np.arange(16).reshape((4, 4)) + fig, ax = plt.subplots() + coll = ax.pcolormesh(x, y, np.ones(z.shape), shading='gouraud') + # Test that the collection is able to update with a 2d array + coll.set_array(z) + fig.canvas.draw() + assert np.array_equal(coll.get_array(), z) + + # Check that pre-flattened arrays work too + coll.set_array(np.ones(16)) + fig.canvas.draw() + assert np.array_equal(coll.get_array(), np.ones(16)) + def test_quadmesh_vmin_vmax(): # test when vmin/vmax on the norm changes, the quadmesh gets updated