Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7ab9e22

Browse files
committed
Issue #11707: Fast C version of functools.cmp_to_key()
1 parent 271b27e commit 7ab9e22

4 files changed

Lines changed: 235 additions & 2 deletions

File tree

Lib/functools.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def cmp_to_key(mycmp):
9797
"""Convert a cmp= function into a key= function"""
9898
class K(object):
9999
__slots__ = ['obj']
100-
def __init__(self, obj, *args):
100+
def __init__(self, obj):
101101
self.obj = obj
102102
def __lt__(self, other):
103103
return mycmp(self.obj, other.obj) < 0
@@ -115,6 +115,11 @@ def __hash__(self):
115115
raise TypeError('hash not implemented')
116116
return K
117117

118+
try:
119+
from _functools import cmp_to_key
120+
except ImportError:
121+
pass
122+
118123
_CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize")
119124

120125
def lru_cache(maxsize=100):

Lib/test/test_functools.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,18 +435,81 @@ def __getitem__(self, i):
435435
self.assertEqual(self.func(add, d), "".join(d.keys()))
436436

437437
class TestCmpToKey(unittest.TestCase):
438+
438439
def test_cmp_to_key(self):
440+
def cmp1(x, y):
441+
return (x > y) - (x < y)
442+
key = functools.cmp_to_key(cmp1)
443+
self.assertEqual(key(3), key(3))
444+
self.assertGreater(key(3), key(1))
445+
def cmp2(x, y):
446+
return int(x) - int(y)
447+
key = functools.cmp_to_key(cmp2)
448+
self.assertEqual(key(4.0), key('4'))
449+
self.assertLess(key(2), key('35'))
450+
451+
def test_cmp_to_key_arguments(self):
452+
def cmp1(x, y):
453+
return (x > y) - (x < y)
454+
key = functools.cmp_to_key(mycmp=cmp1)
455+
self.assertEqual(key(obj=3), key(obj=3))
456+
self.assertGreater(key(obj=3), key(obj=1))
457+
with self.assertRaises((TypeError, AttributeError)):
458+
key(3) > 1 # rhs is not a K object
459+
with self.assertRaises((TypeError, AttributeError)):
460+
1 < key(3) # lhs is not a K object
461+
with self.assertRaises(TypeError):
462+
key = functools.cmp_to_key() # too few args
463+
with self.assertRaises(TypeError):
464+
key = functools.cmp_to_key(cmp1, None) # too many args
465+
key = functools.cmp_to_key(cmp1)
466+
with self.assertRaises(TypeError):
467+
key() # too few args
468+
with self.assertRaises(TypeError):
469+
key(None, None) # too many args
470+
471+
def test_bad_cmp(self):
472+
def cmp1(x, y):
473+
raise ZeroDivisionError
474+
key = functools.cmp_to_key(cmp1)
475+
with self.assertRaises(ZeroDivisionError):
476+
key(3) > key(1)
477+
478+
class BadCmp:
479+
def __lt__(self, other):
480+
raise ZeroDivisionError
481+
def cmp1(x, y):
482+
return BadCmp()
483+
with self.assertRaises(ZeroDivisionError):
484+
key(3) > key(1)
485+
486+
def test_obj_field(self):
487+
def cmp1(x, y):
488+
return (x > y) - (x < y)
489+
key = functools.cmp_to_key(mycmp=cmp1)
490+
self.assertEqual(key(50).obj, 50)
491+
492+
def test_sort_int(self):
439493
def mycmp(x, y):
440494
return y - x
441495
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
442496
[4, 3, 2, 1, 0])
443497

498+
def test_sort_int_str(self):
499+
def mycmp(x, y):
500+
x, y = int(x), int(y)
501+
return (x > y) - (x < y)
502+
values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
503+
values = sorted(values, key=functools.cmp_to_key(mycmp))
504+
self.assertEqual([int(value) for value in values],
505+
[0, 1, 1, 2, 3, 4, 5, 7, 10])
506+
444507
def test_hash(self):
445508
def mycmp(x, y):
446509
return y - x
447510
key = functools.cmp_to_key(mycmp)
448511
k = key(10)
449-
self.assertRaises(TypeError, hash(k))
512+
self.assertRaises(TypeError, hash, k)
450513

451514
class TestTotalOrdering(unittest.TestCase):
452515

@@ -655,6 +718,7 @@ def fib(n):
655718

656719
def test_main(verbose=None):
657720
test_classes = (
721+
TestCmpToKey,
658722
TestPartial,
659723
TestPartialSubclass,
660724
TestPythonPartial,

Misc/NEWS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ Library
9797
- Issue #10791: Implement missing method GzipFile.read1(), allowing GzipFile
9898
to be wrapped in a TextIOWrapper. Patch by Nadeem Vawda.
9999

100+
- Issue #11707: Added a fast C version of functools.cmp_to_key().
101+
Patch by Filip Gruszczyński.
102+
100103
- Issue #11688: Add sqlite3.Connection.set_trace_callback(). Patch by
101104
Torsten Landschoff.
102105

Modules/_functoolsmodule.c

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,165 @@ static PyTypeObject partial_type = {
330330
};
331331

332332

333+
/* cmp_to_key ***************************************************************/
334+
335+
typedef struct {
336+
PyObject_HEAD;
337+
PyObject *cmp;
338+
PyObject *object;
339+
} keyobject;
340+
341+
static void
342+
keyobject_dealloc(keyobject *ko)
343+
{
344+
Py_DECREF(ko->cmp);
345+
Py_XDECREF(ko->object);
346+
PyObject_FREE(ko);
347+
}
348+
349+
static int
350+
keyobject_traverse(keyobject *ko, visitproc visit, void *arg)
351+
{
352+
Py_VISIT(ko->cmp);
353+
if (ko->object)
354+
Py_VISIT(ko->object);
355+
return 0;
356+
}
357+
358+
static PyMemberDef keyobject_members[] = {
359+
{"obj", T_OBJECT,
360+
offsetof(keyobject, object), 0,
361+
PyDoc_STR("Value wrapped by a key function.")},
362+
{NULL}
363+
};
364+
365+
static PyObject *
366+
keyobject_call(keyobject *ko, PyObject *args, PyObject *kw);
367+
368+
static PyObject *
369+
keyobject_richcompare(PyObject *ko, PyObject *other, int op);
370+
371+
static PyTypeObject keyobject_type = {
372+
PyVarObject_HEAD_INIT(&PyType_Type, 0)
373+
"functools.KeyWrapper", /* tp_name */
374+
sizeof(keyobject), /* tp_basicsize */
375+
0, /* tp_itemsize */
376+
/* methods */
377+
(destructor)keyobject_dealloc, /* tp_dealloc */
378+
0, /* tp_print */
379+
0, /* tp_getattr */
380+
0, /* tp_setattr */
381+
0, /* tp_reserved */
382+
0, /* tp_repr */
383+
0, /* tp_as_number */
384+
0, /* tp_as_sequence */
385+
0, /* tp_as_mapping */
386+
0, /* tp_hash */
387+
(ternaryfunc)keyobject_call, /* tp_call */
388+
0, /* tp_str */
389+
PyObject_GenericGetAttr, /* tp_getattro */
390+
0, /* tp_setattro */
391+
0, /* tp_as_buffer */
392+
Py_TPFLAGS_DEFAULT, /* tp_flags */
393+
0, /* tp_doc */
394+
(traverseproc)keyobject_traverse, /* tp_traverse */
395+
0, /* tp_clear */
396+
keyobject_richcompare, /* tp_richcompare */
397+
0, /* tp_weaklistoffset */
398+
0, /* tp_iter */
399+
0, /* tp_iternext */
400+
0, /* tp_methods */
401+
keyobject_members, /* tp_members */
402+
0, /* tp_getset */
403+
};
404+
405+
static PyObject *
406+
keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds)
407+
{
408+
PyObject *object;
409+
keyobject *result;
410+
static char *kwargs[] = {"obj", NULL};
411+
412+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:K", kwargs, &object))
413+
return NULL;
414+
result = PyObject_New(keyobject, &keyobject_type);
415+
if (!result)
416+
return NULL;
417+
Py_INCREF(ko->cmp);
418+
result->cmp = ko->cmp;
419+
Py_INCREF(object);
420+
result->object = object;
421+
return (PyObject *)result;
422+
}
423+
424+
static PyObject *
425+
keyobject_richcompare(PyObject *ko, PyObject *other, int op)
426+
{
427+
PyObject *res;
428+
PyObject *args;
429+
PyObject *x;
430+
PyObject *y;
431+
PyObject *compare;
432+
PyObject *answer;
433+
static PyObject *zero;
434+
435+
if (zero == NULL) {
436+
zero = PyLong_FromLong(0);
437+
if (!zero)
438+
return NULL;
439+
}
440+
441+
if (Py_TYPE(other) != &keyobject_type){
442+
PyErr_Format(PyExc_TypeError, "other argument must be K instance");
443+
return NULL;
444+
}
445+
compare = ((keyobject *) ko)->cmp;
446+
assert(compare != NULL);
447+
x = ((keyobject *) ko)->object;
448+
y = ((keyobject *) other)->object;
449+
if (!x || !y){
450+
PyErr_Format(PyExc_AttributeError, "object");
451+
return NULL;
452+
}
453+
454+
/* Call the user's comparison function and translate the 3-way
455+
* result into true or false (or error).
456+
*/
457+
args = PyTuple_New(2);
458+
if (args == NULL)
459+
return NULL;
460+
Py_INCREF(x);
461+
Py_INCREF(y);
462+
PyTuple_SET_ITEM(args, 0, x);
463+
PyTuple_SET_ITEM(args, 1, y);
464+
res = PyObject_Call(compare, args, NULL);
465+
Py_DECREF(args);
466+
if (res == NULL)
467+
return NULL;
468+
answer = PyObject_RichCompare(res, zero, op);
469+
Py_DECREF(res);
470+
return answer;
471+
}
472+
473+
static PyObject *
474+
functools_cmp_to_key(PyObject *self, PyObject *args, PyObject *kwds){
475+
PyObject *cmp;
476+
static char *kwargs[] = {"mycmp", NULL};
477+
478+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:cmp_to_key", kwargs, &cmp))
479+
return NULL;
480+
keyobject *object = PyObject_New(keyobject, &keyobject_type);
481+
if (!object)
482+
return NULL;
483+
Py_INCREF(cmp);
484+
object->cmp = cmp;
485+
object->object = NULL;
486+
return (PyObject *)object;
487+
}
488+
489+
PyDoc_STRVAR(functools_cmp_to_key_doc,
490+
"Convert a cmp= function into a key= function.");
491+
333492
/* reduce (used to be a builtin) ********************************************/
334493

335494
static PyObject *
@@ -413,6 +572,8 @@ PyDoc_STRVAR(module_doc,
413572

414573
static PyMethodDef module_methods[] = {
415574
{"reduce", functools_reduce, METH_VARARGS, functools_reduce_doc},
575+
{"cmp_to_key", functools_cmp_to_key, METH_VARARGS | METH_KEYWORDS,
576+
functools_cmp_to_key_doc},
416577
{NULL, NULL} /* sentinel */
417578
};
418579

0 commit comments

Comments
 (0)