Forward port total_ordering() and cmp_to_key().

This commit is contained in:
Raymond Hettinger 2010-04-05 18:56:31 +00:00
parent 5daab45158
commit c50846aaef
10 changed files with 186 additions and 25 deletions

View file

@ -364,7 +364,89 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
class TestCmpToKey(unittest.TestCase):
def test_cmp_to_key(self):
def mycmp(x, y):
return y - x
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
[4, 3, 2, 1, 0])
def test_hash(self):
def mycmp(x, y):
return y - x
key = functools.cmp_to_key(mycmp)
k = key(10)
self.assertRaises(TypeError, hash(k))
class TestTotalOrdering(unittest.TestCase):
def test_total_ordering_lt(self):
@functools.total_ordering
class A:
def __init__(self, value):
self.value = value
def __lt__(self, other):
return self.value < other.value
self.assert_(A(1) < A(2))
self.assert_(A(2) > A(1))
self.assert_(A(1) <= A(2))
self.assert_(A(2) >= A(1))
self.assert_(A(2) <= A(2))
self.assert_(A(2) >= A(2))
def test_total_ordering_le(self):
@functools.total_ordering
class A:
def __init__(self, value):
self.value = value
def __le__(self, other):
return self.value <= other.value
self.assert_(A(1) < A(2))
self.assert_(A(2) > A(1))
self.assert_(A(1) <= A(2))
self.assert_(A(2) >= A(1))
self.assert_(A(2) <= A(2))
self.assert_(A(2) >= A(2))
def test_total_ordering_gt(self):
@functools.total_ordering
class A:
def __init__(self, value):
self.value = value
def __gt__(self, other):
return self.value > other.value
self.assert_(A(1) < A(2))
self.assert_(A(2) > A(1))
self.assert_(A(1) <= A(2))
self.assert_(A(2) >= A(1))
self.assert_(A(2) <= A(2))
self.assert_(A(2) >= A(2))
def test_total_ordering_ge(self):
@functools.total_ordering
class A:
def __init__(self, value):
self.value = value
def __ge__(self, other):
return self.value >= other.value
self.assert_(A(1) < A(2))
self.assert_(A(2) > A(1))
self.assert_(A(1) <= A(2))
self.assert_(A(2) >= A(1))
self.assert_(A(2) <= A(2))
self.assert_(A(2) >= A(2))
def test_total_ordering_no_overwrite(self):
# new methods should not overwrite existing
@functools.total_ordering
class A(int):
raise Exception()
self.assert_(A(1) < A(2))
self.assert_(A(2) > A(1))
self.assert_(A(1) <= A(2))
self.assert_(A(2) >= A(1))
self.assert_(A(2) <= A(2))
self.assert_(A(2) >= A(2))
def test_main(verbose=None):