diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 3796d9bbe508..740c59bd85df 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -2237,7 +2237,8 @@ def add_child_axes(self, ax): ax.stale_callback = martist._stale_axes_callback self.child_axes.append(ax) - ax._remove_method = self.child_axes.remove + ax._remove_method = functools.partial( + self.figure._remove_axes, owners=[self.child_axes]) self.stale = True return ax diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index ce263c3d8d1c..4361ef655c81 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -936,11 +936,25 @@ def delaxes(self, ax): """ Remove the `~.axes.Axes` *ax* from the figure; update the current Axes. """ + self._remove_axes(ax, owners=[self._axstack, self._localaxes]) + + def _remove_axes(self, ax, owners): + """ + Common helper for removal of standard axes (via delaxes) and of child axes. + + Parameters + ---------- + ax : `~.AxesBase` + The Axes to remove. + owners + List of objects (list or _AxesStack) "owning" the axes, from which the Axes + will be remove()d. + """ + for owner in owners: + owner.remove(ax) - self._axstack.remove(ax) self._axobservers.process("_axes_change_event", self) self.stale = True - self._localaxes.remove(ax) self.canvas.release_mouse(ax) for name in ax._axis_names: # Break link between any shared axes diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 0fcb2eb26cbb..c78be0ee9cbe 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -8679,6 +8679,14 @@ def test_cla_clears_children_axes_and_fig(): assert art.figure is None +def test_child_axes_removal(): + fig, ax = plt.subplots() + marginal = ax.inset_axes([1, 0, .1, 1], sharey=ax) + marginal_twin = marginal.twinx() + marginal.remove() + ax.set(xlim=(-1, 1), ylim=(10, 20)) + + def test_scatter_color_repr_error(): def get_next_color():