From b629ff86f2b62258ea8d661db72279c5e4732ff4 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Wed, 31 May 2023 11:03:11 +0200 Subject: [PATCH 1/3] Record the connected signal in CallbackRegistry weakref cleanup function. ... to remove the need to loop over all signals in _remove_proxy. --- lib/matplotlib/cbook.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 3f410a811205..109bef224605 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -199,7 +199,7 @@ def __setstate__(self, state): cid_count = state.pop('_cid_gen') vars(self).update(state) self.callbacks = { - s: {cid: _weak_or_strong_ref(func, self._remove_proxy) + s: {cid: _weak_or_strong_ref(func, functools.partial(self._remove_proxy, s)) for cid, func in d.items()} for s, d in self.callbacks.items()} self._func_cid_map = { @@ -212,7 +212,7 @@ def connect(self, signal, func): if self._signals is not None: _api.check_in_list(self._signals, signal=signal) self._func_cid_map.setdefault(signal, {}) - proxy = _weak_or_strong_ref(func, self._remove_proxy) + proxy = _weak_or_strong_ref(func, functools.partial(self._remove_proxy, signal)) if proxy in self._func_cid_map[signal]: return self._func_cid_map[signal][proxy] cid = next(self._cid_gen) @@ -233,18 +233,15 @@ def _connect_picklable(self, signal, func): # Keep a reference to sys.is_finalizing, as sys may have been cleared out # at that point. - def _remove_proxy(self, proxy, *, _is_finalizing=sys.is_finalizing): + def _remove_proxy(self, signal, proxy, *, _is_finalizing=sys.is_finalizing): if _is_finalizing(): # Weakrefs can't be properly torn down at that point anymore. return - for signal, proxy_to_cid in list(self._func_cid_map.items()): - cid = proxy_to_cid.pop(proxy, None) - if cid is not None: - del self.callbacks[signal][cid] - self._pickled_cids.discard(cid) - break - else: - # Not found + cid = self._func_cid_map[signal].pop(proxy, None) + if cid is not None: + del self.callbacks[signal][cid] + self._pickled_cids.discard(cid) + else: # Not found return # Clean up empty dicts if len(self.callbacks[signal]) == 0: @@ -263,10 +260,8 @@ def disconnect(self, cid): proxy = cid_to_proxy.pop(cid, None) if proxy is not None: break - else: - # Not found + else: # Not found return - proxy_to_cid = self._func_cid_map[signal] for current_proxy, current_cid in list(proxy_to_cid.items()): if current_cid == cid: From 920a6d26a9d7497bc190a416c79c86755cdffe94 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Wed, 31 May 2023 11:18:43 +0200 Subject: [PATCH 2/3] Flatten CallbackRegistry._func_cid_map. It is easier to manipulate a flat (signal, proxy) -> cid map rather than a nested signal -> (proxy -> cid) map. --- lib/matplotlib/cbook.py | 44 ++++++++++++------------------ lib/matplotlib/tests/test_cbook.py | 2 +- 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 109bef224605..76ac2756936f 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -171,7 +171,7 @@ class CallbackRegistry: # We maintain two mappings: # callbacks: signal -> {cid -> weakref-to-callback} - # _func_cid_map: signal -> {weakref-to-callback -> cid} + # _func_cid_map: {(signal, weakref-to-callback) -> cid} def __init__(self, exception_handler=_exception_printer, *, signals=None): self._signals = None if signals is None else list(signals) # Copy it. @@ -203,23 +203,21 @@ def __setstate__(self, state): for cid, func in d.items()} for s, d in self.callbacks.items()} self._func_cid_map = { - s: {proxy: cid for cid, proxy in d.items()} - for s, d in self.callbacks.items()} + (s, proxy): cid + for s, d in self.callbacks.items() for cid, proxy in d.items()} self._cid_gen = itertools.count(cid_count) def connect(self, signal, func): """Register *func* to be called when signal *signal* is generated.""" if self._signals is not None: _api.check_in_list(self._signals, signal=signal) - self._func_cid_map.setdefault(signal, {}) proxy = _weak_or_strong_ref(func, functools.partial(self._remove_proxy, signal)) - if proxy in self._func_cid_map[signal]: - return self._func_cid_map[signal][proxy] - cid = next(self._cid_gen) - self._func_cid_map[signal][proxy] = cid - self.callbacks.setdefault(signal, {}) - self.callbacks[signal][cid] = proxy - return cid + try: + return self._func_cid_map[signal, proxy] + except KeyError: + cid = self._func_cid_map[signal, proxy] = next(self._cid_gen) + self.callbacks.setdefault(signal, {})[cid] = proxy + return cid def _connect_picklable(self, signal, func): """ @@ -237,16 +235,14 @@ def _remove_proxy(self, signal, proxy, *, _is_finalizing=sys.is_finalizing): if _is_finalizing(): # Weakrefs can't be properly torn down at that point anymore. return - cid = self._func_cid_map[signal].pop(proxy, None) + cid = self._func_cid_map.pop((signal, proxy), None) if cid is not None: del self.callbacks[signal][cid] self._pickled_cids.discard(cid) else: # Not found return - # Clean up empty dicts - if len(self.callbacks[signal]) == 0: + if len(self.callbacks[signal]) == 0: # Clean up empty dicts del self.callbacks[signal] - del self._func_cid_map[signal] def disconnect(self, cid): """ @@ -255,22 +251,16 @@ def disconnect(self, cid): No error is raised if such a callback does not exist. """ self._pickled_cids.discard(cid) - # Clean up callbacks - for signal, cid_to_proxy in list(self.callbacks.items()): - proxy = cid_to_proxy.pop(cid, None) - if proxy is not None: + for signal, proxy in self._func_cid_map: + if self._func_cid_map[signal, proxy] == cid: break else: # Not found return - proxy_to_cid = self._func_cid_map[signal] - for current_proxy, current_cid in list(proxy_to_cid.items()): - if current_cid == cid: - assert proxy is current_proxy - del proxy_to_cid[current_proxy] - # Clean up empty dicts - if len(self.callbacks[signal]) == 0: + assert self.callbacks[signal][cid] == proxy + del self.callbacks[signal][cid] + del self._func_cid_map[signal, proxy] + if len(self.callbacks[signal]) == 0: # Clean up empty dicts del self.callbacks[signal] - del self._func_cid_map[signal] def process(self, s, *args, **kwargs): """ diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 55dc934baf42..6b484af08608 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -193,7 +193,7 @@ def disconnect(self, cid): return self.callbacks.disconnect(cid) def count(self): - count1 = len(self.callbacks._func_cid_map.get(self.signal, [])) + count1 = sum(s == self.signal for s, p in self.callbacks._func_cid_map) count2 = len(self.callbacks.callbacks.get(self.signal)) assert count1 == count2 return count1 From 87f194f17785782be3877bd4a23f16118f0f56bb Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Wed, 31 May 2023 12:40:32 +0200 Subject: [PATCH 3/3] Support unhashable callbacks in CallbackRegistry. ... by replacing _func_cid_map by a dict-like structure (_UnhashDict) that also supports unhashable entries. Note that _func_cid_map (and thus _UnhashDict) can be dropped if we get rid of proxy deduplication in CallbackRegistry. --- lib/matplotlib/cbook.py | 65 +++++++++++++++++++++++++++--- lib/matplotlib/tests/test_cbook.py | 42 +++++++++++-------- 2 files changed, 86 insertions(+), 21 deletions(-) diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 76ac2756936f..2af408e75774 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -112,6 +112,61 @@ def _weak_or_strong_ref(func, callback): return _StrongRef(func) +class _UnhashDict: + """ + A minimal dict-like class that also supports unhashable keys, storing them + in a list of key-value pairs. + + This class only implements the interface needed for `CallbackRegistry`, and + tries to minimize the overhead for the hashable case. + """ + + def __init__(self, pairs): + self._dict = {} + self._pairs = [] + for k, v in pairs: + self[k] = v + + def __setitem__(self, key, value): + try: + self._dict[key] = value + except TypeError: + for i, (k, v) in enumerate(self._pairs): + if k == key: + self._pairs[i] = (key, value) + break + else: + self._pairs.append((key, value)) + + def __getitem__(self, key): + try: + return self._dict[key] + except TypeError: + pass + for k, v in self._pairs: + if k == key: + return v + raise KeyError(key) + + def pop(self, key, *args): + try: + if key in self._dict: + return self._dict.pop(key) + except TypeError: + for i, (k, v) in enumerate(self._pairs): + if k == key: + del self._pairs[i] + return v + if args: + return args[0] + raise KeyError(key) + + def __iter__(self): + yield from self._dict + for k, v in self._pairs: + yield k + + class CallbackRegistry: """ Handle registering, processing, blocking, and disconnecting @@ -178,7 +233,7 @@ def __init__(self, exception_handler=_exception_printer, *, signals=None): self.exception_handler = exception_handler self.callbacks = {} self._cid_gen = itertools.count() - self._func_cid_map = {} + self._func_cid_map = _UnhashDict([]) # A hidden variable that marks cids that need to be pickled. self._pickled_cids = set() @@ -202,9 +257,9 @@ def __setstate__(self, state): s: {cid: _weak_or_strong_ref(func, functools.partial(self._remove_proxy, s)) for cid, func in d.items()} for s, d in self.callbacks.items()} - self._func_cid_map = { - (s, proxy): cid - for s, d in self.callbacks.items() for cid, proxy in d.items()} + self._func_cid_map = _UnhashDict( + ((s, proxy), cid) + for s, d in self.callbacks.items() for cid, proxy in d.items()) self._cid_gen = itertools.count(cid_count) def connect(self, signal, func): @@ -258,7 +313,7 @@ def disconnect(self, cid): return assert self.callbacks[signal][cid] == proxy del self.callbacks[signal][cid] - del self._func_cid_map[signal, proxy] + self._func_cid_map.pop((signal, proxy)) if len(self.callbacks[signal]) == 0: # Clean up empty dicts del self.callbacks[signal] diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 6b484af08608..7142029f21d0 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -178,6 +178,15 @@ def test_boxplot_stats_autorange_false(self): assert_array_almost_equal(bstats_true[0]['fliers'], []) +class Hashable: + def dummy(self): pass + + +class Unhashable: + __hash__ = None # type: ignore + def dummy(self): pass + + class Test_callback_registry: def setup_method(self): self.signal = 'test' @@ -200,13 +209,13 @@ def count(self): def is_empty(self): np.testing.break_cycles() - assert self.callbacks._func_cid_map == {} + assert [*self.callbacks._func_cid_map] == [] assert self.callbacks.callbacks == {} assert self.callbacks._pickled_cids == set() def is_not_empty(self): np.testing.break_cycles() - assert self.callbacks._func_cid_map != {} + assert [*self.callbacks._func_cid_map] != [] assert self.callbacks.callbacks != {} def test_cid_restore(self): @@ -217,12 +226,13 @@ def test_cid_restore(self): assert cid == 1 @pytest.mark.parametrize('pickle', [True, False]) - def test_callback_complete(self, pickle): + @pytest.mark.parametrize('cls', [Hashable, Unhashable]) + def test_callback_complete(self, pickle, cls): # ensure we start with an empty registry self.is_empty() # create a class for testing - mini_me = Test_callback_registry() + mini_me = cls() # test that we can add a callback cid1 = self.connect(self.signal, mini_me.dummy, pickle) @@ -233,7 +243,7 @@ def test_callback_complete(self, pickle): cid2 = self.connect(self.signal, mini_me.dummy, pickle) assert cid1 == cid2 self.is_not_empty() - assert len(self.callbacks._func_cid_map) == 1 + assert len([*self.callbacks._func_cid_map]) == 1 assert len(self.callbacks.callbacks) == 1 del mini_me @@ -242,12 +252,13 @@ def test_callback_complete(self, pickle): self.is_empty() @pytest.mark.parametrize('pickle', [True, False]) - def test_callback_disconnect(self, pickle): + @pytest.mark.parametrize('cls', [Hashable, Unhashable]) + def test_callback_disconnect(self, pickle, cls): # ensure we start with an empty registry self.is_empty() # create a class for testing - mini_me = Test_callback_registry() + mini_me = cls() # test that we can add a callback cid1 = self.connect(self.signal, mini_me.dummy, pickle) @@ -260,12 +271,13 @@ def test_callback_disconnect(self, pickle): self.is_empty() @pytest.mark.parametrize('pickle', [True, False]) - def test_callback_wrong_disconnect(self, pickle): + @pytest.mark.parametrize('cls', [Hashable, Unhashable]) + def test_callback_wrong_disconnect(self, pickle, cls): # ensure we start with an empty registry self.is_empty() # create a class for testing - mini_me = Test_callback_registry() + mini_me = cls() # test that we can add a callback cid1 = self.connect(self.signal, mini_me.dummy, pickle) @@ -278,20 +290,21 @@ def test_callback_wrong_disconnect(self, pickle): self.is_not_empty() @pytest.mark.parametrize('pickle', [True, False]) - def test_registration_on_non_empty_registry(self, pickle): + @pytest.mark.parametrize('cls', [Hashable, Unhashable]) + def test_registration_on_non_empty_registry(self, pickle, cls): # ensure we start with an empty registry self.is_empty() # setup the registry with a callback - mini_me = Test_callback_registry() + mini_me = cls() self.connect(self.signal, mini_me.dummy, pickle) # Add another callback - mini_me2 = Test_callback_registry() + mini_me2 = cls() self.connect(self.signal, mini_me2.dummy, pickle) # Remove and add the second callback - mini_me2 = Test_callback_registry() + mini_me2 = cls() self.connect(self.signal, mini_me2.dummy, pickle) # We still have 2 references @@ -303,9 +316,6 @@ def test_registration_on_non_empty_registry(self, pickle): mini_me2 = None self.is_empty() - def dummy(self): - pass - def test_pickling(self): assert hasattr(pickle.loads(pickle.dumps(cbook.CallbackRegistry())), "callbacks")