@@ -159,6 +159,44 @@ def test_accumulate(self):
159159 with self .assertRaises (TypeError ):
160160 list (accumulate ([10 , 20 ], 100 ))
161161
162+ def test_batched (self ):
163+ self .assertEqual (list (batched ('ABCDEFG' , 3 )),
164+ [['A' , 'B' , 'C' ], ['D' , 'E' , 'F' ], ['G' ]])
165+ self .assertEqual (list (batched ('ABCDEFG' , 2 )),
166+ [['A' , 'B' ], ['C' , 'D' ], ['E' , 'F' ], ['G' ]])
167+ self .assertEqual (list (batched ('ABCDEFG' , 1 )),
168+ [['A' ], ['B' ], ['C' ], ['D' ], ['E' ], ['F' ], ['G' ]])
169+
170+ with self .assertRaises (TypeError ): # Too few arguments
171+ list (batched ('ABCDEFG' ))
172+ with self .assertRaises (TypeError ):
173+ list (batched ('ABCDEFG' , 3 , None )) # Too many arguments
174+ with self .assertRaises (TypeError ):
175+ list (batched (None , 3 )) # Non-iterable input
176+ with self .assertRaises (TypeError ):
177+ list (batched ('ABCDEFG' , 'hello' )) # n is a string
178+ with self .assertRaises (ValueError ):
179+ list (batched ('ABCDEFG' , 0 )) # n is zero
180+ with self .assertRaises (ValueError ):
181+ list (batched ('ABCDEFG' , - 1 )) # n is negative
182+
183+ data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
184+ for n in range (1 , 6 ):
185+ for i in range (len (data )):
186+ s = data [:i ]
187+ batches = list (batched (s , n ))
188+ with self .subTest (s = s , n = n , batches = batches ):
189+ # Order is preserved and no data is lost
190+ self .assertEqual ('' .join (chain (* batches )), s )
191+ # Each batch is an exact list
192+ self .assertTrue (all (type (batch ) is list for batch in batches ))
193+ # All but the last batch is of size n
194+ if batches :
195+ last_batch = batches .pop ()
196+ self .assertTrue (all (len (batch ) == n for batch in batches ))
197+ self .assertTrue (len (last_batch ) <= n )
198+ batches .append (last_batch )
199+
162200 def test_chain (self ):
163201
164202 def chain2 (* iterables ):
@@ -1737,6 +1775,31 @@ def test_takewhile(self):
17371775
17381776class TestPurePythonRoughEquivalents (unittest .TestCase ):
17391777
1778+ def test_batched_recipe (self ):
1779+ def batched_recipe (iterable , n ):
1780+ "Batch data into lists of length n. The last batch may be shorter."
1781+ # batched('ABCDEFG', 3) --> ABC DEF G
1782+ if n < 1 :
1783+ raise ValueError ('n must be at least one' )
1784+ it = iter (iterable )
1785+ while (batch := list (islice (it , n ))):
1786+ yield batch
1787+
1788+ for iterable , n in product (
1789+ ['' , 'a' , 'ab' , 'abc' , 'abcd' , 'abcde' , 'abcdef' , 'abcdefg' , None ],
1790+ [- 1 , 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , None ]):
1791+ with self .subTest (iterable = iterable , n = n ):
1792+ try :
1793+ e1 , r1 = None , list (batched (iterable , n ))
1794+ except Exception as e :
1795+ e1 , r1 = type (e ), None
1796+ try :
1797+ e2 , r2 = None , list (batched_recipe (iterable , n ))
1798+ except Exception as e :
1799+ e2 , r2 = type (e ), None
1800+ self .assertEqual (r1 , r2 )
1801+ self .assertEqual (e1 , e2 )
1802+
17401803 @staticmethod
17411804 def islice (iterable , * args ):
17421805 s = slice (* args )
@@ -1788,6 +1851,10 @@ def test_accumulate(self):
17881851 a = []
17891852 self .makecycle (accumulate ([1 ,2 ,a ,3 ]), a )
17901853
1854+ def test_batched (self ):
1855+ a = []
1856+ self .makecycle (batched ([1 ,2 ,a ,3 ], 2 ), a )
1857+
17911858 def test_chain (self ):
17921859 a = []
17931860 self .makecycle (chain (a ), a )
@@ -1972,6 +2039,18 @@ def test_accumulate(self):
19722039 self .assertRaises (TypeError , accumulate , N (s ))
19732040 self .assertRaises (ZeroDivisionError , list , accumulate (E (s )))
19742041
2042+ def test_batched (self ):
2043+ s = 'abcde'
2044+ r = [['a' , 'b' ], ['c' , 'd' ], ['e' ]]
2045+ n = 2
2046+ for g in (G , I , Ig , L , R ):
2047+ with self .subTest (g = g ):
2048+ self .assertEqual (list (batched (g (s ), n )), r )
2049+ self .assertEqual (list (batched (S (s ), 2 )), [])
2050+ self .assertRaises (TypeError , batched , X (s ), 2 )
2051+ self .assertRaises (TypeError , batched , N (s ), 2 )
2052+ self .assertRaises (ZeroDivisionError , list , batched (E (s ), 2 ))
2053+
19752054 def test_chain (self ):
19762055 for s in ("123" , "" , range (1000 ), ('do' , 1.2 ), range (2000 ,2200 ,5 )):
19772056 for g in (G , I , Ig , S , L , R ):
0 commit comments