77import unittest
88import re
99import contextlib
10+ import functools
1011import operator
1112import ipaddress
1213
@@ -528,6 +529,20 @@ def test_ip_network(self):
528529 self .assertFactoryError (ipaddress .ip_network , "network" )
529530
530531
532+ @functools .total_ordering
533+ class LargestObject :
534+ def __eq__ (self , other ):
535+ return isinstance (other , LargestObject )
536+ def __lt__ (self , other ):
537+ return False
538+
539+ @functools .total_ordering
540+ class SmallestObject :
541+ def __eq__ (self , other ):
542+ return isinstance (other , SmallestObject )
543+ def __gt__ (self , other ):
544+ return False
545+
531546class ComparisonTests (unittest .TestCase ):
532547
533548 v4addr = ipaddress .IPv4Address (1 )
@@ -581,6 +596,28 @@ def test_mixed_type_ordering(self):
581596 self .assertRaises (TypeError , lambda : lhs <= rhs )
582597 self .assertRaises (TypeError , lambda : lhs >= rhs )
583598
599+ def test_foreign_type_ordering (self ):
600+ other = object ()
601+ smallest = SmallestObject ()
602+ largest = LargestObject ()
603+ for obj in self .objects :
604+ with self .assertRaises (TypeError ):
605+ obj < other
606+ with self .assertRaises (TypeError ):
607+ obj > other
608+ with self .assertRaises (TypeError ):
609+ obj <= other
610+ with self .assertRaises (TypeError ):
611+ obj >= other
612+ self .assertTrue (obj < largest )
613+ self .assertFalse (obj > largest )
614+ self .assertTrue (obj <= largest )
615+ self .assertFalse (obj >= largest )
616+ self .assertFalse (obj < smallest )
617+ self .assertTrue (obj > smallest )
618+ self .assertFalse (obj <= smallest )
619+ self .assertTrue (obj >= smallest )
620+
584621 def test_mixed_type_key (self ):
585622 # with get_mixed_type_key, you can sort addresses and network.
586623 v4_ordered = [self .v4addr , self .v4net , self .v4intf ]
@@ -601,7 +638,7 @@ def test_incompatible_versions(self):
601638 v4addr = ipaddress .ip_address ('1.1.1.1' )
602639 v4net = ipaddress .ip_network ('1.1.1.1' )
603640 v6addr = ipaddress .ip_address ('::1' )
604- v6net = ipaddress .ip_address ('::1' )
641+ v6net = ipaddress .ip_network ('::1' )
605642
606643 self .assertRaises (TypeError , v4addr .__lt__ , v6addr )
607644 self .assertRaises (TypeError , v4addr .__gt__ , v6addr )
@@ -1248,10 +1285,10 @@ def testNetworkComparison(self):
12481285 unsorted = [ip4 , ip1 , ip3 , ip2 ]
12491286 unsorted .sort ()
12501287 self .assertEqual (sorted , unsorted )
1251- self .assertRaises ( TypeError , ip1 .__lt__ ,
1252- ipaddress . ip_address ( '10.10.10.0' ) )
1253- self .assertRaises ( TypeError , ip2 .__lt__ ,
1254- ipaddress . ip_address ( '10.10.10.0' ) )
1288+ self .assertIs ( ip1 .__lt__ ( ipaddress . ip_address ( '10.10.10.0' )) ,
1289+ NotImplemented )
1290+ self .assertIs ( ip2 .__lt__ ( ipaddress . ip_address ( '10.10.10.0' )) ,
1291+ NotImplemented )
12551292
12561293 # <=, >=
12571294 self .assertTrue (ipaddress .ip_network ('1.1.1.1' ) <=
0 commit comments