@@ -259,6 +259,74 @@ def wrapper():
259259 self .assertEqual (wrapper .attr , 'This is a different test' )
260260 self .assertEqual (wrapper .dict_attr , f .dict_attr )
261261
262+ class TestReduce (unittest .TestCase ):
263+ func = functools .reduce
264+
265+ def test_reduce (self ):
266+ class Squares :
267+ def __init__ (self , max ):
268+ self .max = max
269+ self .sofar = []
270+
271+ def __len__ (self ):
272+ return len (self .sofar )
273+
274+ def __getitem__ (self , i ):
275+ if not 0 <= i < self .max : raise IndexError
276+ n = len (self .sofar )
277+ while n <= i :
278+ self .sofar .append (n * n )
279+ n += 1
280+ return self .sofar [i ]
281+
282+ self .assertEqual (self .func (lambda x , y : x + y , ['a' , 'b' , 'c' ], '' ), 'abc' )
283+ self .assertEqual (
284+ self .func (lambda x , y : x + y , [['a' , 'c' ], [], ['d' , 'w' ]], []),
285+ ['a' ,'c' ,'d' ,'w' ]
286+ )
287+ self .assertEqual (self .func (lambda x , y : x * y , range (2 ,8 ), 1 ), 5040 )
288+ self .assertEqual (
289+ self .func (lambda x , y : x * y , range (2 ,21 ), 1L ),
290+ 2432902008176640000L
291+ )
292+ self .assertEqual (self .func (lambda x , y : x + y , Squares (10 )), 285 )
293+ self .assertEqual (self .func (lambda x , y : x + y , Squares (10 ), 0 ), 285 )
294+ self .assertEqual (self .func (lambda x , y : x + y , Squares (0 ), 0 ), 0 )
295+ self .assertRaises (TypeError , self .func )
296+ self .assertRaises (TypeError , self .func , 42 , 42 )
297+ self .assertRaises (TypeError , self .func , 42 , 42 , 42 )
298+ self .assertEqual (self .func (42 , "1" ), "1" ) # func is never called with one item
299+ self .assertEqual (self .func (42 , "" , "1" ), "1" ) # func is never called with one item
300+ self .assertRaises (TypeError , self .func , 42 , (42 , 42 ))
301+
302+ class BadSeq :
303+ def __getitem__ (self , index ):
304+ raise ValueError
305+ self .assertRaises (ValueError , self .func , 42 , BadSeq ())
306+
307+ # Test reduce()'s use of iterators.
308+ def test_iterator_usage (self ):
309+ class SequenceClass :
310+ def __init__ (self , n ):
311+ self .n = n
312+ def __getitem__ (self , i ):
313+ if 0 <= i < self .n :
314+ return i
315+ else :
316+ raise IndexError
317+
318+ from operator import add
319+ self .assertEqual (self .func (add , SequenceClass (5 )), 10 )
320+ self .assertEqual (self .func (add , SequenceClass (5 ), 42 ), 52 )
321+ self .assertRaises (TypeError , self .func , add , SequenceClass (0 ))
322+ self .assertEqual (self .func (add , SequenceClass (0 ), 42 ), 42 )
323+ self .assertEqual (self .func (add , SequenceClass (1 )), 0 )
324+ self .assertEqual (self .func (add , SequenceClass (1 ), 42 ), 42 )
325+
326+ d = {"one" : 1 , "two" : 2 , "three" : 3 }
327+ self .assertEqual (self .func (add , d ), "" .join (d .keys ()))
328+
329+
262330
263331
264332def test_main (verbose = None ):
@@ -268,7 +336,8 @@ def test_main(verbose=None):
268336 TestPartialSubclass ,
269337 TestPythonPartial ,
270338 TestUpdateWrapper ,
271- TestWraps
339+ TestWraps ,
340+ TestReduce
272341 )
273342 test_support .run_unittest (* test_classes )
274343
0 commit comments