Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9675122

Browse files
authored
Support unhashable callbacks in CallbackRegistry (#26013)
* Record the connected signal in CallbackRegistry weakref cleanup function. ... to remove the need to loop over all signals in _remove_proxy. * Flatten CallbackRegistry._func_cid_map. It is easier to manipulate a flat (signal, proxy) -> cid map rather than a nested signal -> (proxy -> cid) map. * Support unhashable callbacks in CallbackRegistry. ... by replacing _func_cid_map by a dict-like structure (_UnhashDict) that also supports unhashable entries. Note that _func_cid_map (and thus _UnhashDict) can be dropped if we get rid of proxy deduplication in CallbackRegistry.
1 parent d4ea011 commit 9675122

File tree

2 files changed

+109
-59
lines changed

2 files changed

+109
-59
lines changed

lib/matplotlib/cbook.py

Lines changed: 82 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,61 @@ def _weak_or_strong_ref(func, callback):
117117
return _StrongRef(func)
118118

119119

120+
class _UnhashDict:
121+
"""
122+
A minimal dict-like class that also supports unhashable keys, storing them
123+
in a list of key-value pairs.
124+
125+
This class only implements the interface needed for `CallbackRegistry`, and
126+
tries to minimize the overhead for the hashable case.
127+
"""
128+
129+
def __init__(self, pairs):
130+
self._dict = {}
131+
self._pairs = []
132+
for k, v in pairs:
133+
self[k] = v
134+
135+
def __setitem__(self, key, value):
136+
try:
137+
self._dict[key] = value
138+
except TypeError:
139+
for i, (k, v) in enumerate(self._pairs):
140+
if k == key:
141+
self._pairs[i] = (key, value)
142+
break
143+
else:
144+
self._pairs.append((key, value))
145+
146+
def __getitem__(self, key):
147+
try:
148+
return self._dict[key]
149+
except TypeError:
150+
pass
151+
for k, v in self._pairs:
152+
if k == key:
153+
return v
154+
raise KeyError(key)
155+
156+
def pop(self, key, *args):
157+
try:
158+
if key in self._dict:
159+
return self._dict.pop(key)
160+
except TypeError:
161+
for i, (k, v) in enumerate(self._pairs):
162+
if k == key:
163+
del self._pairs[i]
164+
return v
165+
if args:
166+
return args[0]
167+
raise KeyError(key)
168+
169+
def __iter__(self):
170+
yield from self._dict
171+
for k, v in self._pairs:
172+
yield k
173+
174+
120175
class CallbackRegistry:
121176
"""
122177
Handle registering, processing, blocking, and disconnecting
@@ -176,14 +231,14 @@ class CallbackRegistry:
176231

177232
# We maintain two mappings:
178233
# callbacks: signal -> {cid -> weakref-to-callback}
179-
# _func_cid_map: signal -> {weakref-to-callback -> cid}
234+
# _func_cid_map: {(signal, weakref-to-callback) -> cid}
180235

181236
def __init__(self, exception_handler=_exception_printer, *, signals=None):
182237
self._signals = None if signals is None else list(signals) # Copy it.
183238
self.exception_handler = exception_handler
184239
self.callbacks = {}
185240
self._cid_gen = itertools.count()
186-
self._func_cid_map = {}
241+
self._func_cid_map = _UnhashDict([])
187242
# A hidden variable that marks cids that need to be pickled.
188243
self._pickled_cids = set()
189244

@@ -204,27 +259,25 @@ def __setstate__(self, state):
204259
cid_count = state.pop('_cid_gen')
205260
vars(self).update(state)
206261
self.callbacks = {
207-
s: {cid: _weak_or_strong_ref(func, self._remove_proxy)
262+
s: {cid: _weak_or_strong_ref(func, functools.partial(self._remove_proxy, s))
208263
for cid, func in d.items()}
209264
for s, d in self.callbacks.items()}
210-
self._func_cid_map = {
211-
s: {proxy: cid for cid, proxy in d.items()}
212-
for s, d in self.callbacks.items()}
265+
self._func_cid_map = _UnhashDict(
266+
((s, proxy), cid)
267+
for s, d in self.callbacks.items() for cid, proxy in d.items())
213268
self._cid_gen = itertools.count(cid_count)
214269

215270
def connect(self, signal, func):
216271
"""Register *func* to be called when signal *signal* is generated."""
217272
if self._signals is not None:
218273
_api.check_in_list(self._signals, signal=signal)
219-
self._func_cid_map.setdefault(signal, {})
220-
proxy = _weak_or_strong_ref(func, self._remove_proxy)
221-
if proxy in self._func_cid_map[signal]:
222-
return self._func_cid_map[signal][proxy]
223-
cid = next(self._cid_gen)
224-
self._func_cid_map[signal][proxy] = cid
225-
self.callbacks.setdefault(signal, {})
226-
self.callbacks[signal][cid] = proxy
227-
return cid
274+
proxy = _weak_or_strong_ref(func, functools.partial(self._remove_proxy, signal))
275+
try:
276+
return self._func_cid_map[signal, proxy]
277+
except KeyError:
278+
cid = self._func_cid_map[signal, proxy] = next(self._cid_gen)
279+
self.callbacks.setdefault(signal, {})[cid] = proxy
280+
return cid
228281

229282
def _connect_picklable(self, signal, func):
230283
"""
@@ -238,23 +291,18 @@ def _connect_picklable(self, signal, func):
238291

239292
# Keep a reference to sys.is_finalizing, as sys may have been cleared out
240293
# at that point.
241-
def _remove_proxy(self, proxy, *, _is_finalizing=sys.is_finalizing):
294+
def _remove_proxy(self, signal, proxy, *, _is_finalizing=sys.is_finalizing):
242295
if _is_finalizing():
243296
# Weakrefs can't be properly torn down at that point anymore.
244297
return
245-
for signal, proxy_to_cid in list(self._func_cid_map.items()):
246-
cid = proxy_to_cid.pop(proxy, None)
247-
if cid is not None:
248-
del self.callbacks[signal][cid]
249-
self._pickled_cids.discard(cid)
250-
break
251-
else:
252-
# Not found
298+
cid = self._func_cid_map.pop((signal, proxy), None)
299+
if cid is not None:
300+
del self.callbacks[signal][cid]
301+
self._pickled_cids.discard(cid)
302+
else: # Not found
253303
return
254-
# Clean up empty dicts
255-
if len(self.callbacks[signal]) == 0:
304+
if len(self.callbacks[signal]) == 0: # Clean up empty dicts
256305
del self.callbacks[signal]
257-
del self._func_cid_map[signal]
258306

259307
def disconnect(self, cid):
260308
"""
@@ -263,24 +311,16 @@ def disconnect(self, cid):
263311
No error is raised if such a callback does not exist.
264312
"""
265313
self._pickled_cids.discard(cid)
266-
# Clean up callbacks
267-
for signal, cid_to_proxy in list(self.callbacks.items()):
268-
proxy = cid_to_proxy.pop(cid, None)
269-
if proxy is not None:
314+
for signal, proxy in self._func_cid_map:
315+
if self._func_cid_map[signal, proxy] == cid:
270316
break
271-
else:
272-
# Not found
317+
else: # Not found
273318
return
274-
275-
proxy_to_cid = self._func_cid_map[signal]
276-
for current_proxy, current_cid in list(proxy_to_cid.items()):
277-
if current_cid == cid:
278-
assert proxy is current_proxy
279-
del proxy_to_cid[current_proxy]
280-
# Clean up empty dicts
281-
if len(self.callbacks[signal]) == 0:
319+
assert self.callbacks[signal][cid] == proxy
320+
del self.callbacks[signal][cid]
321+
self._func_cid_map.pop((signal, proxy))
322+
if len(self.callbacks[signal]) == 0: # Clean up empty dicts
282323
del self.callbacks[signal]
283-
del self._func_cid_map[signal]
284324

285325
def process(self, s, *args, **kwargs):
286326
"""

lib/matplotlib/tests/test_cbook.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,15 @@ def test_boxplot_stats_autorange_false(self):
181181
assert_array_almost_equal(bstats_true[0]['fliers'], [])
182182

183183

184+
class Hashable:
185+
def dummy(self): pass
186+
187+
188+
class Unhashable:
189+
__hash__ = None # type: ignore
190+
def dummy(self): pass
191+
192+
184193
class Test_callback_registry:
185194
def setup_method(self):
186195
self.signal = 'test'
@@ -196,20 +205,20 @@ def disconnect(self, cid):
196205
return self.callbacks.disconnect(cid)
197206

198207
def count(self):
199-
count1 = len(self.callbacks._func_cid_map.get(self.signal, []))
208+
count1 = sum(s == self.signal for s, p in self.callbacks._func_cid_map)
200209
count2 = len(self.callbacks.callbacks.get(self.signal))
201210
assert count1 == count2
202211
return count1
203212

204213
def is_empty(self):
205214
np.testing.break_cycles()
206-
assert self.callbacks._func_cid_map == {}
215+
assert [*self.callbacks._func_cid_map] == []
207216
assert self.callbacks.callbacks == {}
208217
assert self.callbacks._pickled_cids == set()
209218

210219
def is_not_empty(self):
211220
np.testing.break_cycles()
212-
assert self.callbacks._func_cid_map != {}
221+
assert [*self.callbacks._func_cid_map] != []
213222
assert self.callbacks.callbacks != {}
214223

215224
def test_cid_restore(self):
@@ -220,12 +229,13 @@ def test_cid_restore(self):
220229
assert cid == 1
221230

222231
@pytest.mark.parametrize('pickle', [True, False])
223-
def test_callback_complete(self, pickle):
232+
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
233+
def test_callback_complete(self, pickle, cls):
224234
# ensure we start with an empty registry
225235
self.is_empty()
226236

227237
# create a class for testing
228-
mini_me = Test_callback_registry()
238+
mini_me = cls()
229239

230240
# test that we can add a callback
231241
cid1 = self.connect(self.signal, mini_me.dummy, pickle)
@@ -236,7 +246,7 @@ def test_callback_complete(self, pickle):
236246
cid2 = self.connect(self.signal, mini_me.dummy, pickle)
237247
assert cid1 == cid2
238248
self.is_not_empty()
239-
assert len(self.callbacks._func_cid_map) == 1
249+
assert len([*self.callbacks._func_cid_map]) == 1
240250
assert len(self.callbacks.callbacks) == 1
241251

242252
del mini_me
@@ -245,12 +255,13 @@ def test_callback_complete(self, pickle):
245255
self.is_empty()
246256

247257
@pytest.mark.parametrize('pickle', [True, False])
248-
def test_callback_disconnect(self, pickle):
258+
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
259+
def test_callback_disconnect(self, pickle, cls):
249260
# ensure we start with an empty registry
250261
self.is_empty()
251262

252263
# create a class for testing
253-
mini_me = Test_callback_registry()
264+
mini_me = cls()
254265

255266
# test that we can add a callback
256267
cid1 = self.connect(self.signal, mini_me.dummy, pickle)
@@ -263,12 +274,13 @@ def test_callback_disconnect(self, pickle):
263274
self.is_empty()
264275

265276
@pytest.mark.parametrize('pickle', [True, False])
266-
def test_callback_wrong_disconnect(self, pickle):
277+
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
278+
def test_callback_wrong_disconnect(self, pickle, cls):
267279
# ensure we start with an empty registry
268280
self.is_empty()
269281

270282
# create a class for testing
271-
mini_me = Test_callback_registry()
283+
mini_me = cls()
272284

273285
# test that we can add a callback
274286
cid1 = self.connect(self.signal, mini_me.dummy, pickle)
@@ -281,20 +293,21 @@ def test_callback_wrong_disconnect(self, pickle):
281293
self.is_not_empty()
282294

283295
@pytest.mark.parametrize('pickle', [True, False])
284-
def test_registration_on_non_empty_registry(self, pickle):
296+
@pytest.mark.parametrize('cls', [Hashable, Unhashable])
297+
def test_registration_on_non_empty_registry(self, pickle, cls):
285298
# ensure we start with an empty registry
286299
self.is_empty()
287300

288301
# setup the registry with a callback
289-
mini_me = Test_callback_registry()
302+
mini_me = cls()
290303
self.connect(self.signal, mini_me.dummy, pickle)
291304

292305
# Add another callback
293-
mini_me2 = Test_callback_registry()
306+
mini_me2 = cls()
294307
self.connect(self.signal, mini_me2.dummy, pickle)
295308

296309
# Remove and add the second callback
297-
mini_me2 = Test_callback_registry()
310+
mini_me2 = cls()
298311
self.connect(self.signal, mini_me2.dummy, pickle)
299312

300313
# We still have 2 references
@@ -306,9 +319,6 @@ def test_registration_on_non_empty_registry(self, pickle):
306319
mini_me2 = None
307320
self.is_empty()
308321

309-
def dummy(self):
310-
pass
311-
312322
def test_pickling(self):
313323
assert hasattr(pickle.loads(pickle.dumps(cbook.CallbackRegistry())),
314324
"callbacks")

0 commit comments

Comments
 (0)