@@ -846,21 +846,27 @@ class Grouper:
846
846
"""
847
847
848
848
def __init__ (self , init = ()):
849
- self ._mapping = weakref .WeakKeyDictionary (
850
- {x : weakref .WeakSet ([x ]) for x in init })
849
+ self ._count = itertools .count ()
850
+ # For each item, we store (order_in_which_item_was_seen, group_of_item), which
851
+ # lets __iter__ and get_siblings return items in the order in which they have
852
+ # been seen.
853
+ self ._mapping = weakref .WeakKeyDictionary ()
854
+ for x in init :
855
+ if x not in self ._mapping :
856
+ self ._mapping [x ] = (next (self ._count ), weakref .WeakSet ([x ]))
851
857
852
858
def __getstate__ (self ):
853
859
return {
854
860
** vars (self ),
855
861
# Convert weak refs to strong ones.
856
- "_mapping" : {k : set (v ) for k , v in self ._mapping .items ()},
862
+ "_mapping" : {k : ( i , set (v )) for k , ( i , v ) in self ._mapping .items ()},
857
863
}
858
864
859
865
def __setstate__ (self , state ):
860
866
vars (self ).update (state )
861
867
# Convert strong refs to weak ones.
862
868
self ._mapping = weakref .WeakKeyDictionary (
863
- {k : weakref .WeakSet (v ) for k , v in self ._mapping .items ()})
869
+ {k : ( i , weakref .WeakSet (v )) for k , ( i , v ) in self ._mapping .items ()})
864
870
865
871
def __contains__ (self , item ):
866
872
return item in self ._mapping
@@ -873,25 +879,32 @@ def join(self, a, *args):
873
879
"""
874
880
Join given arguments into the same set. Accepts one or more arguments.
875
881
"""
876
- mapping = self ._mapping
877
- set_a = mapping .setdefault (a , weakref .WeakSet ([a ]))
878
-
879
- for arg in args :
880
- set_b = mapping .get (arg , weakref .WeakSet ([arg ]))
882
+ m = self ._mapping
883
+ try :
884
+ _ , set_a = m [a ]
885
+ except KeyError :
886
+ _ , set_a = m [a ] = (next (self ._count ), weakref .WeakSet ([a ]))
887
+ for b in args :
888
+ try :
889
+ _ , set_b = m [b ]
890
+ except KeyError :
891
+ _ , set_b = m [b ] = (next (self ._count ), weakref .WeakSet ([b ]))
881
892
if set_b is not set_a :
882
893
if len (set_b ) > len (set_a ):
883
894
set_a , set_b = set_b , set_a
884
895
set_a .update (set_b )
885
896
for elem in set_b :
886
- mapping [elem ] = set_a
897
+ i , _ = m [elem ]
898
+ m [elem ] = (i , set_a )
887
899
888
900
def joined (self , a , b ):
889
901
"""Return whether *a* and *b* are members of the same set."""
890
- return (self ._mapping .get (a , object ()) is self ._mapping .get (b ))
902
+ return (self ._mapping .get (a , (None , object ()))[1 ]
903
+ is self ._mapping .get (b , (None , object ()))[1 ])
891
904
892
905
def remove (self , a ):
893
906
"""Remove *a* from the grouper, doing nothing if it is not there."""
894
- set_a = self ._mapping .pop (a , None )
907
+ _ , set_a = self ._mapping .pop (a , ( None , None ) )
895
908
if set_a :
896
909
set_a .remove (a )
897
910
@@ -901,14 +914,14 @@ def __iter__(self):
901
914
902
915
The iterator is invalid if interleaved with calls to join().
903
916
"""
904
- unique_groups = {id (group ): group for group in self ._mapping .values ()}
917
+ unique_groups = {id (group ): group for _ , group in self ._mapping .values ()}
905
918
for group in unique_groups .values ():
906
- yield [ x for x in group ]
919
+ yield sorted ( group , key = self . _mapping . __getitem__ )
907
920
908
921
def get_siblings (self , a ):
909
922
"""Return all of the items joined with *a*, including itself."""
910
- siblings = self ._mapping .get (a , [a ])
911
- return [ x for x in siblings ]
923
+ _ , siblings = self ._mapping .get (a , ( None , [a ]) )
924
+ return sorted ( siblings , key = self . _mapping . get )
912
925
913
926
914
927
class GrouperView :
0 commit comments