Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 43cbccc

Browse files
committed
Make Grouper return siblings in the order in which they have been seen.
1 parent 6799367 commit 43cbccc

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

lib/matplotlib/cbook.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -847,21 +847,29 @@ class Grouper:
847847
"""
848848

849849
def __init__(self, init=()):
850-
self._mapping = weakref.WeakKeyDictionary(
851-
{x: weakref.WeakSet([x]) for x in init})
850+
self._count = itertools.count()
851+
# For each item, we store (order_in_which_item_was_seen, group_of_item), which
852+
# lets __iter__ and get_siblings return items in the order in which they have
853+
# been seen.
854+
self._mapping = weakref.WeakKeyDictionary()
855+
for x in init:
856+
if x not in self._mapping:
857+
self._mapping[x] = (next(self._count), weakref.WeakSet([x]))
852858

853859
def __getstate__(self):
854860
return {
855861
**vars(self),
856-
# Convert weak refs to strong ones.
857-
"_mapping": {k: set(v) for k, v in self._mapping.items()},
862+
# Convert weak refs to strong ones, and counter to index.
863+
"_mapping": {k: (i, set(v)) for k, (i, v) in self._mapping.items()},
864+
"_count": next(self._count),
858865
}
859866

860867
def __setstate__(self, state):
861868
vars(self).update(state)
862-
# Convert strong refs to weak ones.
869+
# Convert strong refs to weak ones, and index to counter.
863870
self._mapping = weakref.WeakKeyDictionary(
864-
{k: weakref.WeakSet(v) for k, v in self._mapping.items()})
871+
{k: (i, weakref.WeakSet(v)) for k, (i, v) in self._mapping.items()})
872+
self._count = itertools.count(self._count)
865873

866874
def __contains__(self, item):
867875
return item in self._mapping
@@ -874,25 +882,32 @@ def join(self, a, *args):
874882
"""
875883
Join given arguments into the same set. Accepts one or more arguments.
876884
"""
877-
mapping = self._mapping
878-
set_a = mapping.setdefault(a, weakref.WeakSet([a]))
879-
880-
for arg in args:
881-
set_b = mapping.get(arg, weakref.WeakSet([arg]))
885+
m = self._mapping
886+
try:
887+
_, set_a = m[a]
888+
except KeyError:
889+
_, set_a = m[a] = (next(self._count), weakref.WeakSet([a]))
890+
for b in args:
891+
try:
892+
_, set_b = m[b]
893+
except KeyError:
894+
_, set_b = m[b] = (next(self._count), weakref.WeakSet([b]))
882895
if set_b is not set_a:
883896
if len(set_b) > len(set_a):
884897
set_a, set_b = set_b, set_a
885898
set_a.update(set_b)
886899
for elem in set_b:
887-
mapping[elem] = set_a
900+
i, _ = m[elem]
901+
m[elem] = (i, set_a)
888902

889903
def joined(self, a, b):
890904
"""Return whether *a* and *b* are members of the same set."""
891-
return (self._mapping.get(a, object()) is self._mapping.get(b))
905+
return (self._mapping.get(a, (None, object()))[1]
906+
is self._mapping.get(b, (None, object()))[1])
892907

893908
def remove(self, a):
894909
"""Remove *a* from the grouper, doing nothing if it is not there."""
895-
set_a = self._mapping.pop(a, None)
910+
_, set_a = self._mapping.pop(a, (None, None))
896911
if set_a:
897912
set_a.remove(a)
898913

@@ -902,14 +917,14 @@ def __iter__(self):
902917
903918
The iterator is invalid if interleaved with calls to join().
904919
"""
905-
unique_groups = {id(group): group for group in self._mapping.values()}
920+
unique_groups = {id(group): group for _, group in self._mapping.values()}
906921
for group in unique_groups.values():
907-
yield [x for x in group]
922+
yield sorted(group, key=self._mapping.__getitem__)
908923

909924
def get_siblings(self, a):
910925
"""Return all of the items joined with *a*, including itself."""
911-
siblings = self._mapping.get(a, [a])
912-
return [x for x in siblings]
926+
_, siblings = self._mapping.get(a, (None, [a]))
927+
return sorted(siblings, key=self._mapping.get)
913928

914929

915930
class GrouperView:

lib/matplotlib/tests/test_cbook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,9 @@ class Dummy:
603603
for o in objs:
604604
assert o in mapping
605605

606-
base_set = mapping[objs[0]]
606+
_, base_set = mapping[objs[0]]
607607
for o in objs[1:]:
608-
assert mapping[o] is base_set
608+
assert mapping[o][1] is base_set
609609

610610

611611
def test_flatiter():

0 commit comments

Comments
 (0)