diff --git a/lib/matplotlib/colorbar.py b/lib/matplotlib/colorbar.py index b1cd08d955a7..cc5c236f7d29 100644 --- a/lib/matplotlib/colorbar.py +++ b/lib/matplotlib/colorbar.py @@ -235,7 +235,7 @@ class ColorbarAxes(Axes): Users should not normally instantiate this class, but it is the class returned by ``cbar = fig.colorbar(im); cax = cbar.ax``. """ - def __init__(self, parent, userax=True): + def __init__(self, parent, userax=True, colorbar=None): """ Parameters ---------- @@ -243,6 +243,8 @@ def __init__(self, parent, userax=True): Axes that specifies the position of the colorbar. userax : boolean True if the user passed `.Figure.colorbar` the axes manually. + colorbar: ColorbarBase + colorbar that this axes is contained in. """ if userax: @@ -260,7 +262,6 @@ def __init__(self, parent, userax=True): parent.remove() else: outer_ax = parent - inner_ax = outer_ax.inset_axes([0, 0, 1, 1]) self.__dict__.update(inner_ax.__dict__) @@ -272,6 +273,7 @@ def __init__(self, parent, userax=True): self.outer_ax.tick_params = self.inner_ax.tick_params self.outer_ax.set_xticks = self.inner_ax.set_xticks self.outer_ax.set_yticks = self.inner_ax.set_yticks + self.colorbar = colorbar for attr in ["get_position", "set_position", "set_aspect"]: setattr(self, attr, getattr(self.outer_ax, attr)) if userax: @@ -285,6 +287,41 @@ def _set_inner_bounds(self, bounds): self.inner_ax._axes_locator = _TransformedBoundsLocator( bounds, self.outer_ax.transAxes) + def cla(self): + """ + Reset the colorbar axes to be empty. + """ + # need to low-level manipulate the stacks because we + # are just swapping places here. We don't need to + # set transforms etc... + + print('id', self.colorbar.mappable.colorbar_cid) + + if isinstance(self.colorbar.mappable.colorbar_cid, dict): + cid = self.colorbar.mappable.colorbar_cid[self.colorbar] + else: + cid = self.colorbar.mappable.colorbar_cid + self.colorbar.mappable.callbacksSM.disconnect(cid) + + self.inner_ax.cla() + self.outer_ax.cla() + self.figure._axstack.add(self) + self.figure._axstack.remove(self.outer_ax) + self.figure._localaxes.add(self) + self.figure._localaxes.remove(self.outer_ax) + + self.__dict__.update(self.outer_ax.__dict__) + self.inner_ax.remove() + del self.inner_ax + del self.outer_ax + + self.xaxis.set_visible(True) + self.yaxis.set_visible(True) + for spine in self.spines.values(): + spine.set_visible(True) + + self.set_facecolor(mpl.rcParams['axes.facecolor']) + class _ColorbarSpine(mspines.Spine): def __init__(self, axes): @@ -418,7 +455,7 @@ def __init__(self, ax, cmap=None, ['uniform', 'proportional'], spacing=spacing) # wrap the axes so that it can be positioned as an inset axes: - ax = ColorbarAxes(ax, userax=userax) + ax = ColorbarAxes(ax, userax=userax, colorbar=self) self.ax = ax ax.set(navigate=False) @@ -1195,9 +1232,17 @@ def __init__(self, ax, mappable, **kwargs): _add_disjoint_kwargs(kwargs, alpha=mappable.get_alpha()) super().__init__(ax, **kwargs) - mappable.colorbar = self - mappable.colorbar_cid = mappable.callbacksSM.connect( - 'changed', self.update_normal) + cid = mappable.callbacksSM.connect('changed', self.update_normal) + if mappable.colorbar is None: + mappable.colorbar = self + mappable.colorbar_cid = cid + elif not isinstance(mappable.colorbar, list): + old = mappable.colorbar_cid + mappable.colorbar_cid = {mappable.colorbar: old, self: cid} + mappable.colorbar = [mappable.colorbar, self] + else: + mappable.colorbar += [self] + mappable.colorbar_cid[self] = cid @_api.deprecated("3.3", alternative="update_normal") def on_mappable_changed(self, mappable): @@ -1251,6 +1296,10 @@ def update_normal(self, mappable): self._reset_locator_formatter_scale() self.draw_all() +# except AttributeError: +# # update_normal sometimes is called when it shouldn't be.. +# pass + if isinstance(self.mappable, contour.ContourSet): CS = self.mappable if not CS.filled: @@ -1306,7 +1355,11 @@ def remove(self): gridspec is restored. """ super().remove() - self.mappable.callbacksSM.disconnect(self.mappable.colorbar_cid) + if isinstance(self.mappable.colorbar_cid, dict): + cid = self.mappable.colorbar_cid[self] + else: + cid = self.mappable.colorbar_cid + self.mappable.callbacksSM.disconnect(cid) self.mappable.colorbar = None self.mappable.colorbar_cid = None diff --git a/lib/matplotlib/tests/test_colorbar.py b/lib/matplotlib/tests/test_colorbar.py index ac054755c4c1..be84611aca4c 100644 --- a/lib/matplotlib/tests/test_colorbar.py +++ b/lib/matplotlib/tests/test_colorbar.py @@ -759,6 +759,24 @@ def test_axes_handles_same_functions(fig_ref, fig_test): caxx.set_position([0.92, 0.1, 0.02, 0.7]) +@check_figures_equal(extensions=["png"]) +def test_colorbar_reuse_axes(fig_ref, fig_test): + ax = fig_ref.add_subplot() + pc = ax.imshow(np.arange(100).reshape(10, 10)) + cb = fig_ref.colorbar(pc) + cb2 = fig_ref.colorbar(pc, extend='both') + + ax = fig_test.add_subplot() + pc = ax.imshow(np.arange(100).reshape(10, 10)) + cb = fig_test.colorbar(pc, extend='both') + cb2 = fig_test.colorbar(pc) + # Clear and re-use the same colorbar axes + cb.ax.cla() + cb2.ax.cla() + cb = fig_test.colorbar(pc, cax=cb.ax) + cb2 = fig_test.colorbar(pc, cax=cb2.ax, extend='both') + + def test_inset_colorbar_layout(): fig, ax = plt.subplots(constrained_layout=True, figsize=(3, 6)) pc = ax.imshow(np.arange(100).reshape(10, 10))