Backport r87613 to make OrderedDict subclassing match dict subclassing.

This commit is contained in:
Raymond Hettinger 2011-01-04 20:57:19 +00:00
parent db0ef2b5e5
commit 1d879f6852
2 changed files with 49 additions and 5 deletions

View file

@ -21,7 +21,7 @@ from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
class _Link(object): class _Link(object):
__slots__ = 'prev', 'next', 'key', '__weakref__' __slots__ = 'prev', 'next', 'key', '__weakref__'
class OrderedDict(dict, MutableMapping): class OrderedDict(dict):
'Dictionary that remembers insertion order' 'Dictionary that remembers insertion order'
# An inherited dict maps keys to values. # An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get. # The inherited dict provides __getitem__, __len__, __contains__, and get.
@ -50,7 +50,7 @@ class OrderedDict(dict, MutableMapping):
self.__root = root = _Link() # sentinel node for the doubly linked list self.__root = root = _Link() # sentinel node for the doubly linked list
root.prev = root.next = root root.prev = root.next = root
self.__map = {} self.__map = {}
self.update(*args, **kwds) self.__update(*args, **kwds)
def clear(self): def clear(self):
'od.clear() -> None. Remove all items from od.' 'od.clear() -> None. Remove all items from od.'
@ -109,13 +109,29 @@ class OrderedDict(dict, MutableMapping):
return (self.__class__, (items,), inst_dict) return (self.__class__, (items,), inst_dict)
return self.__class__, (items,) return self.__class__, (items,)
setdefault = MutableMapping.setdefault update = __update = MutableMapping.update
update = MutableMapping.update
pop = MutableMapping.pop
keys = MutableMapping.keys keys = MutableMapping.keys
values = MutableMapping.values values = MutableMapping.values
items = MutableMapping.items items = MutableMapping.items
__marker = object()
def pop(self, key, default=__marker):
if key in self:
result = self[key]
del self[key]
return result
if default is self.__marker:
raise KeyError(key)
return default
def setdefault(self, key, default=None):
'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
if key in self:
return self[key]
self[key] = default
return default
def popitem(self, last=True): def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair. '''od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false. Pairs are returned in LIFO order if last is true or FIFO order if false.

View file

@ -792,6 +792,10 @@ class TestOrderedDict(unittest.TestCase):
self.assertEqual(list(d.items()), self.assertEqual(list(d.items()),
[('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)])
def test_abc(self):
self.assertTrue(isinstance(OrderedDict(), MutableMapping))
self.assertTrue(issubclass(OrderedDict, MutableMapping))
def test_clear(self): def test_clear(self):
pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
shuffle(pairs) shuffle(pairs)
@ -850,6 +854,17 @@ class TestOrderedDict(unittest.TestCase):
self.assertEqual(len(od), 0) self.assertEqual(len(od), 0)
self.assertEqual(od.pop(k, 12345), 12345) self.assertEqual(od.pop(k, 12345), 12345)
# make sure pop still works when __missing__ is defined
class Missing(OrderedDict):
def __missing__(self, key):
return 0
m = Missing(a=1)
self.assertEqual(m.pop('b', 5), 5)
self.assertEqual(m.pop('a', 6), 1)
self.assertEqual(m.pop('a', 6), 6)
with self.assertRaises(KeyError):
m.pop('a')
def test_equality(self): def test_equality(self):
pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
shuffle(pairs) shuffle(pairs)
@ -934,6 +949,12 @@ class TestOrderedDict(unittest.TestCase):
# make sure 'x' is added to the end # make sure 'x' is added to the end
self.assertEqual(list(od.items())[-1], ('x', 10)) self.assertEqual(list(od.items())[-1], ('x', 10))
# make sure setdefault still works when __missing__ is defined
class Missing(OrderedDict):
def __missing__(self, key):
return 0
self.assertEqual(Missing().setdefault(5, 9), 9)
def test_reinsert(self): def test_reinsert(self):
# Given insert a, insert b, delete a, re-insert a, # Given insert a, insert b, delete a, re-insert a,
# verify that a is now later than b. # verify that a is now later than b.
@ -945,6 +966,13 @@ class TestOrderedDict(unittest.TestCase):
self.assertEqual(list(od.items()), [('b', 2), ('a', 1)]) self.assertEqual(list(od.items()), [('b', 2), ('a', 1)])
def test_override_update(self):
# Verify that subclasses can override update() without breaking __init__()
class MyOD(OrderedDict):
def update(self, *args, **kwds):
raise Exception()
items = [('a', 1), ('c', 3), ('b', 2)]
self.assertEqual(list(MyOD(items).items()), items)
class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
type2test = OrderedDict type2test = OrderedDict