Issue #11707: Fast C version of functools.cmp_to_key()

This commit is contained in:
Raymond Hettinger 2011-04-05 02:33:54 -07:00
parent 271b27e5fe
commit 7ab9e22e34
4 changed files with 235 additions and 2 deletions

View file

@ -97,7 +97,7 @@ def cmp_to_key(mycmp):
"""Convert a cmp= function into a key= function"""
class K(object):
__slots__ = ['obj']
def __init__(self, obj, *args):
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj) < 0
@ -115,6 +115,11 @@ def cmp_to_key(mycmp):
raise TypeError('hash not implemented')
return K
try:
from _functools import cmp_to_key
except ImportError:
pass
_CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize")
def lru_cache(maxsize=100):

View file

@ -435,18 +435,81 @@ class TestReduce(unittest.TestCase):
self.assertEqual(self.func(add, d), "".join(d.keys()))
class TestCmpToKey(unittest.TestCase):
def test_cmp_to_key(self):
def cmp1(x, y):
return (x > y) - (x < y)
key = functools.cmp_to_key(cmp1)
self.assertEqual(key(3), key(3))
self.assertGreater(key(3), key(1))
def cmp2(x, y):
return int(x) - int(y)
key = functools.cmp_to_key(cmp2)
self.assertEqual(key(4.0), key('4'))
self.assertLess(key(2), key('35'))
def test_cmp_to_key_arguments(self):
def cmp1(x, y):
return (x > y) - (x < y)
key = functools.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(obj=3), key(obj=3))
self.assertGreater(key(obj=3), key(obj=1))
with self.assertRaises((TypeError, AttributeError)):
key(3) > 1 # rhs is not a K object
with self.assertRaises((TypeError, AttributeError)):
1 < key(3) # lhs is not a K object
with self.assertRaises(TypeError):
key = functools.cmp_to_key() # too few args
with self.assertRaises(TypeError):
key = functools.cmp_to_key(cmp1, None) # too many args
key = functools.cmp_to_key(cmp1)
with self.assertRaises(TypeError):
key() # too few args
with self.assertRaises(TypeError):
key(None, None) # too many args
def test_bad_cmp(self):
def cmp1(x, y):
raise ZeroDivisionError
key = functools.cmp_to_key(cmp1)
with self.assertRaises(ZeroDivisionError):
key(3) > key(1)
class BadCmp:
def __lt__(self, other):
raise ZeroDivisionError
def cmp1(x, y):
return BadCmp()
with self.assertRaises(ZeroDivisionError):
key(3) > key(1)
def test_obj_field(self):
def cmp1(x, y):
return (x > y) - (x < y)
key = functools.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(50).obj, 50)
def test_sort_int(self):
def mycmp(x, y):
return y - x
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
[4, 3, 2, 1, 0])
def test_sort_int_str(self):
def mycmp(x, y):
x, y = int(x), int(y)
return (x > y) - (x < y)
values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
values = sorted(values, key=functools.cmp_to_key(mycmp))
self.assertEqual([int(value) for value in values],
[0, 1, 1, 2, 3, 4, 5, 7, 10])
def test_hash(self):
def mycmp(x, y):
return y - x
key = functools.cmp_to_key(mycmp)
k = key(10)
self.assertRaises(TypeError, hash(k))
self.assertRaises(TypeError, hash, k)
class TestTotalOrdering(unittest.TestCase):
@ -655,6 +718,7 @@ class TestLRU(unittest.TestCase):
def test_main(verbose=None):
test_classes = (
TestCmpToKey,
TestPartial,
TestPartialSubclass,
TestPythonPartial,