Add reference implementation for PEP 443

PEP accepted: http://mail.python.org/pipermail/python-dev/2013-June/126734.html
This commit is contained in:
Łukasz Langa 2013-06-05 12:20:24 +02:00
parent 072318b178
commit 6f69251980
6 changed files with 614 additions and 55 deletions

View file

@ -1,24 +1,30 @@
import collections
import sys
import unittest
from test import support
from weakref import proxy
from itertools import permutations
import pickle
from random import choice
import sys
from test import support
import unittest
from weakref import proxy
import functools
py_functools = support.import_fresh_module('functools', blocked=['_functools'])
c_functools = support.import_fresh_module('functools', fresh=['_functools'])
decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
def capture(*args, **kw):
"""capture all positional and keyword arguments"""
return args, kw
def signature(part):
""" return the signature of a partial object """
return (part.func, part.args, part.keywords, part.__dict__)
class TestPartial:
def test_basic_examples(self):
@ -138,6 +144,7 @@ class TestPartial:
join = self.partial(''.join)
self.assertEqual(join(data), '0123456789')
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialC(TestPartial, unittest.TestCase):
if c_functools:
@ -194,18 +201,22 @@ class TestPartialC(TestPartial, unittest.TestCase):
"new style getargs format but argument is not a tuple",
f.__setstate__, BadSequence())
class TestPartialPy(TestPartial, unittest.TestCase):
partial = staticmethod(py_functools.partial)
if c_functools:
class PartialSubclass(c_functools.partial):
pass
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialCSubclass(TestPartialC):
if c_functools:
partial = PartialSubclass
class TestUpdateWrapper(unittest.TestCase):
def check_wrapper(self, wrapper, wrapped,
@ -312,6 +323,7 @@ class TestUpdateWrapper(unittest.TestCase):
self.assertTrue(wrapper.__doc__.startswith('max('))
self.assertEqual(wrapper.__annotations__, {})
class TestWraps(TestUpdateWrapper):
def _default_update(self):
@ -372,6 +384,7 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
class TestReduce(unittest.TestCase):
func = functools.reduce
@ -452,6 +465,7 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
class TestCmpToKey:
def test_cmp_to_key(self):
@ -534,14 +548,17 @@ class TestCmpToKey:
self.assertRaises(TypeError, hash, k)
self.assertNotIsInstance(k, collections.Hashable)
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
if c_functools:
cmp_to_key = c_functools.cmp_to_key
class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
cmp_to_key = staticmethod(py_functools.cmp_to_key)
class TestTotalOrdering(unittest.TestCase):
def test_total_ordering_lt(self):
@ -642,6 +659,7 @@ class TestTotalOrdering(unittest.TestCase):
with self.assertRaises(TypeError):
TestTO(8) <= ()
class TestLRU(unittest.TestCase):
def test_lru(self):
@ -834,6 +852,353 @@ class TestLRU(unittest.TestCase):
DoubleEq(2)) # Verify the correct return value
class TestSingleDispatch(unittest.TestCase):
def test_simple_overloads(self):
@functools.singledispatch
def g(obj):
return "base"
def g_int(i):
return "integer"
g.register(int, g_int)
self.assertEqual(g("str"), "base")
self.assertEqual(g(1), "integer")
self.assertEqual(g([1,2,3]), "base")
def test_mro(self):
@functools.singledispatch
def g(obj):
return "base"
class C:
pass
class D(C):
pass
def g_C(c):
return "C"
g.register(C, g_C)
self.assertEqual(g(C()), "C")
self.assertEqual(g(D()), "C")
def test_classic_classes(self):
@functools.singledispatch
def g(obj):
return "base"
class C:
pass
class D(C):
pass
def g_C(c):
return "C"
g.register(C, g_C)
self.assertEqual(g(C()), "C")
self.assertEqual(g(D()), "C")
def test_register_decorator(self):
@functools.singledispatch
def g(obj):
return "base"
@g.register(int)
def g_int(i):
return "int %s" % (i,)
self.assertEqual(g(""), "base")
self.assertEqual(g(12), "int 12")
self.assertIs(g.dispatch(int), g_int)
self.assertIs(g.dispatch(object), g.dispatch(str))
# Note: in the assert above this is not g.
# @singledispatch returns the wrapper.
def test_wrapping_attributes(self):
@functools.singledispatch
def g(obj):
"Simple test"
return "Test"
self.assertEqual(g.__name__, "g")
self.assertEqual(g.__doc__, "Simple test")
@unittest.skipUnless(decimal, 'requires _decimal')
@support.cpython_only
def test_c_classes(self):
@functools.singledispatch
def g(obj):
return "base"
@g.register(decimal.DecimalException)
def _(obj):
return obj.args
subn = decimal.Subnormal("Exponent < Emin")
rnd = decimal.Rounded("Number got rounded")
self.assertEqual(g(subn), ("Exponent < Emin",))
self.assertEqual(g(rnd), ("Number got rounded",))
@g.register(decimal.Subnormal)
def _(obj):
return "Too small to care."
self.assertEqual(g(subn), "Too small to care.")
self.assertEqual(g(rnd), ("Number got rounded",))
def test_compose_mro(self):
c = collections
mro = functools._compose_mro
bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
for haystack in permutations(bases):
m = mro(dict, haystack)
self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object])
bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
for haystack in permutations(bases):
m = mro(c.ChainMap, haystack)
self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
c.Sized, c.Iterable, c.Container, object])
# Note: The MRO order below depends on haystack ordering.
m = mro(c.defaultdict, [c.Sized, c.Container, str])
self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object])
m = mro(c.defaultdict, [c.Container, c.Sized, str])
self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object])
def test_register_abc(self):
c = collections
d = {"a": "b"}
l = [1, 2, 3]
s = {object(), None}
f = frozenset(s)
t = (1, 2, 3)
@functools.singledispatch
def g(obj):
return "base"
self.assertEqual(g(d), "base")
self.assertEqual(g(l), "base")
self.assertEqual(g(s), "base")
self.assertEqual(g(f), "base")
self.assertEqual(g(t), "base")
g.register(c.Sized, lambda obj: "sized")
self.assertEqual(g(d), "sized")
self.assertEqual(g(l), "sized")
self.assertEqual(g(s), "sized")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
g.register(c.MutableMapping, lambda obj: "mutablemapping")
self.assertEqual(g(d), "mutablemapping")
self.assertEqual(g(l), "sized")
self.assertEqual(g(s), "sized")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
g.register(c.ChainMap, lambda obj: "chainmap")
self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
self.assertEqual(g(l), "sized")
self.assertEqual(g(s), "sized")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
g.register(c.MutableSequence, lambda obj: "mutablesequence")
self.assertEqual(g(d), "mutablemapping")
self.assertEqual(g(l), "mutablesequence")
self.assertEqual(g(s), "sized")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
g.register(c.MutableSet, lambda obj: "mutableset")
self.assertEqual(g(d), "mutablemapping")
self.assertEqual(g(l), "mutablesequence")
self.assertEqual(g(s), "mutableset")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
g.register(c.Mapping, lambda obj: "mapping")
self.assertEqual(g(d), "mutablemapping") # not specific enough
self.assertEqual(g(l), "mutablesequence")
self.assertEqual(g(s), "mutableset")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
g.register(c.Sequence, lambda obj: "sequence")
self.assertEqual(g(d), "mutablemapping")
self.assertEqual(g(l), "mutablesequence")
self.assertEqual(g(s), "mutableset")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sequence")
g.register(c.Set, lambda obj: "set")
self.assertEqual(g(d), "mutablemapping")
self.assertEqual(g(l), "mutablesequence")
self.assertEqual(g(s), "mutableset")
self.assertEqual(g(f), "set")
self.assertEqual(g(t), "sequence")
g.register(dict, lambda obj: "dict")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "mutablesequence")
self.assertEqual(g(s), "mutableset")
self.assertEqual(g(f), "set")
self.assertEqual(g(t), "sequence")
g.register(list, lambda obj: "list")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list")
self.assertEqual(g(s), "mutableset")
self.assertEqual(g(f), "set")
self.assertEqual(g(t), "sequence")
g.register(set, lambda obj: "concrete-set")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list")
self.assertEqual(g(s), "concrete-set")
self.assertEqual(g(f), "set")
self.assertEqual(g(t), "sequence")
g.register(frozenset, lambda obj: "frozen-set")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list")
self.assertEqual(g(s), "concrete-set")
self.assertEqual(g(f), "frozen-set")
self.assertEqual(g(t), "sequence")
g.register(tuple, lambda obj: "tuple")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list")
self.assertEqual(g(s), "concrete-set")
self.assertEqual(g(f), "frozen-set")
self.assertEqual(g(t), "tuple")
def test_mro_conflicts(self):
c = collections
@functools.singledispatch
def g(arg):
return "base"
class O(c.Sized):
def __len__(self):
return 0
o = O()
self.assertEqual(g(o), "base")
g.register(c.Iterable, lambda arg: "iterable")
g.register(c.Container, lambda arg: "container")
g.register(c.Sized, lambda arg: "sized")
g.register(c.Set, lambda arg: "set")
self.assertEqual(g(o), "sized")
c.Iterable.register(O)
self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
c.Container.register(O)
self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
class P:
pass
p = P()
self.assertEqual(g(p), "base")
c.Iterable.register(P)
self.assertEqual(g(p), "iterable")
c.Container.register(P)
with self.assertRaises(RuntimeError) as re:
g(p)
self.assertEqual(
str(re),
("Ambiguous dispatch: <class 'collections.abc.Container'> "
"or <class 'collections.abc.Iterable'>"),
)
class Q(c.Sized):
def __len__(self):
return 0
q = Q()
self.assertEqual(g(q), "sized")
c.Iterable.register(Q)
self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
c.Set.register(Q)
self.assertEqual(g(q), "set") # because c.Set is a subclass of
# c.Sized which is explicitly in
# __mro__
def test_cache_invalidation(self):
from collections import UserDict
class TracingDict(UserDict):
def __init__(self, *args, **kwargs):
super(TracingDict, self).__init__(*args, **kwargs)
self.set_ops = []
self.get_ops = []
def __getitem__(self, key):
result = self.data[key]
self.get_ops.append(key)
return result
def __setitem__(self, key, value):
self.set_ops.append(key)
self.data[key] = value
def clear(self):
self.data.clear()
_orig_wkd = functools.WeakKeyDictionary
td = TracingDict()
functools.WeakKeyDictionary = lambda: td
c = collections
@functools.singledispatch
def g(arg):
return "base"
d = {}
l = []
self.assertEqual(len(td), 0)
self.assertEqual(g(d), "base")
self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, [])
self.assertEqual(td.set_ops, [dict])
self.assertEqual(td.data[dict], g.registry[object])
self.assertEqual(g(l), "base")
self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, [])
self.assertEqual(td.set_ops, [dict, list])
self.assertEqual(td.data[dict], g.registry[object])
self.assertEqual(td.data[list], g.registry[object])
self.assertEqual(td.data[dict], td.data[list])
self.assertEqual(g(l), "base")
self.assertEqual(g(d), "base")
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list])
g.register(list, lambda arg: "list")
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(len(td), 0)
self.assertEqual(g(d), "base")
self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict])
self.assertEqual(td.data[dict],
functools._find_impl(dict, g.registry))
self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list])
self.assertEqual(td.data[list],
functools._find_impl(list, g.registry))
class X:
pass
c.MutableMapping.register(X) # Will not invalidate the cache,
# not using ABCs yet.
self.assertEqual(g(d), "base")
self.assertEqual(g(l), "list")
self.assertEqual(td.get_ops, [list, dict, dict, list])
self.assertEqual(td.set_ops, [dict, list, dict, list])
g.register(c.Sized, lambda arg: "sized")
self.assertEqual(len(td), 0)
self.assertEqual(g(d), "sized")
self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, [list, dict, dict, list])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, [list, dict, dict, list])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
self.assertEqual(g(l), "list")
self.assertEqual(g(d), "sized")
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
g.dispatch(list)
g.dispatch(dict)
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
c.MutableSet.register(X) # Will invalidate the cache.
self.assertEqual(len(td), 2) # Stale cache.
self.assertEqual(g(l), "list")
self.assertEqual(len(td), 1)
g.register(c.MutableMapping, lambda arg: "mutablemapping")
self.assertEqual(len(td), 0)
self.assertEqual(g(d), "mutablemapping")
self.assertEqual(len(td), 1)
self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2)
g.register(dict, lambda arg: "dict")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list")
g._clear_cache()
self.assertEqual(len(td), 0)
functools.WeakKeyDictionary = _orig_wkd
def test_main(verbose=None):
test_classes = (
TestPartialC,
@ -846,6 +1211,7 @@ def test_main(verbose=None):
TestWraps,
TestReduce,
TestLRU,
TestSingleDispatch,
)
support.run_unittest(*test_classes)