@@ -200,10 +200,14 @@ def test_fuzz(self):
200200 except Exception :
201201 pass
202202
203- def test_loads_recursion (self ):
203+ def test_loads_2x_code (self ):
204204 s = b'c' + (b'X' * 4 * 4 ) + b'{' * 2 ** 20
205205 self .assertRaises (ValueError , marshal .loads , s )
206206
207+ def test_loads_recursion (self ):
208+ s = b'c' + (b'X' * 4 * 5 ) + b'{' * 2 ** 20
209+ self .assertRaises (ValueError , marshal .loads , s )
210+
207211 def test_recursion_limit (self ):
208212 # Create a deeply nested structure.
209213 head = last = []
@@ -323,6 +327,122 @@ def test_frozenset(self, size):
323327 def test_bytearray (self , size ):
324328 self .check_unmarshallable (bytearray (size ))
325329
330+ def CollectObjectIDs (ids , obj ):
331+ """Collect object ids seen in a structure"""
332+ if id (obj ) in ids :
333+ return
334+ ids .add (id (obj ))
335+ if isinstance (obj , (list , tuple , set , frozenset )):
336+ for e in obj :
337+ CollectObjectIDs (ids , e )
338+ elif isinstance (obj , dict ):
339+ for k , v in obj .items ():
340+ CollectObjectIDs (ids , k )
341+ CollectObjectIDs (ids , v )
342+ return len (ids )
343+
344+ class InstancingTestCase (unittest .TestCase , HelperMixin ):
345+ intobj = 123321
346+ floatobj = 1.2345
347+ strobj = "abcde" * 3
348+ dictobj = {"hello" :floatobj , "goodbye" :floatobj , floatobj :"hello" }
349+
350+ def helper3 (self , rsample , recursive = False , simple = False ):
351+ #we have two instances
352+ sample = (rsample , rsample )
353+
354+ n0 = CollectObjectIDs (set (), sample )
355+
356+ s3 = marshal .dumps (sample , 3 )
357+ n3 = CollectObjectIDs (set (), marshal .loads (s3 ))
358+
359+ #same number of instances generated
360+ self .assertEqual (n3 , n0 )
361+
362+ if not recursive :
363+ #can compare with version 2
364+ s2 = marshal .dumps (sample , 2 )
365+ n2 = CollectObjectIDs (set (), marshal .loads (s2 ))
366+ #old format generated more instances
367+ self .assertGreater (n2 , n0 )
368+
369+ #if complex objects are in there, old format is larger
370+ if not simple :
371+ self .assertGreater (len (s2 ), len (s3 ))
372+ else :
373+ self .assertGreaterEqual (len (s2 ), len (s3 ))
374+
375+ def testInt (self ):
376+ self .helper (self .intobj )
377+ self .helper3 (self .intobj , simple = True )
378+
379+ def testFloat (self ):
380+ self .helper (self .floatobj )
381+ self .helper3 (self .floatobj )
382+
383+ def testStr (self ):
384+ self .helper (self .strobj )
385+ self .helper3 (self .strobj )
386+
387+ def testDict (self ):
388+ self .helper (self .dictobj )
389+ self .helper3 (self .dictobj )
390+
391+ def testModule (self ):
392+ with open (__file__ , "rb" ) as f :
393+ code = f .read ()
394+ if __file__ .endswith (".py" ):
395+ code = compile (code , __file__ , "exec" )
396+ self .helper (code )
397+ self .helper3 (code )
398+
399+ def testRecursion (self ):
400+ d = dict (self .dictobj )
401+ d ["self" ] = d
402+ self .helper3 (d , recursive = True )
403+ l = [self .dictobj ]
404+ l .append (l )
405+ self .helper3 (l , recursive = True )
406+
407+ class CompatibilityTestCase (unittest .TestCase ):
408+ def _test (self , version ):
409+ with open (__file__ , "rb" ) as f :
410+ code = f .read ()
411+ if __file__ .endswith (".py" ):
412+ code = compile (code , __file__ , "exec" )
413+ data = marshal .dumps (code , version )
414+ marshal .loads (data )
415+
416+ def test0To3 (self ):
417+ self ._test (0 )
418+
419+ def test1To3 (self ):
420+ self ._test (1 )
421+
422+ def test2To3 (self ):
423+ self ._test (2 )
424+
425+ def test3To3 (self ):
426+ self ._test (3 )
427+
428+ class InterningTestCase (unittest .TestCase , HelperMixin ):
429+ strobj = "this is an interned string"
430+ strobj = sys .intern (strobj )
431+
432+ def testIntern (self ):
433+ s = marshal .loads (marshal .dumps (self .strobj ))
434+ self .assertEqual (s , self .strobj )
435+ self .assertEqual (id (s ), id (self .strobj ))
436+ s2 = sys .intern (s )
437+ self .assertEqual (id (s2 ), id (s ))
438+
439+ def testNoIntern (self ):
440+ s = marshal .loads (marshal .dumps (self .strobj , 2 ))
441+ self .assertEqual (s , self .strobj )
442+ self .assertNotEqual (id (s ), id (self .strobj ))
443+ s2 = sys .intern (s )
444+ self .assertNotEqual (id (s2 ), id (s ))
445+
326446
327447def test_main ():
328448 support .run_unittest (IntTestCase ,
0 commit comments