@@ -783,21 +783,27 @@ class Grouper:
783
783
"""
784
784
785
785
def __init__ (self , init = ()):
786
- self ._mapping = weakref .WeakKeyDictionary (
787
- {x : weakref .WeakSet ([x ]) for x in init })
786
+ self ._count = itertools .count ()
787
+ # For each item, we store (order_in_which_item_was_seen, group_of_item), which
788
+ # lets __iter__ and get_siblings return items in the order in which they have
789
+ # been seen.
790
+ self ._mapping = weakref .WeakKeyDictionary ()
791
+ for x in init :
792
+ if x not in self ._mapping :
793
+ self ._mapping [x ] = (next (self ._count ), weakref .WeakSet ([x ]))
788
794
789
795
def __getstate__ (self ):
790
796
return {
791
797
** vars (self ),
792
798
# Convert weak refs to strong ones.
793
- "_mapping" : {k : set (v ) for k , v in self ._mapping .items ()},
799
+ "_mapping" : {k : ( i , set (v )) for k , ( i , v ) in self ._mapping .items ()},
794
800
}
795
801
796
802
def __setstate__ (self , state ):
797
803
vars (self ).update (state )
798
804
# Convert strong refs to weak ones.
799
805
self ._mapping = weakref .WeakKeyDictionary (
800
- {k : weakref .WeakSet (v ) for k , v in self ._mapping .items ()})
806
+ {k : ( i , weakref .WeakSet (v )) for k , ( i , v ) in self ._mapping .items ()})
801
807
802
808
def __contains__ (self , item ):
803
809
return item in self ._mapping
@@ -810,24 +816,30 @@ def join(self, a, *args):
810
816
"""
811
817
Join given arguments into the same set. Accepts one or more arguments.
812
818
"""
813
- mapping = self ._mapping
814
- set_a = mapping .setdefault (a , weakref .WeakSet ([a ]))
815
-
816
- for arg in args :
817
- set_b = mapping .get (arg , weakref .WeakSet ([arg ]))
819
+ m = self ._mapping
820
+ try :
821
+ _ , set_a = m [a ]
822
+ except KeyError :
823
+ _ , set_a = m .setdefault (a , (next (self ._count ), weakref .WeakSet ([a ])))
824
+ for b in args :
825
+ try :
826
+ _ , set_b = m [b ]
827
+ except KeyError :
828
+ _ , set_b = m .setdefault (b , (next (self ._count ), weakref .WeakSet ([b ])))
818
829
if set_b is not set_a :
819
830
if len (set_b ) > len (set_a ):
820
831
set_a , set_b = set_b , set_a
821
832
set_a .update (set_b )
822
833
for elem in set_b :
823
- mapping [elem ] = set_a
834
+ i , _ = m [elem ]
835
+ m [elem ] = (i , set_a )
824
836
825
837
def joined (self , a , b ):
826
838
"""Return whether *a* and *b* are members of the same set."""
827
- return (self ._mapping .get (a , object ()) is self ._mapping .get (b ))
839
+ return (self ._mapping .get (a , ( None , object ()))[ 1 ] is self ._mapping .get (b )[ 1 ] )
828
840
829
841
def remove (self , a ):
830
- set_a = self ._mapping .pop (a , None )
842
+ _ , set_a = self ._mapping .pop (a , ( None , None ) )
831
843
if set_a :
832
844
set_a .remove (a )
833
845
@@ -837,14 +849,14 @@ def __iter__(self):
837
849
838
850
The iterator is invalid if interleaved with calls to join().
839
851
"""
840
- unique_groups = {id (group ): group for group in self ._mapping .values ()}
852
+ unique_groups = {id (group ): group for _ , group in self ._mapping .values ()}
841
853
for group in unique_groups .values ():
842
- yield [ x for x in group ]
854
+ yield sorted ( group , key = self . _mapping . __getitem__ )
843
855
844
856
def get_siblings (self , a ):
845
857
"""Return all of the items joined with *a*, including itself."""
846
- siblings = self ._mapping .get (a , [a ])
847
- return [ x for x in siblings ]
858
+ _ , siblings = self ._mapping .get (a , ( None , [a ]) )
859
+ return sorted ( siblings , key = self . _mapping . get )
848
860
849
861
850
862
class GrouperView :
0 commit comments