@@ -396,6 +396,115 @@ def test_compress_empty(self):
396396 c = ZstdCompressor ()
397397 self .assertNotEqual (c .compress (b'' , c .FLUSH_FRAME ), b'' )
398398
399+ def test_set_pledged_input_size (self ):
400+ DAT = DECOMPRESSED_100_PLUS_32KB
401+ CHUNK_SIZE = len (DAT ) // 3
402+
403+ # wrong value
404+ c = ZstdCompressor ()
405+ with self .assertRaisesRegex (ValueError ,
406+ r'should be a positive int less than \d+' ):
407+ c .set_pledged_input_size (- 300 )
408+ # overflow
409+ with self .assertRaisesRegex (ValueError ,
410+ r'should be a positive int less than \d+' ):
411+ c .set_pledged_input_size (2 ** 64 )
412+ # ZSTD_CONTENTSIZE_ERROR is invalid
413+ with self .assertRaisesRegex (ValueError ,
414+ r'should be a positive int less than \d+' ):
415+ c .set_pledged_input_size (2 ** 64 - 2 )
416+ # ZSTD_CONTENTSIZE_UNKNOWN should use None
417+ with self .assertRaisesRegex (ValueError ,
418+ r'should be a positive int less than \d+' ):
419+ c .set_pledged_input_size (2 ** 64 - 1 )
420+
421+ # check valid values are settable
422+ c .set_pledged_input_size (2 ** 63 )
423+ c .set_pledged_input_size (2 ** 64 - 3 )
424+
425+ # check that zero means empty frame
426+ c = ZstdCompressor (level = 1 )
427+ c .set_pledged_input_size (0 )
428+ c .compress (b'' )
429+ dat = c .flush ()
430+ ret = get_frame_info (dat )
431+ self .assertEqual (ret .decompressed_size , 0 )
432+
433+
434+ # wrong mode
435+ c = ZstdCompressor (level = 1 )
436+ c .compress (b'123456' )
437+ self .assertEqual (c .last_mode , c .CONTINUE )
438+ with self .assertRaisesRegex (ValueError ,
439+ r'last_mode == FLUSH_FRAME' ):
440+ c .set_pledged_input_size (300 )
441+
442+ # None value
443+ c = ZstdCompressor (level = 1 )
444+ c .set_pledged_input_size (None )
445+ dat = c .compress (DAT ) + c .flush ()
446+
447+ ret = get_frame_info (dat )
448+ self .assertEqual (ret .decompressed_size , None )
449+
450+ # correct value
451+ c = ZstdCompressor (level = 1 )
452+ c .set_pledged_input_size (len (DAT ))
453+
454+ chunks = []
455+ posi = 0
456+ while posi < len (DAT ):
457+ dat = c .compress (DAT [posi :posi + CHUNK_SIZE ])
458+ posi += CHUNK_SIZE
459+ chunks .append (dat )
460+
461+ dat = c .flush ()
462+ chunks .append (dat )
463+ chunks = b'' .join (chunks )
464+
465+ ret = get_frame_info (chunks )
466+ self .assertEqual (ret .decompressed_size , len (DAT ))
467+ self .assertEqual (decompress (chunks ), DAT )
468+
469+ c .set_pledged_input_size (len (DAT )) # the second frame
470+ dat = c .compress (DAT ) + c .flush ()
471+
472+ ret = get_frame_info (dat )
473+ self .assertEqual (ret .decompressed_size , len (DAT ))
474+ self .assertEqual (decompress (dat ), DAT )
475+
476+ # not enough data
477+ c = ZstdCompressor (level = 1 )
478+ c .set_pledged_input_size (len (DAT )+ 1 )
479+
480+ for start in range (0 , len (DAT ), CHUNK_SIZE ):
481+ end = min (start + CHUNK_SIZE , len (DAT ))
482+ _dat = c .compress (DAT [start :end ])
483+
484+ with self .assertRaises (ZstdError ):
485+ c .flush ()
486+
487+ # too much data
488+ c = ZstdCompressor (level = 1 )
489+ c .set_pledged_input_size (len (DAT ))
490+
491+ for start in range (0 , len (DAT ), CHUNK_SIZE ):
492+ end = min (start + CHUNK_SIZE , len (DAT ))
493+ _dat = c .compress (DAT [start :end ])
494+
495+ with self .assertRaises (ZstdError ):
496+ c .compress (b'extra' , ZstdCompressor .FLUSH_FRAME )
497+
498+ # content size not set if content_size_flag == 0
499+ c = ZstdCompressor (options = {CompressionParameter .content_size_flag : 0 })
500+ c .set_pledged_input_size (10 )
501+ dat1 = c .compress (b"hello" )
502+ dat2 = c .compress (b"world" )
503+ dat3 = c .flush ()
504+ frame_data = get_frame_info (dat1 + dat2 + dat3 )
505+ self .assertIsNone (frame_data .decompressed_size )
506+
507+
399508class DecompressorTestCase (unittest .TestCase ):
400509
401510 def test_simple_decompress_bad_args (self ):
0 commit comments