Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 457fc9a

Browse files
committed
Issue #27137: align Python & C implementations of functools.partial
The pure Python fallback implementation of functools.partial now matches the behaviour of its accelerated C counterpart for subclassing, pickling and text representation purposes. Patch by Emanuel Barry and Serhiy Storchaka.
1 parent eddc4b7 commit 457fc9a

4 files changed

Lines changed: 182 additions & 92 deletions

File tree

Lib/functools.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections import namedtuple
2222
from types import MappingProxyType
2323
from weakref import WeakKeyDictionary
24+
from reprlib import recursive_repr
2425
try:
2526
from _thread import RLock
2627
except ImportError:
@@ -237,26 +238,83 @@ def __ge__(self, other):
237238
################################################################################
238239

239240
# Purely functional, no descriptor behaviour
240-
def partial(func, *args, **keywords):
241+
class partial:
241242
"""New function with partial application of the given arguments
242243
and keywords.
243244
"""
244-
if hasattr(func, 'func'):
245-
args = func.args + args
246-
tmpkw = func.keywords.copy()
247-
tmpkw.update(keywords)
248-
keywords = tmpkw
249-
del tmpkw
250-
func = func.func
251-
252-
def newfunc(*fargs, **fkeywords):
253-
newkeywords = keywords.copy()
254-
newkeywords.update(fkeywords)
255-
return func(*(args + fargs), **newkeywords)
256-
newfunc.func = func
257-
newfunc.args = args
258-
newfunc.keywords = keywords
259-
return newfunc
245+
246+
__slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
247+
248+
def __new__(*args, **keywords):
249+
if not args:
250+
raise TypeError("descriptor '__new__' of partial needs an argument")
251+
if len(args) < 2:
252+
raise TypeError("type 'partial' takes at least one argument")
253+
cls, func, *args = args
254+
if not callable(func):
255+
raise TypeError("the first argument must be callable")
256+
args = tuple(args)
257+
258+
if hasattr(func, "func"):
259+
args = func.args + args
260+
tmpkw = func.keywords.copy()
261+
tmpkw.update(keywords)
262+
keywords = tmpkw
263+
del tmpkw
264+
func = func.func
265+
266+
self = super(partial, cls).__new__(cls)
267+
268+
self.func = func
269+
self.args = args
270+
self.keywords = keywords
271+
return self
272+
273+
def __call__(*args, **keywords):
274+
if not args:
275+
raise TypeError("descriptor '__call__' of partial needs an argument")
276+
self, *args = args
277+
newkeywords = self.keywords.copy()
278+
newkeywords.update(keywords)
279+
return self.func(*self.args, *args, **newkeywords)
280+
281+
@recursive_repr()
282+
def __repr__(self):
283+
qualname = type(self).__qualname__
284+
args = [repr(self.func)]
285+
args.extend(repr(x) for x in self.args)
286+
args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
287+
if type(self).__module__ == "functools":
288+
return f"functools.{qualname}({', '.join(args)})"
289+
return f"{qualname}({', '.join(args)})"
290+
291+
def __reduce__(self):
292+
return type(self), (self.func,), (self.func, self.args,
293+
self.keywords or None, self.__dict__ or None)
294+
295+
def __setstate__(self, state):
296+
if not isinstance(state, tuple):
297+
raise TypeError("argument to __setstate__ must be a tuple")
298+
if len(state) != 4:
299+
raise TypeError(f"expected 4 items in state, got {len(state)}")
300+
func, args, kwds, namespace = state
301+
if (not callable(func) or not isinstance(args, tuple) or
302+
(kwds is not None and not isinstance(kwds, dict)) or
303+
(namespace is not None and not isinstance(namespace, dict))):
304+
raise TypeError("invalid partial state")
305+
306+
args = tuple(args) # just in case it's a subclass
307+
if kwds is None:
308+
kwds = {}
309+
elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
310+
kwds = dict(kwds)
311+
if namespace is None:
312+
namespace = {}
313+
314+
self.__dict__ = namespace
315+
self.func = func
316+
self.args = args
317+
self.keywords = kwds
260318

261319
try:
262320
from _functools import partial

Lib/test/test_functools.py

Lines changed: 101 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from test import support
99
import unittest
1010
from weakref import proxy
11+
import contextlib
1112
try:
1213
import threading
1314
except ImportError:
@@ -20,6 +21,14 @@
2021

2122
decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
2223

24+
@contextlib.contextmanager
25+
def replaced_module(name, replacement):
26+
original_module = sys.modules[name]
27+
sys.modules[name] = replacement
28+
try:
29+
yield
30+
finally:
31+
sys.modules[name] = original_module
2332

2433
def capture(*args, **kw):
2534
"""capture all positional and keyword arguments"""
@@ -167,89 +176,67 @@ def foo(bar):
167176
p2.new_attr = 'spam'
168177
self.assertEqual(p2.new_attr, 'spam')
169178

170-
171-
@unittest.skipUnless(c_functools, 'requires the C _functools module')
172-
class TestPartialC(TestPartial, unittest.TestCase):
173-
if c_functools:
174-
partial = c_functools.partial
175-
176-
def test_attributes_unwritable(self):
177-
# attributes should not be writable
178-
p = self.partial(capture, 1, 2, a=10, b=20)
179-
self.assertRaises(AttributeError, setattr, p, 'func', map)
180-
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
181-
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
182-
183-
p = self.partial(hex)
184-
try:
185-
del p.__dict__
186-
except TypeError:
187-
pass
188-
else:
189-
self.fail('partial object allowed __dict__ to be deleted')
190-
191179
def test_repr(self):
192180
args = (object(), object())
193181
args_repr = ', '.join(repr(a) for a in args)
194182
kwargs = {'a': object(), 'b': object()}
195183
kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
196184
'b={b!r}, a={a!r}'.format_map(kwargs)]
197-
if self.partial is c_functools.partial:
185+
if self.partial in (c_functools.partial, py_functools.partial):
198186
name = 'functools.partial'
199187
else:
200188
name = self.partial.__name__
201189

202190
f = self.partial(capture)
203-
self.assertEqual('{}({!r})'.format(name, capture),
204-
repr(f))
191+
self.assertEqual(f'{name}({capture!r})', repr(f))
205192

206193
f = self.partial(capture, *args)
207-
self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
208-
repr(f))
194+
self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
209195

210196
f = self.partial(capture, **kwargs)
211197
self.assertIn(repr(f),
212-
['{}({!r}, {})'.format(name, capture, kwargs_repr)
198+
[f'{name}({capture!r}, {kwargs_repr})'
213199
for kwargs_repr in kwargs_reprs])
214200

215201
f = self.partial(capture, *args, **kwargs)
216202
self.assertIn(repr(f),
217-
['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
203+
[f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
218204
for kwargs_repr in kwargs_reprs])
219205

220206
def test_recursive_repr(self):
221-
if self.partial is c_functools.partial:
207+
if self.partial in (c_functools.partial, py_functools.partial):
222208
name = 'functools.partial'
223209
else:
224210
name = self.partial.__name__
225211

226212
f = self.partial(capture)
227213
f.__setstate__((f, (), {}, {}))
228214
try:
229-
self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
215+
self.assertEqual(repr(f), '%s(...)' % (name,))
230216
finally:
231217
f.__setstate__((capture, (), {}, {}))
232218

233219
f = self.partial(capture)
234220
f.__setstate__((capture, (f,), {}, {}))
235221
try:
236-
self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
222+
self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
237223
finally:
238224
f.__setstate__((capture, (), {}, {}))
239225

240226
f = self.partial(capture)
241227
f.__setstate__((capture, (), {'a': f}, {}))
242228
try:
243-
self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
229+
self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
244230
finally:
245231
f.__setstate__((capture, (), {}, {}))
246232

247233
def test_pickle(self):
248-
f = self.partial(signature, ['asdf'], bar=[True])
249-
f.attr = []
250-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
251-
f_copy = pickle.loads(pickle.dumps(f, proto))
252-
self.assertEqual(signature(f_copy), signature(f))
234+
with self.AllowPickle():
235+
f = self.partial(signature, ['asdf'], bar=[True])
236+
f.attr = []
237+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
238+
f_copy = pickle.loads(pickle.dumps(f, proto))
239+
self.assertEqual(signature(f_copy), signature(f))
253240

254241
def test_copy(self):
255242
f = self.partial(signature, ['asdf'], bar=[True])
@@ -274,11 +261,13 @@ def test_deepcopy(self):
274261
def test_setstate(self):
275262
f = self.partial(signature)
276263
f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
264+
277265
self.assertEqual(signature(f),
278266
(capture, (1,), dict(a=10), dict(attr=[])))
279267
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
280268

281269
f.__setstate__((capture, (1,), dict(a=10), None))
270+
282271
self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283272
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284273

@@ -325,38 +314,39 @@ def test_setstate_subclasses(self):
325314
self.assertIs(type(r[0]), tuple)
326315

327316
def test_recursive_pickle(self):
328-
f = self.partial(capture)
329-
f.__setstate__((f, (), {}, {}))
330-
try:
331-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
332-
with self.assertRaises(RecursionError):
333-
pickle.dumps(f, proto)
334-
finally:
335-
f.__setstate__((capture, (), {}, {}))
336-
337-
f = self.partial(capture)
338-
f.__setstate__((capture, (f,), {}, {}))
339-
try:
340-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341-
f_copy = pickle.loads(pickle.dumps(f, proto))
342-
try:
343-
self.assertIs(f_copy.args[0], f_copy)
344-
finally:
345-
f_copy.__setstate__((capture, (), {}, {}))
346-
finally:
347-
f.__setstate__((capture, (), {}, {}))
348-
349-
f = self.partial(capture)
350-
f.__setstate__((capture, (), {'a': f}, {}))
351-
try:
352-
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
353-
f_copy = pickle.loads(pickle.dumps(f, proto))
354-
try:
355-
self.assertIs(f_copy.keywords['a'], f_copy)
356-
finally:
357-
f_copy.__setstate__((capture, (), {}, {}))
358-
finally:
359-
f.__setstate__((capture, (), {}, {}))
317+
with self.AllowPickle():
318+
f = self.partial(capture)
319+
f.__setstate__((f, (), {}, {}))
320+
try:
321+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
322+
with self.assertRaises(RecursionError):
323+
pickle.dumps(f, proto)
324+
finally:
325+
f.__setstate__((capture, (), {}, {}))
326+
327+
f = self.partial(capture)
328+
f.__setstate__((capture, (f,), {}, {}))
329+
try:
330+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
331+
f_copy = pickle.loads(pickle.dumps(f, proto))
332+
try:
333+
self.assertIs(f_copy.args[0], f_copy)
334+
finally:
335+
f_copy.__setstate__((capture, (), {}, {}))
336+
finally:
337+
f.__setstate__((capture, (), {}, {}))
338+
339+
f = self.partial(capture)
340+
f.__setstate__((capture, (), {'a': f}, {}))
341+
try:
342+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
343+
f_copy = pickle.loads(pickle.dumps(f, proto))
344+
try:
345+
self.assertIs(f_copy.keywords['a'], f_copy)
346+
finally:
347+
f_copy.__setstate__((capture, (), {}, {}))
348+
finally:
349+
f.__setstate__((capture, (), {}, {}))
360350

361351
# Issue 6083: Reference counting bug
362352
def test_setstate_refcount(self):
@@ -375,24 +365,60 @@ def __getitem__(self, key):
375365
f = self.partial(object)
376366
self.assertRaises(TypeError, f.__setstate__, BadSequence())
377367

368+
@unittest.skipUnless(c_functools, 'requires the C _functools module')
369+
class TestPartialC(TestPartial, unittest.TestCase):
370+
if c_functools:
371+
partial = c_functools.partial
372+
373+
class AllowPickle:
374+
def __enter__(self):
375+
return self
376+
def __exit__(self, type, value, tb):
377+
return False
378+
379+
def test_attributes_unwritable(self):
380+
# attributes should not be writable
381+
p = self.partial(capture, 1, 2, a=10, b=20)
382+
self.assertRaises(AttributeError, setattr, p, 'func', map)
383+
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
384+
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
385+
386+
p = self.partial(hex)
387+
try:
388+
del p.__dict__
389+
except TypeError:
390+
pass
391+
else:
392+
self.fail('partial object allowed __dict__ to be deleted')
378393

379394
class TestPartialPy(TestPartial, unittest.TestCase):
380-
partial = staticmethod(py_functools.partial)
395+
partial = py_functools.partial
381396

397+
class AllowPickle:
398+
def __init__(self):
399+
self._cm = replaced_module("functools", py_functools)
400+
def __enter__(self):
401+
return self._cm.__enter__()
402+
def __exit__(self, type, value, tb):
403+
return self._cm.__exit__(type, value, tb)
382404

383405
if c_functools:
384-
class PartialSubclass(c_functools.partial):
406+
class CPartialSubclass(c_functools.partial):
385407
pass
386408

409+
class PyPartialSubclass(py_functools.partial):
410+
pass
387411

388412
@unittest.skipUnless(c_functools, 'requires the C _functools module')
389413
class TestPartialCSubclass(TestPartialC):
390414
if c_functools:
391-
partial = PartialSubclass
415+
partial = CPartialSubclass
392416

393417
# partial subclasses are not optimized for nested calls
394418
test_nested_optimization = None
395419

420+
class TestPartialPySubclass(TestPartialPy):
421+
partial = PyPartialSubclass
396422

397423
class TestPartialMethod(unittest.TestCase):
398424

@@ -683,9 +709,10 @@ def wrapper():
683709
self.assertEqual(wrapper.attr, 'This is a different test')
684710
self.assertEqual(wrapper.dict_attr, f.dict_attr)
685711

686-
712+
@unittest.skipUnless(c_functools, 'requires the C _functools module')
687713
class TestReduce(unittest.TestCase):
688-
func = functools.reduce
714+
if c_functools:
715+
func = c_functools.reduce
689716

690717
def test_reduce(self):
691718
class Squares:

0 commit comments

Comments
 (0)