Fix field name conflicts for named tuples.

This commit is contained in:
Raymond Hettinger 2009-05-27 02:24:45 +00:00
parent 55d8828f98
commit a68cad13ae
3 changed files with 53 additions and 13 deletions

View file

@ -673,8 +673,8 @@ Example:
<BLANKLINE> <BLANKLINE>
_fields = ('x', 'y') _fields = ('x', 'y')
<BLANKLINE> <BLANKLINE>
def __new__(cls, x, y): def __new__(_cls, x, y):
return tuple.__new__(cls, (x, y)) return _tuple.__new__(_cls, (x, y))
<BLANKLINE> <BLANKLINE>
@classmethod @classmethod
def _make(cls, iterable, new=tuple.__new__, len=len): def _make(cls, iterable, new=tuple.__new__, len=len):
@ -691,9 +691,9 @@ Example:
'Return a new OrderedDict which maps field names to their values' 'Return a new OrderedDict which maps field names to their values'
return OrderedDict(zip(self._fields, self)) return OrderedDict(zip(self._fields, self))
<BLANKLINE> <BLANKLINE>
def _replace(self, **kwds): def _replace(_self, **kwds):
'Return a new Point object replacing specified fields with new values' 'Return a new Point object replacing specified fields with new values'
result = self._make(map(kwds.pop, ('x', 'y'), self)) result = _self._make(map(kwds.pop, ('x', 'y'), _self))
if kwds: if kwds:
raise ValueError('Got unexpected field names: %r' % kwds.keys()) raise ValueError('Got unexpected field names: %r' % kwds.keys())
return result return result
@ -701,8 +701,8 @@ Example:
def __getnewargs__(self): def __getnewargs__(self):
return tuple(self) return tuple(self)
<BLANKLINE> <BLANKLINE>
x = property(itemgetter(0)) x = _property(_itemgetter(0))
y = property(itemgetter(1)) y = _property(_itemgetter(1))
>>> p = Point(11, y=22) # instantiate with positional or keyword arguments >>> p = Point(11, y=22) # instantiate with positional or keyword arguments
>>> p[0] + p[1] # indexable like the plain tuple (11, 22) >>> p[0] + p[1] # indexable like the plain tuple (11, 22)

View file

@ -229,8 +229,8 @@ def namedtuple(typename, field_names, verbose=False, rename=False):
'%(typename)s(%(argtxt)s)' \n '%(typename)s(%(argtxt)s)' \n
__slots__ = () \n __slots__ = () \n
_fields = %(field_names)r \n _fields = %(field_names)r \n
def __new__(cls, %(argtxt)s): def __new__(_cls, %(argtxt)s):
return tuple.__new__(cls, (%(argtxt)s)) \n return _tuple.__new__(_cls, (%(argtxt)s)) \n
@classmethod @classmethod
def _make(cls, iterable, new=tuple.__new__, len=len): def _make(cls, iterable, new=tuple.__new__, len=len):
'Make a new %(typename)s object from a sequence or iterable' 'Make a new %(typename)s object from a sequence or iterable'
@ -243,23 +243,23 @@ def namedtuple(typename, field_names, verbose=False, rename=False):
def _asdict(self): def _asdict(self):
'Return a new OrderedDict which maps field names to their values' 'Return a new OrderedDict which maps field names to their values'
return OrderedDict(zip(self._fields, self)) \n return OrderedDict(zip(self._fields, self)) \n
def _replace(self, **kwds): def _replace(_self, **kwds):
'Return a new %(typename)s object replacing specified fields with new values' 'Return a new %(typename)s object replacing specified fields with new values'
result = self._make(map(kwds.pop, %(field_names)r, self)) result = _self._make(map(kwds.pop, %(field_names)r, _self))
if kwds: if kwds:
raise ValueError('Got unexpected field names: %%r' %% kwds.keys()) raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
return result \n return result \n
def __getnewargs__(self): def __getnewargs__(self):
return tuple(self) \n\n''' % locals() return tuple(self) \n\n''' % locals()
for i, name in enumerate(field_names): for i, name in enumerate(field_names):
template += ' %s = property(itemgetter(%d))\n' % (name, i) template += ' %s = _property(_itemgetter(%d))\n' % (name, i)
if verbose: if verbose:
print template print template
# Execute the template string in a temporary namespace and # Execute the template string in a temporary namespace and
# support tracing utilities by setting a value for frame.f_globals['__name__'] # support tracing utilities by setting a value for frame.f_globals['__name__']
namespace = dict(itemgetter=_itemgetter, __name__='namedtuple_%s' % typename, namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
OrderedDict=OrderedDict) OrderedDict=OrderedDict, _property=property, _tuple=tuple)
try: try:
exec template in namespace exec template in namespace
except SyntaxError, e: except SyntaxError, e:

View file

@ -7,6 +7,8 @@ from test import mapping_tests
import pickle, cPickle, copy import pickle, cPickle, copy
from random import randrange, shuffle from random import randrange, shuffle
import operator import operator
import keyword
import re
from collections import Hashable, Iterable, Iterator from collections import Hashable, Iterable, Iterator
from collections import Sized, Container, Callable from collections import Sized, Container, Callable
from collections import Set, MutableSet from collections import Set, MutableSet
@ -170,6 +172,44 @@ class TestNamedTuple(unittest.TestCase):
self.assertEqual(p, q) self.assertEqual(p, q)
self.assertEqual(p._fields, q._fields) self.assertEqual(p._fields, q._fields)
def test_name_conflicts(self):
# Some names like "self", "cls", "tuple", "itemgetter", and "property"
# failed when used as field names. Test to make sure these now work.
T = namedtuple('T', 'itemgetter property self cls tuple')
t = T(1, 2, 3, 4, 5)
self.assertEqual(t, (1,2,3,4,5))
newt = t._replace(itemgetter=10, property=20, self=30, cls=40, tuple=50)
self.assertEqual(newt, (10,20,30,40,50))
# Broader test of all interesting names in a template
with test_support.captured_stdout() as template:
T = namedtuple('T', 'x', verbose=True)
words = set(re.findall('[A-Za-z]+', template.getvalue()))
words -= set(keyword.kwlist)
T = namedtuple('T', words)
# test __new__
values = tuple(range(len(words)))
t = T(*values)
self.assertEqual(t, values)
t = T(**dict(zip(T._fields, values)))
self.assertEqual(t, values)
# test _make
t = T._make(values)
self.assertEqual(t, values)
# exercise __repr__
repr(t)
# test _asdict
self.assertEqual(t._asdict(), dict(zip(T._fields, values)))
# test _replace
t = T._make(values)
newvalues = tuple(v*10 for v in values)
newt = t._replace(**dict(zip(T._fields, newvalues)))
self.assertEqual(newt, newvalues)
# test _fields
self.assertEqual(T._fields, tuple(words))
# test __getnewargs__
self.assertEqual(t.__getnewargs__(), values)
class ABCTestCase(unittest.TestCase): class ABCTestCase(unittest.TestCase):
def validate_abstract_methods(self, abc, *names): def validate_abstract_methods(self, abc, *names):