diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index c0f66c9f27cd..814df14d9ab4 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -361,9 +361,14 @@ class _BoundMethodProxy(object): Minor bugfixes by Michael Droettboom ''' def __init__(self, cb): + self._hash = hash(cb) + self._destroy_callbacks = [] try: try: - self.inst = ref(cb.im_self) + if six.PY3: + self.inst = ref(cb.__self__, self._destroy) + else: + self.inst = ref(cb.im_self, self._destroy) except TypeError: self.inst = None if six.PY3: @@ -377,6 +382,16 @@ def __init__(self, cb): self.func = cb self.klass = None + def add_destroy_callback(self, callback): + self._destroy_callbacks.append(_BoundMethodProxy(callback)) + + def _destroy(self, wk): + for callback in self._destroy_callbacks: + try: + callback(self) + except ReferenceError: + pass + def __getstate__(self): d = self.__dict__.copy() # de-weak reference inst @@ -433,6 +448,9 @@ def __ne__(self, other): ''' return not self.__eq__(other) + def __hash__(self): + return self._hash + class CallbackRegistry(object): """ @@ -492,17 +510,32 @@ def connect(self, s, func): func will be called """ self._func_cid_map.setdefault(s, WeakKeyDictionary()) - if func in self._func_cid_map[s]: - return self._func_cid_map[s][func] + # Note proxy not needed in python 3. + # TODO rewrite this when support for python2.x gets dropped. + proxy = _BoundMethodProxy(func) + if proxy in self._func_cid_map[s]: + return self._func_cid_map[s][proxy] + proxy.add_destroy_callback(self._remove_proxy) self._cid += 1 cid = self._cid - self._func_cid_map[s][func] = cid + self._func_cid_map[s][proxy] = cid self.callbacks.setdefault(s, dict()) - proxy = _BoundMethodProxy(func) self.callbacks[s][cid] = proxy return cid + def _remove_proxy(self, proxy): + for signal, proxies in list(six.iteritems(self._func_cid_map)): + try: + del self.callbacks[signal][proxies[proxy]] + except KeyError: + pass + + if len(self.callbacks[signal]) == 0: + del self.callbacks[signal] + del self._func_cid_map[signal] + + def disconnect(self, cid): """ disconnect the callback registered with callback id *cid* @@ -513,7 +546,7 @@ def disconnect(self, cid): except KeyError: continue else: - for category, functions in list( + for signal, functions in list( six.iteritems(self._func_cid_map)): for function, value in list(six.iteritems(functions)): if value == cid: @@ -527,11 +560,10 @@ def process(self, s, *args, **kwargs): """ if s in self.callbacks: for cid, proxy in list(six.iteritems(self.callbacks[s])): - # Clean out dead references - if proxy.inst is not None and proxy.inst() is None: - del self.callbacks[s][cid] - else: + try: proxy(*args, **kwargs) + except ReferenceError: + self._remove_proxy(proxy) class Scheduler(threading.Thread): diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 416fa0c74020..0965a3d18066 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -8,7 +8,7 @@ import numpy as np from numpy.testing.utils import (assert_array_equal, assert_approx_equal, assert_array_almost_equal) -from nose.tools import assert_equal, raises, assert_true +from nose.tools import assert_equal, assert_not_equal, raises, assert_true import matplotlib.cbook as cbook import matplotlib.colors as mcolors @@ -243,3 +243,47 @@ def test_label_error(self): def test_bad_dims(self): data = np.random.normal(size=(34, 34, 34)) results = cbook.boxplot_stats(data) + + +class Test_callback_registry(object): + def setup(self): + self.signal = 'test' + self.callbacks = cbook.CallbackRegistry() + + def connect(self, s, func): + return self.callbacks.connect(s, func) + + def is_empty(self): + assert_equal(self.callbacks._func_cid_map, {}) + assert_equal(self.callbacks.callbacks, {}) + + def is_not_empty(self): + assert_not_equal(self.callbacks._func_cid_map, {}) + assert_not_equal(self.callbacks.callbacks, {}) + + def test_callback_complete(self): + # ensure we start with an empty registry + self.is_empty() + + # create a class for testing + mini_me = Test_callback_registry() + + # test that we can add a callback + cid1 = self.connect(self.signal, mini_me.dummy) + assert_equal(type(cid1), int) + self.is_not_empty() + + # test that we don't add a second callback + cid2 = self.connect(self.signal, mini_me.dummy) + assert_equal(cid1, cid2) + self.is_not_empty() + assert_equal(len(self.callbacks._func_cid_map), 1) + assert_equal(len(self.callbacks.callbacks), 1) + + del mini_me + + # check we now have no callbacks registered + self.is_empty() + + def dummy(self): + pass