Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support unhashable callbacks in CallbackRegistry #26013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 82 additions & 42 deletions lib/matplotlib/cbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,61 @@
return _StrongRef(func)


class _UnhashDict:
"""
A minimal dict-like class that also supports unhashable keys, storing them
in a list of key-value pairs.

This class only implements the interface needed for `CallbackRegistry`, and
tries to minimize the overhead for the hashable case.
"""

def __init__(self, pairs):
self._dict = {}
self._pairs = []
for k, v in pairs:
self[k] = v

def __setitem__(self, key, value):
try:
self._dict[key] = value
except TypeError:
for i, (k, v) in enumerate(self._pairs):
if k == key:
self._pairs[i] = (key, value)
break
else:
self._pairs.append((key, value))

def __getitem__(self, key):
try:
return self._dict[key]
except TypeError:
pass
for k, v in self._pairs:
if k == key:
return v
raise KeyError(key)

def pop(self, key, *args):
try:
if key in self._dict:
return self._dict.pop(key)
except TypeError:
for i, (k, v) in enumerate(self._pairs):
if k == key:
del self._pairs[i]
return v
if args:
return args[0]
raise KeyError(key)

def __iter__(self):
yield from self._dict
for k, v in self._pairs:
yield k


class CallbackRegistry:
"""
Handle registering, processing, blocking, and disconnecting
Expand Down Expand Up @@ -171,14 +226,14 @@

# We maintain two mappings:
# callbacks: signal -> {cid -> weakref-to-callback}
# _func_cid_map: signal -> {weakref-to-callback -> cid}
# _func_cid_map: {(signal, weakref-to-callback) -> cid}

def __init__(self, exception_handler=_exception_printer, *, signals=None):
self._signals = None if signals is None else list(signals) # Copy it.
self.exception_handler = exception_handler
self.callbacks = {}
self._cid_gen = itertools.count()
self._func_cid_map = {}
self._func_cid_map = _UnhashDict([])
# A hidden variable that marks cids that need to be pickled.
self._pickled_cids = set()

Expand All @@ -199,27 +254,25 @@
cid_count = state.pop('_cid_gen')
vars(self).update(state)
self.callbacks = {
s: {cid: _weak_or_strong_ref(func, self._remove_proxy)
s: {cid: _weak_or_strong_ref(func, functools.partial(self._remove_proxy, s))
for cid, func in d.items()}
for s, d in self.callbacks.items()}
self._func_cid_map = {
s: {proxy: cid for cid, proxy in d.items()}
for s, d in self.callbacks.items()}
self._func_cid_map = _UnhashDict(
((s, proxy), cid)
for s, d in self.callbacks.items() for cid, proxy in d.items())
self._cid_gen = itertools.count(cid_count)

def connect(self, signal, func):
"""Register *func* to be called when signal *signal* is generated."""
if self._signals is not None:
_api.check_in_list(self._signals, signal=signal)
self._func_cid_map.setdefault(signal, {})
proxy = _weak_or_strong_ref(func, self._remove_proxy)
if proxy in self._func_cid_map[signal]:
return self._func_cid_map[signal][proxy]
cid = next(self._cid_gen)
self._func_cid_map[signal][proxy] = cid
self.callbacks.setdefault(signal, {})
self.callbacks[signal][cid] = proxy
return cid
proxy = _weak_or_strong_ref(func, functools.partial(self._remove_proxy, signal))
try:
return self._func_cid_map[signal, proxy]
except KeyError:
cid = self._func_cid_map[signal, proxy] = next(self._cid_gen)
self.callbacks.setdefault(signal, {})[cid] = proxy
return cid

def _connect_picklable(self, signal, func):
"""
Expand All @@ -233,23 +286,18 @@

# Keep a reference to sys.is_finalizing, as sys may have been cleared out
# at that point.
def _remove_proxy(self, proxy, *, _is_finalizing=sys.is_finalizing):
def _remove_proxy(self, signal, proxy, *, _is_finalizing=sys.is_finalizing):
if _is_finalizing():
# Weakrefs can't be properly torn down at that point anymore.
return
for signal, proxy_to_cid in list(self._func_cid_map.items()):
cid = proxy_to_cid.pop(proxy, None)
if cid is not None:
del self.callbacks[signal][cid]
self._pickled_cids.discard(cid)
break
else:
# Not found
cid = self._func_cid_map.pop((signal, proxy), None)
if cid is not None:
del self.callbacks[signal][cid]
self._pickled_cids.discard(cid)
else: # Not found

Check warning on line 297 in lib/matplotlib/cbook.py

View check run for this annotation

Codecov / codecov/patch

lib/matplotlib/cbook.py#L297

Added line #L297 was not covered by tests
return
# Clean up empty dicts
if len(self.callbacks[signal]) == 0:
if len(self.callbacks[signal]) == 0: # Clean up empty dicts
del self.callbacks[signal]
del self._func_cid_map[signal]

def disconnect(self, cid):
"""
Expand All @@ -258,24 +306,16 @@
No error is raised if such a callback does not exist.
"""
self._pickled_cids.discard(cid)
# Clean up callbacks
for signal, cid_to_proxy in list(self.callbacks.items()):
proxy = cid_to_proxy.pop(cid, None)
if proxy is not None:
for signal, proxy in self._func_cid_map:
if self._func_cid_map[signal, proxy] == cid:
break
else:
# Not found
else: # Not found
return

proxy_to_cid = self._func_cid_map[signal]
for current_proxy, current_cid in list(proxy_to_cid.items()):
if current_cid == cid:
assert proxy is current_proxy
del proxy_to_cid[current_proxy]
# Clean up empty dicts
if len(self.callbacks[signal]) == 0:
assert self.callbacks[signal][cid] == proxy
del self.callbacks[signal][cid]
self._func_cid_map.pop((signal, proxy))
if len(self.callbacks[signal]) == 0: # Clean up empty dicts
del self.callbacks[signal]
del self._func_cid_map[signal]

def process(self, s, *args, **kwargs):
"""
Expand Down
44 changes: 27 additions & 17 deletions lib/matplotlib/tests/test_cbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ def test_boxplot_stats_autorange_false(self):
assert_array_almost_equal(bstats_true[0]['fliers'], [])


class Hashable:
def dummy(self): pass


class Unhashable:
__hash__ = None # type: ignore
def dummy(self): pass


class Test_callback_registry:
def setup_method(self):
self.signal = 'test'
Expand All @@ -193,20 +202,20 @@ def disconnect(self, cid):
return self.callbacks.disconnect(cid)

def count(self):
count1 = len(self.callbacks._func_cid_map.get(self.signal, []))
count1 = sum(s == self.signal for s, p in self.callbacks._func_cid_map)
count2 = len(self.callbacks.callbacks.get(self.signal))
assert count1 == count2
return count1

def is_empty(self):
np.testing.break_cycles()
assert self.callbacks._func_cid_map == {}
assert [*self.callbacks._func_cid_map] == []
assert self.callbacks.callbacks == {}
assert self.callbacks._pickled_cids == set()

def is_not_empty(self):
np.testing.break_cycles()
assert self.callbacks._func_cid_map != {}
assert [*self.callbacks._func_cid_map] != []
assert self.callbacks.callbacks != {}

def test_cid_restore(self):
Expand All @@ -217,12 +226,13 @@ def test_cid_restore(self):
assert cid == 1

@pytest.mark.parametrize('pickle', [True, False])
def test_callback_complete(self, pickle):
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
def test_callback_complete(self, pickle, cls):
# ensure we start with an empty registry
self.is_empty()

# create a class for testing
mini_me = Test_callback_registry()
mini_me = cls()

# test that we can add a callback
cid1 = self.connect(self.signal, mini_me.dummy, pickle)
Expand All @@ -233,7 +243,7 @@ def test_callback_complete(self, pickle):
cid2 = self.connect(self.signal, mini_me.dummy, pickle)
assert cid1 == cid2
self.is_not_empty()
assert len(self.callbacks._func_cid_map) == 1
assert len([*self.callbacks._func_cid_map]) == 1
assert len(self.callbacks.callbacks) == 1

del mini_me
Expand All @@ -242,12 +252,13 @@ def test_callback_complete(self, pickle):
self.is_empty()

@pytest.mark.parametrize('pickle', [True, False])
def test_callback_disconnect(self, pickle):
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
def test_callback_disconnect(self, pickle, cls):
# ensure we start with an empty registry
self.is_empty()

# create a class for testing
mini_me = Test_callback_registry()
mini_me = cls()

# test that we can add a callback
cid1 = self.connect(self.signal, mini_me.dummy, pickle)
Expand All @@ -260,12 +271,13 @@ def test_callback_disconnect(self, pickle):
self.is_empty()

@pytest.mark.parametrize('pickle', [True, False])
def test_callback_wrong_disconnect(self, pickle):
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
def test_callback_wrong_disconnect(self, pickle, cls):
# ensure we start with an empty registry
self.is_empty()

# create a class for testing
mini_me = Test_callback_registry()
mini_me = cls()

# test that we can add a callback
cid1 = self.connect(self.signal, mini_me.dummy, pickle)
Expand All @@ -278,20 +290,21 @@ def test_callback_wrong_disconnect(self, pickle):
self.is_not_empty()

@pytest.mark.parametrize('pickle', [True, False])
def test_registration_on_non_empty_registry(self, pickle):
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
def test_registration_on_non_empty_registry(self, pickle, cls):
# ensure we start with an empty registry
self.is_empty()

# setup the registry with a callback
mini_me = Test_callback_registry()
mini_me = cls()
self.connect(self.signal, mini_me.dummy, pickle)

# Add another callback
mini_me2 = Test_callback_registry()
mini_me2 = cls()
self.connect(self.signal, mini_me2.dummy, pickle)

# Remove and add the second callback
mini_me2 = Test_callback_registry()
mini_me2 = cls()
self.connect(self.signal, mini_me2.dummy, pickle)

# We still have 2 references
Expand All @@ -303,9 +316,6 @@ def test_registration_on_non_empty_registry(self, pickle):
mini_me2 = None
self.is_empty()

def dummy(self):
pass

def test_pickling(self):
assert hasattr(pickle.loads(pickle.dumps(cbook.CallbackRegistry())),
"callbacks")
Expand Down
Loading