@@ -847,21 +847,29 @@ class Grouper:
847
847
"""
848
848
849
849
def __init__ (self , init = ()):
850
- self ._mapping = weakref .WeakKeyDictionary (
851
- {x : weakref .WeakSet ([x ]) for x in init })
850
+ self ._count = itertools .count ()
851
+ # For each item, we store (order_in_which_item_was_seen, group_of_item), which
852
+ # lets __iter__ and get_siblings return items in the order in which they have
853
+ # been seen.
854
+ self ._mapping = weakref .WeakKeyDictionary ()
855
+ for x in init :
856
+ if x not in self ._mapping :
857
+ self ._mapping [x ] = (next (self ._count ), weakref .WeakSet ([x ]))
852
858
853
859
def __getstate__ (self ):
854
860
return {
855
861
** vars (self ),
856
- # Convert weak refs to strong ones.
857
- "_mapping" : {k : set (v ) for k , v in self ._mapping .items ()},
862
+ # Convert weak refs to strong ones, and counter to index.
863
+ "_mapping" : {k : (i , set (v )) for k , (i , v ) in self ._mapping .items ()},
864
+ "_count" : next (self ._count ),
858
865
}
859
866
860
867
def __setstate__ (self , state ):
861
868
vars (self ).update (state )
862
- # Convert strong refs to weak ones.
869
+ # Convert strong refs to weak ones, and index to counter .
863
870
self ._mapping = weakref .WeakKeyDictionary (
864
- {k : weakref .WeakSet (v ) for k , v in self ._mapping .items ()})
871
+ {k : (i , weakref .WeakSet (v )) for k , (i , v ) in self ._mapping .items ()})
872
+ self ._count = itertools .count (self ._count )
865
873
866
874
def __contains__ (self , item ):
867
875
return item in self ._mapping
@@ -874,25 +882,32 @@ def join(self, a, *args):
874
882
"""
875
883
Join given arguments into the same set. Accepts one or more arguments.
876
884
"""
877
- mapping = self ._mapping
878
- set_a = mapping .setdefault (a , weakref .WeakSet ([a ]))
879
-
880
- for arg in args :
881
- set_b = mapping .get (arg , weakref .WeakSet ([arg ]))
885
+ m = self ._mapping
886
+ try :
887
+ _ , set_a = m [a ]
888
+ except KeyError :
889
+ _ , set_a = m [a ] = (next (self ._count ), weakref .WeakSet ([a ]))
890
+ for b in args :
891
+ try :
892
+ _ , set_b = m [b ]
893
+ except KeyError :
894
+ _ , set_b = m [b ] = (next (self ._count ), weakref .WeakSet ([b ]))
882
895
if set_b is not set_a :
883
896
if len (set_b ) > len (set_a ):
884
897
set_a , set_b = set_b , set_a
885
898
set_a .update (set_b )
886
899
for elem in set_b :
887
- mapping [elem ] = set_a
900
+ i , _ = m [elem ]
901
+ m [elem ] = (i , set_a )
888
902
889
903
def joined (self , a , b ):
890
904
"""Return whether *a* and *b* are members of the same set."""
891
- return (self ._mapping .get (a , object ()) is self ._mapping .get (b ))
905
+ return (self ._mapping .get (a , (None , object ()))[1 ]
906
+ is self ._mapping .get (b , (None , object ()))[1 ])
892
907
893
908
def remove (self , a ):
894
909
"""Remove *a* from the grouper, doing nothing if it is not there."""
895
- set_a = self ._mapping .pop (a , None )
910
+ _ , set_a = self ._mapping .pop (a , ( None , None ) )
896
911
if set_a :
897
912
set_a .remove (a )
898
913
@@ -902,14 +917,14 @@ def __iter__(self):
902
917
903
918
The iterator is invalid if interleaved with calls to join().
904
919
"""
905
- unique_groups = {id (group ): group for group in self ._mapping .values ()}
920
+ unique_groups = {id (group ): group for _ , group in self ._mapping .values ()}
906
921
for group in unique_groups .values ():
907
- yield [ x for x in group ]
922
+ yield sorted ( group , key = self . _mapping . __getitem__ )
908
923
909
924
def get_siblings (self , a ):
910
925
"""Return all of the items joined with *a*, including itself."""
911
- siblings = self ._mapping .get (a , [a ])
912
- return [ x for x in siblings ]
926
+ _ , siblings = self ._mapping .get (a , ( None , [a ]) )
927
+ return sorted ( siblings , key = self . _mapping . get )
913
928
914
929
915
930
class GrouperView :
0 commit comments