Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 46e66a7

Browse files
committed
Make Grouper return siblings in the order in which they have been seen.
1 parent bba391d commit 46e66a7

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

lib/matplotlib/cbook.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -846,21 +846,27 @@ class Grouper:
846846
"""
847847

848848
def __init__(self, init=()):
849-
self._mapping = weakref.WeakKeyDictionary(
850-
{x: weakref.WeakSet([x]) for x in init})
849+
self._count = itertools.count()
850+
# For each item, we store (order_in_which_item_was_seen, group_of_item), which
851+
# lets __iter__ and get_siblings return items in the order in which they have
852+
# been seen.
853+
self._mapping = weakref.WeakKeyDictionary()
854+
for x in init:
855+
if x not in self._mapping:
856+
self._mapping[x] = (next(self._count), weakref.WeakSet([x]))
851857

852858
def __getstate__(self):
853859
return {
854860
**vars(self),
855861
# Convert weak refs to strong ones.
856-
"_mapping": {k: set(v) for k, v in self._mapping.items()},
862+
"_mapping": {k: (i, set(v)) for k, (i, v) in self._mapping.items()},
857863
}
858864

859865
def __setstate__(self, state):
860866
vars(self).update(state)
861867
# Convert strong refs to weak ones.
862868
self._mapping = weakref.WeakKeyDictionary(
863-
{k: weakref.WeakSet(v) for k, v in self._mapping.items()})
869+
{k: (i, weakref.WeakSet(v)) for k, (i, v) in self._mapping.items()})
864870

865871
def __contains__(self, item):
866872
return item in self._mapping
@@ -873,25 +879,32 @@ def join(self, a, *args):
873879
"""
874880
Join given arguments into the same set. Accepts one or more arguments.
875881
"""
876-
mapping = self._mapping
877-
set_a = mapping.setdefault(a, weakref.WeakSet([a]))
878-
879-
for arg in args:
880-
set_b = mapping.get(arg, weakref.WeakSet([arg]))
882+
m = self._mapping
883+
try:
884+
_, set_a = m[a]
885+
except KeyError:
886+
_, set_a = m[a] = (next(self._count), weakref.WeakSet([a]))
887+
for b in args:
888+
try:
889+
_, set_b = m[b]
890+
except KeyError:
891+
_, set_b = m[b] = (next(self._count), weakref.WeakSet([b]))
881892
if set_b is not set_a:
882893
if len(set_b) > len(set_a):
883894
set_a, set_b = set_b, set_a
884895
set_a.update(set_b)
885896
for elem in set_b:
886-
mapping[elem] = set_a
897+
i, _ = m[elem]
898+
m[elem] = (i, set_a)
887899

888900
def joined(self, a, b):
889901
"""Return whether *a* and *b* are members of the same set."""
890-
return (self._mapping.get(a, object()) is self._mapping.get(b))
902+
return (self._mapping.get(a, (None, object()))[1]
903+
is self._mapping.get(b, (None, object()))[1])
891904

892905
def remove(self, a):
893906
"""Remove *a* from the grouper, doing nothing if it is not there."""
894-
set_a = self._mapping.pop(a, None)
907+
_, set_a = self._mapping.pop(a, (None, None))
895908
if set_a:
896909
set_a.remove(a)
897910

@@ -901,14 +914,14 @@ def __iter__(self):
901914
902915
The iterator is invalid if interleaved with calls to join().
903916
"""
904-
unique_groups = {id(group): group for group in self._mapping.values()}
917+
unique_groups = {id(group): group for _, group in self._mapping.values()}
905918
for group in unique_groups.values():
906-
yield [x for x in group]
919+
yield sorted(group, key=self._mapping.__getitem__)
907920

908921
def get_siblings(self, a):
909922
"""Return all of the items joined with *a*, including itself."""
910-
siblings = self._mapping.get(a, [a])
911-
return [x for x in siblings]
923+
_, siblings = self._mapping.get(a, (None, [a]))
924+
return sorted(siblings, key=self._mapping.get)
912925

913926

914927
class GrouperView:

lib/matplotlib/tests/test_cbook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,9 @@ class Dummy:
601601
for o in objs:
602602
assert o in mapping
603603

604-
base_set = mapping[objs[0]]
604+
_, base_set = mapping[objs[0]]
605605
for o in objs[1:]:
606-
assert mapping[o] is base_set
606+
assert mapping[o][1] is base_set
607607

608608

609609
def test_flatiter():

0 commit comments

Comments
 (0)