From 659fec86be28d99f63ee2f0a3ceda0b8873f2b7e Mon Sep 17 00:00:00 2001 From: saranti Date: Wed, 24 Jan 2024 19:51:12 +1100 Subject: [PATCH] add method to update arrow patch --- .../next_whats_new/update_arrow_patch.rst | 30 ++++++++++++++ lib/matplotlib/patches.py | 40 ++++++++++++++++--- lib/matplotlib/patches.pyi | 9 ++++- lib/matplotlib/tests/test_patches.py | 29 ++++++++++++++ 4 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 doc/users/next_whats_new/update_arrow_patch.rst diff --git a/doc/users/next_whats_new/update_arrow_patch.rst b/doc/users/next_whats_new/update_arrow_patch.rst new file mode 100644 index 000000000000..894090587b5d --- /dev/null +++ b/doc/users/next_whats_new/update_arrow_patch.rst @@ -0,0 +1,30 @@ +Update the position of arrow patch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Adds a setter method that allows the user to update the position of the +`.patches.Arrow` object without requiring a full re-draw. + +.. plot:: + :include-source: true + :alt: Example of changing the position of the arrow with the new ``set_data`` method. + + import matplotlib as mpl + import matplotlib.pyplot as plt + from matplotlib.patches import Arrow + import matplotlib.animation as animation + + fig, ax = plt.subplots() + ax.set_xlim(0, 10) + ax.set_ylim(0, 10) + + a = mpl.patches.Arrow(2, 0, 0, 10) + ax.add_patch(a) + + + # code for modifying the arrow + def update(i): + a.set_data(x=.5, dx=i, dy=6, width=2) + + + ani = animation.FuncAnimation(fig, update, frames=15, interval=90, blit=False) + + plt.show() diff --git a/lib/matplotlib/patches.py b/lib/matplotlib/patches.py index ef52a00b059b..fc9d2c88897d 100644 --- a/lib/matplotlib/patches.py +++ b/lib/matplotlib/patches.py @@ -1297,12 +1297,7 @@ def __init__(self, x, y, dx, dy, *, width=1.0, **kwargs): properties. """ super().__init__(**kwargs) - self._patch_transform = ( - transforms.Affine2D() - .scale(np.hypot(dx, dy), width) - .rotate(np.arctan2(dy, dx)) - .translate(x, y) - .frozen()) + self.set_data(x, y, dx, dy, width) def get_path(self): return self._path @@ -1310,6 +1305,39 @@ def get_path(self): def get_patch_transform(self): return self._patch_transform + def set_data(self, x=None, y=None, dx=None, dy=None, width=None): + """ + Set `.Arrow` x, y, dx, dy and width. + Values left as None will not be updated. + + Parameters + ---------- + x, y : float or None, default: None + The x and y coordinates of the arrow base. + + dx, dy : float or None, default: None + The length of the arrow along x and y direction. + + width : float or None, default: None + Width of full arrow tail. + """ + if x is not None: + self._x = x + if y is not None: + self._y = y + if dx is not None: + self._dx = dx + if dy is not None: + self._dy = dy + if width is not None: + self._width = width + self._patch_transform = ( + transforms.Affine2D() + .scale(np.hypot(self._dx, self._dy), self._width) + .rotate(np.arctan2(self._dy, self._dx)) + .translate(self._x, self._y) + .frozen()) + class FancyArrow(Polygon): """ diff --git a/lib/matplotlib/patches.pyi b/lib/matplotlib/patches.pyi index 287ea0f738ab..f6c9ddf75839 100644 --- a/lib/matplotlib/patches.pyi +++ b/lib/matplotlib/patches.pyi @@ -181,7 +181,14 @@ class Arrow(Patch): def __init__( self, x: float, y: float, dx: float, dy: float, *, width: float = ..., **kwargs ) -> None: ... - + def set_data( + self, + x: float | None = ..., + y: float | None = ..., + dx: float | None = ..., + dy: float | None = ..., + width: float | None = ..., + ) -> None: ... class FancyArrow(Polygon): def __init__( self, diff --git a/lib/matplotlib/tests/test_patches.py b/lib/matplotlib/tests/test_patches.py index b1af0abbc573..9530bcd19130 100644 --- a/lib/matplotlib/tests/test_patches.py +++ b/lib/matplotlib/tests/test_patches.py @@ -931,3 +931,32 @@ def test_modifying_arc(fig_test, fig_ref): fig_test.subplots().add_patch(arc2) arc2.set_width(.5) arc2.set_angle(20) + + +def test_arrow_set_data(): + fig, ax = plt.subplots() + arrow = mpl.patches.Arrow(2, 0, 0, 10) + expected1 = np.array( + [[1.9, 0.], + [2.1, -0.], + [2.1, 8.], + [2.3, 8.], + [2., 10.], + [1.7, 8.], + [1.9, 8.], + [1.9, 0.]] + ) + assert np.allclose(expected1, np.round(arrow.get_verts(), 2)) + + expected2 = np.array( + [[0.39, 0.04], + [0.61, -0.04], + [3.01, 6.36], + [3.24, 6.27], + [3.5, 8.], + [2.56, 6.53], + [2.79, 6.44], + [0.39, 0.04]] + ) + arrow.set_data(x=.5, dx=3, dy=8, width=1.2) + assert np.allclose(expected2, np.round(arrow.get_verts(), 2))