88from test import support
99import unittest
1010from weakref import proxy
11+ import contextlib
1112try :
1213 import threading
1314except ImportError :
2021
2122decimal = support .import_fresh_module ('decimal' , fresh = ['_decimal' ])
2223
24+ @contextlib .contextmanager
25+ def replaced_module (name , replacement ):
26+ original_module = sys .modules [name ]
27+ sys .modules [name ] = replacement
28+ try :
29+ yield
30+ finally :
31+ sys .modules [name ] = original_module
2332
2433def capture (* args , ** kw ):
2534 """capture all positional and keyword arguments"""
@@ -167,89 +176,67 @@ def foo(bar):
167176 p2 .new_attr = 'spam'
168177 self .assertEqual (p2 .new_attr , 'spam' )
169178
170-
171- @unittest .skipUnless (c_functools , 'requires the C _functools module' )
172- class TestPartialC (TestPartial , unittest .TestCase ):
173- if c_functools :
174- partial = c_functools .partial
175-
176- def test_attributes_unwritable (self ):
177- # attributes should not be writable
178- p = self .partial (capture , 1 , 2 , a = 10 , b = 20 )
179- self .assertRaises (AttributeError , setattr , p , 'func' , map )
180- self .assertRaises (AttributeError , setattr , p , 'args' , (1 , 2 ))
181- self .assertRaises (AttributeError , setattr , p , 'keywords' , dict (a = 1 , b = 2 ))
182-
183- p = self .partial (hex )
184- try :
185- del p .__dict__
186- except TypeError :
187- pass
188- else :
189- self .fail ('partial object allowed __dict__ to be deleted' )
190-
191179 def test_repr (self ):
192180 args = (object (), object ())
193181 args_repr = ', ' .join (repr (a ) for a in args )
194182 kwargs = {'a' : object (), 'b' : object ()}
195183 kwargs_reprs = ['a={a!r}, b={b!r}' .format_map (kwargs ),
196184 'b={b!r}, a={a!r}' .format_map (kwargs )]
197- if self .partial is c_functools .partial :
185+ if self .partial in ( c_functools .partial , py_functools . partial ) :
198186 name = 'functools.partial'
199187 else :
200188 name = self .partial .__name__
201189
202190 f = self .partial (capture )
203- self .assertEqual ('{}({!r})' .format (name , capture ),
204- repr (f ))
191+ self .assertEqual (f'{ name } ({ capture !r} )' , repr (f ))
205192
206193 f = self .partial (capture , * args )
207- self .assertEqual ('{}({!r}, {})' .format (name , capture , args_repr ),
208- repr (f ))
194+ self .assertEqual (f'{ name } ({ capture !r} , { args_repr } )' , repr (f ))
209195
210196 f = self .partial (capture , ** kwargs )
211197 self .assertIn (repr (f ),
212- ['{ }({!r}, {})'. format ( name , capture , kwargs_repr )
198+ [f' { name } ({ capture !r} , { kwargs_repr } )'
213199 for kwargs_repr in kwargs_reprs ])
214200
215201 f = self .partial (capture , * args , ** kwargs )
216202 self .assertIn (repr (f ),
217- ['{ }({!r}, {}, {})'. format ( name , capture , args_repr , kwargs_repr )
203+ [f' { name } ({ capture !r} , { args_repr } , { kwargs_repr } )'
218204 for kwargs_repr in kwargs_reprs ])
219205
220206 def test_recursive_repr (self ):
221- if self .partial is c_functools .partial :
207+ if self .partial in ( c_functools .partial , py_functools . partial ) :
222208 name = 'functools.partial'
223209 else :
224210 name = self .partial .__name__
225211
226212 f = self .partial (capture )
227213 f .__setstate__ ((f , (), {}, {}))
228214 try :
229- self .assertEqual (repr (f ), '%s(%s( ...)) ' % (name , name ))
215+ self .assertEqual (repr (f ), '%s(...)' % (name ,))
230216 finally :
231217 f .__setstate__ ((capture , (), {}, {}))
232218
233219 f = self .partial (capture )
234220 f .__setstate__ ((capture , (f ,), {}, {}))
235221 try :
236- self .assertEqual (repr (f ), '%s(%r, %s( ...)) ' % (name , capture , name ))
222+ self .assertEqual (repr (f ), '%s(%r, ...)' % (name , capture ,))
237223 finally :
238224 f .__setstate__ ((capture , (), {}, {}))
239225
240226 f = self .partial (capture )
241227 f .__setstate__ ((capture , (), {'a' : f }, {}))
242228 try :
243- self .assertEqual (repr (f ), '%s(%r, a=%s( ...)) ' % (name , capture , name ))
229+ self .assertEqual (repr (f ), '%s(%r, a=...)' % (name , capture ,))
244230 finally :
245231 f .__setstate__ ((capture , (), {}, {}))
246232
247233 def test_pickle (self ):
248- f = self .partial (signature , ['asdf' ], bar = [True ])
249- f .attr = []
250- for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
251- f_copy = pickle .loads (pickle .dumps (f , proto ))
252- self .assertEqual (signature (f_copy ), signature (f ))
234+ with self .AllowPickle ():
235+ f = self .partial (signature , ['asdf' ], bar = [True ])
236+ f .attr = []
237+ for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
238+ f_copy = pickle .loads (pickle .dumps (f , proto ))
239+ self .assertEqual (signature (f_copy ), signature (f ))
253240
254241 def test_copy (self ):
255242 f = self .partial (signature , ['asdf' ], bar = [True ])
@@ -274,11 +261,13 @@ def test_deepcopy(self):
274261 def test_setstate (self ):
275262 f = self .partial (signature )
276263 f .__setstate__ ((capture , (1 ,), dict (a = 10 ), dict (attr = [])))
264+
277265 self .assertEqual (signature (f ),
278266 (capture , (1 ,), dict (a = 10 ), dict (attr = [])))
279267 self .assertEqual (f (2 , b = 20 ), ((1 , 2 ), {'a' : 10 , 'b' : 20 }))
280268
281269 f .__setstate__ ((capture , (1 ,), dict (a = 10 ), None ))
270+
282271 self .assertEqual (signature (f ), (capture , (1 ,), dict (a = 10 ), {}))
283272 self .assertEqual (f (2 , b = 20 ), ((1 , 2 ), {'a' : 10 , 'b' : 20 }))
284273
@@ -325,38 +314,39 @@ def test_setstate_subclasses(self):
325314 self .assertIs (type (r [0 ]), tuple )
326315
327316 def test_recursive_pickle (self ):
328- f = self .partial (capture )
329- f .__setstate__ ((f , (), {}, {}))
330- try :
331- for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
332- with self .assertRaises (RecursionError ):
333- pickle .dumps (f , proto )
334- finally :
335- f .__setstate__ ((capture , (), {}, {}))
336-
337- f = self .partial (capture )
338- f .__setstate__ ((capture , (f ,), {}, {}))
339- try :
340- for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
341- f_copy = pickle .loads (pickle .dumps (f , proto ))
342- try :
343- self .assertIs (f_copy .args [0 ], f_copy )
344- finally :
345- f_copy .__setstate__ ((capture , (), {}, {}))
346- finally :
347- f .__setstate__ ((capture , (), {}, {}))
348-
349- f = self .partial (capture )
350- f .__setstate__ ((capture , (), {'a' : f }, {}))
351- try :
352- for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
353- f_copy = pickle .loads (pickle .dumps (f , proto ))
354- try :
355- self .assertIs (f_copy .keywords ['a' ], f_copy )
356- finally :
357- f_copy .__setstate__ ((capture , (), {}, {}))
358- finally :
359- f .__setstate__ ((capture , (), {}, {}))
317+ with self .AllowPickle ():
318+ f = self .partial (capture )
319+ f .__setstate__ ((f , (), {}, {}))
320+ try :
321+ for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
322+ with self .assertRaises (RecursionError ):
323+ pickle .dumps (f , proto )
324+ finally :
325+ f .__setstate__ ((capture , (), {}, {}))
326+
327+ f = self .partial (capture )
328+ f .__setstate__ ((capture , (f ,), {}, {}))
329+ try :
330+ for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
331+ f_copy = pickle .loads (pickle .dumps (f , proto ))
332+ try :
333+ self .assertIs (f_copy .args [0 ], f_copy )
334+ finally :
335+ f_copy .__setstate__ ((capture , (), {}, {}))
336+ finally :
337+ f .__setstate__ ((capture , (), {}, {}))
338+
339+ f = self .partial (capture )
340+ f .__setstate__ ((capture , (), {'a' : f }, {}))
341+ try :
342+ for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
343+ f_copy = pickle .loads (pickle .dumps (f , proto ))
344+ try :
345+ self .assertIs (f_copy .keywords ['a' ], f_copy )
346+ finally :
347+ f_copy .__setstate__ ((capture , (), {}, {}))
348+ finally :
349+ f .__setstate__ ((capture , (), {}, {}))
360350
361351 # Issue 6083: Reference counting bug
362352 def test_setstate_refcount (self ):
@@ -375,24 +365,60 @@ def __getitem__(self, key):
375365 f = self .partial (object )
376366 self .assertRaises (TypeError , f .__setstate__ , BadSequence ())
377367
368+ @unittest .skipUnless (c_functools , 'requires the C _functools module' )
369+ class TestPartialC (TestPartial , unittest .TestCase ):
370+ if c_functools :
371+ partial = c_functools .partial
372+
373+ class AllowPickle :
374+ def __enter__ (self ):
375+ return self
376+ def __exit__ (self , type , value , tb ):
377+ return False
378+
379+ def test_attributes_unwritable (self ):
380+ # attributes should not be writable
381+ p = self .partial (capture , 1 , 2 , a = 10 , b = 20 )
382+ self .assertRaises (AttributeError , setattr , p , 'func' , map )
383+ self .assertRaises (AttributeError , setattr , p , 'args' , (1 , 2 ))
384+ self .assertRaises (AttributeError , setattr , p , 'keywords' , dict (a = 1 , b = 2 ))
385+
386+ p = self .partial (hex )
387+ try :
388+ del p .__dict__
389+ except TypeError :
390+ pass
391+ else :
392+ self .fail ('partial object allowed __dict__ to be deleted' )
378393
379394class TestPartialPy (TestPartial , unittest .TestCase ):
380- partial = staticmethod ( py_functools .partial )
395+ partial = py_functools .partial
381396
397+ class AllowPickle :
398+ def __init__ (self ):
399+ self ._cm = replaced_module ("functools" , py_functools )
400+ def __enter__ (self ):
401+ return self ._cm .__enter__ ()
402+ def __exit__ (self , type , value , tb ):
403+ return self ._cm .__exit__ (type , value , tb )
382404
383405if c_functools :
384- class PartialSubclass (c_functools .partial ):
406+ class CPartialSubclass (c_functools .partial ):
385407 pass
386408
409+ class PyPartialSubclass (py_functools .partial ):
410+ pass
387411
388412@unittest .skipUnless (c_functools , 'requires the C _functools module' )
389413class TestPartialCSubclass (TestPartialC ):
390414 if c_functools :
391- partial = PartialSubclass
415+ partial = CPartialSubclass
392416
393417 # partial subclasses are not optimized for nested calls
394418 test_nested_optimization = None
395419
420+ class TestPartialPySubclass (TestPartialPy ):
421+ partial = PyPartialSubclass
396422
397423class TestPartialMethod (unittest .TestCase ):
398424
@@ -683,9 +709,10 @@ def wrapper():
683709 self .assertEqual (wrapper .attr , 'This is a different test' )
684710 self .assertEqual (wrapper .dict_attr , f .dict_attr )
685711
686-
712+ @ unittest . skipUnless ( c_functools , 'requires the C _functools module' )
687713class TestReduce (unittest .TestCase ):
688- func = functools .reduce
714+ if c_functools :
715+ func = c_functools .reduce
689716
690717 def test_reduce (self ):
691718 class Squares :
0 commit comments