@@ -847,21 +847,27 @@ class Grouper:
847
847
"""
848
848
849
849
def __init__ (self , init = ()):
850
- self ._mapping = weakref .WeakKeyDictionary (
851
- {x : weakref .WeakSet ([x ]) for x in init })
850
+ self ._count = itertools .count ()
851
+ # For each item, we store (order_in_which_item_was_seen, group_of_item), which
852
+ # lets __iter__ and get_siblings return items in the order in which they have
853
+ # been seen.
854
+ self ._mapping = weakref .WeakKeyDictionary ()
855
+ for x in init :
856
+ if x not in self ._mapping :
857
+ self ._mapping [x ] = (next (self ._count ), weakref .WeakSet ([x ]))
852
858
853
859
def __getstate__ (self ):
854
860
return {
855
861
** vars (self ),
856
862
# Convert weak refs to strong ones.
857
- "_mapping" : {k : set (v ) for k , v in self ._mapping .items ()},
863
+ "_mapping" : {k : ( i , set (v )) for k , ( i , v ) in self ._mapping .items ()},
858
864
}
859
865
860
866
def __setstate__ (self , state ):
861
867
vars (self ).update (state )
862
868
# Convert strong refs to weak ones.
863
869
self ._mapping = weakref .WeakKeyDictionary (
864
- {k : weakref .WeakSet (v ) for k , v in self ._mapping .items ()})
870
+ {k : ( i , weakref .WeakSet (v )) for k , ( i , v ) in self ._mapping .items ()})
865
871
866
872
def __contains__ (self , item ):
867
873
return item in self ._mapping
@@ -874,25 +880,32 @@ def join(self, a, *args):
874
880
"""
875
881
Join given arguments into the same set. Accepts one or more arguments.
876
882
"""
877
- mapping = self ._mapping
878
- set_a = mapping .setdefault (a , weakref .WeakSet ([a ]))
879
-
880
- for arg in args :
881
- set_b = mapping .get (arg , weakref .WeakSet ([arg ]))
883
+ m = self ._mapping
884
+ try :
885
+ _ , set_a = m [a ]
886
+ except KeyError :
887
+ _ , set_a = m [a ] = (next (self ._count ), weakref .WeakSet ([a ]))
888
+ for b in args :
889
+ try :
890
+ _ , set_b = m [b ]
891
+ except KeyError :
892
+ _ , set_b = m [b ] = (next (self ._count ), weakref .WeakSet ([b ]))
882
893
if set_b is not set_a :
883
894
if len (set_b ) > len (set_a ):
884
895
set_a , set_b = set_b , set_a
885
896
set_a .update (set_b )
886
897
for elem in set_b :
887
- mapping [elem ] = set_a
898
+ i , _ = m [elem ]
899
+ m [elem ] = (i , set_a )
888
900
889
901
def joined (self , a , b ):
890
902
"""Return whether *a* and *b* are members of the same set."""
891
- return (self ._mapping .get (a , object ()) is self ._mapping .get (b ))
903
+ return (self ._mapping .get (a , (None , object ()))[1 ]
904
+ is self ._mapping .get (b , (None , object ()))[1 ])
892
905
893
906
def remove (self , a ):
894
907
"""Remove *a* from the grouper, doing nothing if it is not there."""
895
- set_a = self ._mapping .pop (a , None )
908
+ _ , set_a = self ._mapping .pop (a , ( None , None ) )
896
909
if set_a :
897
910
set_a .remove (a )
898
911
@@ -902,14 +915,14 @@ def __iter__(self):
902
915
903
916
The iterator is invalid if interleaved with calls to join().
904
917
"""
905
- unique_groups = {id (group ): group for group in self ._mapping .values ()}
918
+ unique_groups = {id (group ): group for _ , group in self ._mapping .values ()}
906
919
for group in unique_groups .values ():
907
- yield [ x for x in group ]
920
+ yield sorted ( group , key = self . _mapping . __getitem__ )
908
921
909
922
def get_siblings (self , a ):
910
923
"""Return all of the items joined with *a*, including itself."""
911
- siblings = self ._mapping .get (a , [a ])
912
- return [ x for x in siblings ]
924
+ _ , siblings = self ._mapping .get (a , ( None , [a ]) )
925
+ return sorted ( siblings , key = self . _mapping . get )
913
926
914
927
915
928
class GrouperView :
0 commit comments