@@ -783,21 +783,27 @@ class Grouper:
783783 """
784784
785785 def __init__ (self , init = ()):
786- self ._mapping = weakref .WeakKeyDictionary (
787- {x : weakref .WeakSet ([x ]) for x in init })
786+ self ._count = itertools .count ()
787+ # For each item, we store (order_in_which_item_was_seen, group_of_item), which
788+ # lets __iter__ and get_siblings return items in the order in which they have
789+ # been seen.
790+ self ._mapping = weakref .WeakKeyDictionary ()
791+ for x in init :
792+ if x not in self ._mapping :
793+ self ._mapping [x ] = (next (self ._count ), weakref .WeakSet ([x ]))
788794
789795 def __getstate__ (self ):
790796 return {
791797 ** vars (self ),
792798 # Convert weak refs to strong ones.
793- "_mapping" : {k : set (v ) for k , v in self ._mapping .items ()},
799+ "_mapping" : {k : ( i , set (v )) for k , ( i , v ) in self ._mapping .items ()},
794800 }
795801
796802 def __setstate__ (self , state ):
797803 vars (self ).update (state )
798804 # Convert strong refs to weak ones.
799805 self ._mapping = weakref .WeakKeyDictionary (
800- {k : weakref .WeakSet (v ) for k , v in self ._mapping .items ()})
806+ {k : ( i , weakref .WeakSet (v )) for k , ( i , v ) in self ._mapping .items ()})
801807
802808 def __contains__ (self , item ):
803809 return item in self ._mapping
@@ -810,24 +816,30 @@ def join(self, a, *args):
810816 """
811817 Join given arguments into the same set. Accepts one or more arguments.
812818 """
813- mapping = self ._mapping
814- set_a = mapping .setdefault (a , weakref .WeakSet ([a ]))
815-
816- for arg in args :
817- set_b = mapping .get (arg , weakref .WeakSet ([arg ]))
819+ m = self ._mapping
820+ try :
821+ _ , set_a = m [a ]
822+ except KeyError :
823+ _ , set_a = m .setdefault (a , (next (self ._count ), weakref .WeakSet ([a ])))
824+ for b in args :
825+ try :
826+ _ , set_b = m [b ]
827+ except KeyError :
828+ _ , set_b = m .setdefault (b , (next (self ._count ), weakref .WeakSet ([b ])))
818829 if set_b is not set_a :
819830 if len (set_b ) > len (set_a ):
820831 set_a , set_b = set_b , set_a
821832 set_a .update (set_b )
822833 for elem in set_b :
823- mapping [elem ] = set_a
834+ i , _ = m [elem ]
835+ m [elem ] = (i , set_a )
824836
825837 def joined (self , a , b ):
826838 """Return whether *a* and *b* are members of the same set."""
827- return (self ._mapping .get (a , object ()) is self ._mapping .get (b ))
839+ return (self ._mapping .get (a , ( None , object ()))[ 1 ] is self ._mapping .get (b )[ 1 ] )
828840
829841 def remove (self , a ):
830- set_a = self ._mapping .pop (a , None )
842+ _ , set_a = self ._mapping .pop (a , ( None , None ) )
831843 if set_a :
832844 set_a .remove (a )
833845
@@ -837,14 +849,14 @@ def __iter__(self):
837849
838850 The iterator is invalid if interleaved with calls to join().
839851 """
840- unique_groups = {id (group ): group for group in self ._mapping .values ()}
852+ unique_groups = {id (group ): group for _ , group in self ._mapping .values ()}
841853 for group in unique_groups .values ():
842- yield [ x for x in group ]
854+ yield sorted ( group , key = self . _mapping . __getitem__ )
843855
844856 def get_siblings (self , a ):
845857 """Return all of the items joined with *a*, including itself."""
846- siblings = self ._mapping .get (a , [a ])
847- return [ x for x in siblings ]
858+ _ , siblings = self ._mapping .get (a , ( None , [a ]) )
859+ return sorted ( siblings , key = self . _mapping . get )
848860
849861
850862class GrouperView :
0 commit comments