From 0022611818b8ed3d192424768c54111dc3c65009 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Tue, 1 Nov 2016 17:36:03 -0700 Subject: [PATCH] Switch to a private, simpler AxesStack. The current implementation of AxesStack subclasses cbook.Stack, which requires hashable keys, which leads to additional complexity on the caller's side (`_make_key`). Instead, switch to using two lists (keys and axes) and relying on `list.index`, which makes the implementation much simpler. Also make the new class private and deprecate the previous one. --- lib/matplotlib/figure.py | 108 +++++++++++++++---------- lib/matplotlib/projections/__init__.py | 6 +- 2 files changed, 67 insertions(+), 47 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 01c1a1329c4a..cf661fab5c47 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -72,6 +72,7 @@ class AxesStack(Stack): """ def __init__(self): + cbook.warn_deprecated("2.1") Stack.__init__(self) self._ind = 0 @@ -158,6 +159,62 @@ def __contains__(self, a): return a in self.as_list() +class _AxesStack(object): + """Lightweight stack that tracks Axes in a Figure. + """ + + def __init__(self): + # We maintain a list of (creation_index, key, axes) tuples. + # We do not use an OrderedDict because 1. the keys may not be hashable + # and 2. we need to directly find a pair corresponding to an axes (i.e. + # we'd really need a two-way dict). + self._items = [] + self._created = 0 + + def as_list(self): + """Copy of the list of axes, in the order of insertion. + """ + return [ax for _, _, ax in sorted(self._items)] + + def get(self, key): + """Find the axes corresponding to a key; defaults to `None`. + """ + return next((ax for _, k, ax in self._items if k == key), None) + + def current_key_axes(self): + """Return the topmost `(key, axes)` pair, or `(None, None)` if empty. + """ + _, key, ax = (self._items or [(None, None, None)])[-1] + return key, ax + + def add(self, key, ax): + """Append a `(key, axes)` pair, unless the axes are already present. + """ + # Skipping existing Axes is needed to support calling `add_axes` with + # an already existing Axes. + if not any(a == ax for _, _, a in self._items): + self._items.append((self._created, key, ax)) + self._created += 1 + + def bubble(self, ax): + """Move an axes and its corresponding key to the top. + """ + idx, = (idx for idx, (_, _, a) in enumerate(self._items) if a == ax) + self._items.append(self._items[idx]) + del self._items[idx] + + def remove(self, ax): + """Remove an axes and its corresponding key. + """ + idx, = (idx for idx, (_, _, a) in enumerate(self._items) if a == ax) + del self._items[idx] + + def clear(self): + """Clear the stack. + """ + del self._items[:] + + class SubplotParams(object): """ A class to hold the parameters for a subplot @@ -358,7 +415,7 @@ def __init__(self, self.subplotpars = subplotpars self.set_tight_layout(tight_layout) - self._axstack = AxesStack() # track all figure axes and current axes + self._axstack = _AxesStack() # track all figure axes and current axes self.clf() self._cachedRenderer = None @@ -410,10 +467,8 @@ def show(self, warn=True): "matplotlib is currently using a non-GUI backend, " "so cannot show the figure") - def _get_axes(self): - return self._axstack.as_list() - - axes = property(fget=_get_axes, doc="Read-only: list of axes in Figure") + axes = property(lambda self: self._axstack.as_list(), + doc="Read-only: list of axes in Figure") def _get_dpi(self): return self._dpi @@ -835,36 +890,6 @@ def delaxes(self, a): func(self) self.stale = True - def _make_key(self, *args, **kwargs): - 'make a hashable key out of args and kwargs' - - def fixitems(items): - #items may have arrays and lists in them, so convert them - # to tuples for the key - ret = [] - for k, v in items: - # some objects can define __getitem__ without being - # iterable and in those cases the conversion to tuples - # will fail. So instead of using the iterable(v) function - # we simply try and convert to a tuple, and proceed if not. - try: - v = tuple(v) - except Exception: - pass - ret.append((k, v)) - return tuple(ret) - - def fixlist(args): - ret = [] - for a in args: - if iterable(a): - a = tuple(a) - ret.append(a) - return tuple(ret) - - key = fixlist(args), fixitems(six.iteritems(kwargs)) - return key - def add_axes(self, *args, **kwargs): """ Add an axes at position *rect* [*left*, *bottom*, *width*, @@ -929,9 +954,9 @@ def add_axes(self, *args, **kwargs): # shortcut the projection "key" modifications later on, if an axes # with the exact args/kwargs exists, return it immediately. - key = self._make_key(*args, **kwargs) + key = (args, kwargs) ax = self._axstack.get(key) - if ax is not None: + if ax: self.sca(ax) return ax @@ -951,7 +976,7 @@ def add_axes(self, *args, **kwargs): # check that an axes of this type doesn't already exist, if it # does, set it as active and return it ax = self._axstack.get(key) - if ax is not None and isinstance(ax, projection_class): + if isinstance(ax, projection_class): self.sca(ax) return ax @@ -1037,15 +1062,14 @@ def add_subplot(self, *args, **kwargs): raise ValueError(msg) # make a key for the subplot (which includes the axes object id # in the hash) - key = self._make_key(*args, **kwargs) + key = (args, kwargs) else: projection_class, kwargs, key = process_projection_requirements( self, *args, **kwargs) # try to find the axes with this key in the stack ax = self._axstack.get(key) - - if ax is not None: + if ax: if isinstance(ax, projection_class): # the axes already existed, so set it as active & return self.sca(ax) @@ -1614,7 +1638,7 @@ def _gci(self): do not use elsewhere. """ # Look first for an image in the current Axes: - cax = self._axstack.current_key_axes()[1] + ckey, cax = self._axstack.current_key_axes() if cax is None: return None im = cax._gci() diff --git a/lib/matplotlib/projections/__init__.py b/lib/matplotlib/projections/__init__.py index 1e423420b0b6..5e5ffcaf2e66 100644 --- a/lib/matplotlib/projections/__init__.py +++ b/lib/matplotlib/projections/__init__.py @@ -96,11 +96,7 @@ def process_projection_requirements(figure, *args, **kwargs): raise TypeError('projection must be a string, None or implement a ' '_as_mpl_axes method. Got %r' % projection) - # Make the key without projection kwargs, this is used as a unique - # lookup for axes instances - key = figure._make_key(*args, **kwargs) - - return projection_class, kwargs, key + return projection_class, kwargs, (args, kwargs) def get_projection_names():