diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 18ef6b793c2b..da870a7f421e 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -1669,7 +1669,7 @@ class Grouper(object): False """ - def __init__(self, init=[]): + def __init__(self, init=()): mapping = self._mapping = {} for x in init: mapping[ref(x)] = [ref(x)] @@ -1721,6 +1721,14 @@ def joined(self, a, b): except KeyError: return False + def remove(self, a): + self.clean() + + mapping = self._mapping + seta = mapping.pop(ref(a), None) + if seta is not None: + seta.remove(ref(a)) + def __iter__(self): """ Iterate over each of the disjoint sets as a list. diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 0de748cf7a78..547df4e3ca17 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -916,7 +916,7 @@ def add_axes(self, *args, **kwargs): self._axstack.add(key, a) self.sca(a) - a._remove_method = lambda ax: self.delaxes(ax) + a._remove_method = self.__remove_ax self.stale = True a.stale_callback = _stale_figure_callback return a @@ -1006,11 +1006,37 @@ def add_subplot(self, *args, **kwargs): self._axstack.add(key, a) self.sca(a) - a._remove_method = lambda ax: self.delaxes(ax) + a._remove_method = self.__remove_ax self.stale = True a.stale_callback = _stale_figure_callback return a + def __remove_ax(self, ax): + def _reset_loc_form(axis): + axis.set_major_formatter(axis.get_major_formatter()) + axis.set_major_locator(axis.get_major_locator()) + axis.set_minor_formatter(axis.get_minor_formatter()) + axis.set_minor_locator(axis.get_minor_locator()) + + def _break_share_link(ax, grouper): + siblings = grouper.get_siblings(ax) + if len(siblings) > 1: + grouper.remove(ax) + for last_ax in siblings: + if ax is last_ax: + continue + return last_ax + return None + + self.delaxes(ax) + last_ax = _break_share_link(ax, ax._shared_y_axes) + if last_ax is not None: + _reset_loc_form(last_ax.yaxis) + + last_ax = _break_share_link(ax, ax._shared_x_axes) + if last_ax is not None: + _reset_loc_form(last_ax.xaxis) + def clf(self, keep_observers=False): """ Clear the figure. diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index a3a161d39f80..de2f51bf3a14 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -4086,11 +4086,57 @@ def test_shared_scale(): assert_equal(ax.get_yscale(), 'linear') assert_equal(ax.get_xscale(), 'linear') + @cleanup def test_violin_point_mass(): """Violin plot should handle point mass pdf gracefully.""" plt.violinplot(np.array([0, 0])) + +@cleanup +def test_remove_shared_axes(): + + def _helper_x(ax): + ax2 = ax.twinx() + ax2.remove() + ax.set_xlim(0, 15) + r = ax.xaxis.get_major_locator()() + assert r[-1] > 14 + + def _helper_y(ax): + ax2 = ax.twiny() + ax2.remove() + ax.set_ylim(0, 15) + r = ax.yaxis.get_major_locator()() + assert r[-1] > 14 + + # test all of the ways to get fig/ax sets + fig = plt.figure() + ax = fig.gca() + yield _helper_x, ax + yield _helper_y, ax + + fig, ax = plt.subplots() + yield _helper_x, ax + yield _helper_y, ax + + fig, ax_lst = plt.subplots(2, 2, sharex='all', sharey='all') + ax = ax_lst[0][0] + yield _helper_x, ax + yield _helper_y, ax + + fig = plt.figure() + ax = fig.add_axes([.1, .1, .8, .8]) + yield _helper_x, ax + yield _helper_y, ax + + fig, ax_lst = plt.subplots(2, 2, sharex='all', sharey='all') + ax = ax_lst[0][0] + orig_xlim = ax_lst[0][1].get_xlim() + ax.remove() + ax.set_xlim(0, 5) + assert assert_array_equal(ax_lst[0][1].get_xlim(), orig_xlim) + if __name__ == '__main__': import nose import sys diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 2b916b08566f..1b11fe026120 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -1,5 +1,7 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) +import itertools +from weakref import ref from matplotlib.externals import six @@ -376,3 +378,40 @@ def test_step_fails(): np.arange(12)) assert_raises(ValueError, cbook._step_validation, np.arange(12), np.arange(3)) + + +def test_grouper(): + class dummy(): + pass + a, b, c, d, e = objs = [dummy() for j in range(5)] + g = cbook.Grouper() + g.join(*objs) + assert set(list(g)[0]) == set(objs) + assert set(g.get_siblings(a)) == set(objs) + + for other in objs[1:]: + assert g.joined(a, other) + + g.remove(a) + for other in objs[1:]: + assert not g.joined(a, other) + + for A, B in itertools.product(objs[1:], objs[1:]): + assert g.joined(A, B) + + +def test_grouper_private(): + class dummy(): + pass + objs = [dummy() for j in range(5)] + g = cbook.Grouper() + g.join(*objs) + # reach in and touch the internals ! + mapping = g._mapping + + for o in objs: + assert ref(o) in mapping + + base_set = mapping[ref(objs[0])] + for o in objs[1:]: + assert mapping[ref(o)] is base_set