bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)

Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
This commit is contained in:
Eric V. Smith 2018-01-27 19:07:40 -05:00 committed by GitHub
parent 2a2247ce5e
commit ea8fc52e75
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 679 additions and 295 deletions

View file

@ -9,6 +9,7 @@ import unittest
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
from collections import deque, OrderedDict, namedtuple
from functools import total_ordering
# Just any custom exception we can catch.
class CustomError(Exception): pass
@ -82,68 +83,12 @@ class TestCase(unittest.TestCase):
class C(B):
x: int = 0
def test_overwriting_init(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __init__ '
'in C'):
@dataclass
class C:
x: int
def __init__(self, x):
self.x = 2 * x
@dataclass(init=False)
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(5).x, 10)
def test_overwriting_repr(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __repr__ '
'in C'):
@dataclass
class C:
x: int
def __repr__(self):
pass
@dataclass(repr=False)
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
def test_overwriting_cmp(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __eq__ '
'in C'):
# This will generate the comparison functions, make sure we can't
# overwrite them.
@dataclass(hash=False, frozen=False)
class C:
x: int
def __eq__(self):
pass
@dataclass(order=False, eq=False)
class C:
x: int
def __eq__(self, other):
return True
self.assertEqual(C(0), 'x')
def test_overwriting_hash(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __hash__ '
'in C'):
@dataclass(frozen=True)
class C:
x: int
def __hash__(self):
pass
@dataclass(frozen=True)
class C:
x: int
def __hash__(self):
pass
@dataclass(frozen=True,hash=False)
class C:
@ -152,14 +97,11 @@ class TestCase(unittest.TestCase):
return 600
self.assertEqual(hash(C(0)), 600)
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __hash__ '
'in C'):
@dataclass(frozen=True)
class C:
x: int
def __hash__(self):
pass
@dataclass(frozen=True)
class C:
x: int
def __hash__(self):
pass
@dataclass(frozen=True, hash=False)
class C:
@ -168,33 +110,6 @@ class TestCase(unittest.TestCase):
return 600
self.assertEqual(hash(C(0)), 600)
def test_overwriting_frozen(self):
# frozen uses __setattr__ and __delattr__
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __setattr__ '
'in C'):
@dataclass(frozen=True)
class C:
x: int
def __setattr__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __delattr__ '
'in C'):
@dataclass(frozen=True)
class C:
x: int
def __delattr__(self):
pass
@dataclass(frozen=False)
class C:
x: int
def __setattr__(self, name, value):
self.__dict__['x'] = value * 2
self.assertEqual(C(10).x, 20)
def test_overwrite_fields_in_derived_class(self):
# Note that x from C1 replaces x in Base, but the order remains
# the same as defined in Base.
@ -239,34 +154,6 @@ class TestCase(unittest.TestCase):
first = next(iter(sig.parameters))
self.assertEqual('self', first)
def test_repr(self):
@dataclass
class B:
x: int
@dataclass
class C(B):
y: int = 10
o = C(4)
self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
@dataclass
class D(C):
x: int = 20
self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
@dataclass
class C:
@dataclass
class D:
i: int
@dataclass
class E:
pass
self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
def test_0_field_compare(self):
# Ensure that order=False is the default.
@dataclass
@ -420,80 +307,8 @@ class TestCase(unittest.TestCase):
self.assertEqual(hash(C(4)), hash((4,)))
self.assertEqual(hash(C(42)), hash((42,)))
def test_hash(self):
@dataclass(hash=True)
class C:
x: int
y: str
self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
def test_no_hash(self):
@dataclass(hash=None)
class C:
x: int
with self.assertRaisesRegex(TypeError,
"unhashable type: 'C'"):
hash(C(1))
def test_hash_rules(self):
# There are 24 cases of:
# hash=True/False/None
# eq=True/False
# order=True/False
# frozen=True/False
for (hash, eq, order, frozen, result ) in [
(False, False, False, False, 'absent'),
(False, False, False, True, 'absent'),
(False, False, True, False, 'exception'),
(False, False, True, True, 'exception'),
(False, True, False, False, 'absent'),
(False, True, False, True, 'absent'),
(False, True, True, False, 'absent'),
(False, True, True, True, 'absent'),
(True, False, False, False, 'fn'),
(True, False, False, True, 'fn'),
(True, False, True, False, 'exception'),
(True, False, True, True, 'exception'),
(True, True, False, False, 'fn'),
(True, True, False, True, 'fn'),
(True, True, True, False, 'fn'),
(True, True, True, True, 'fn'),
(None, False, False, False, 'absent'),
(None, False, False, True, 'absent'),
(None, False, True, False, 'exception'),
(None, False, True, True, 'exception'),
(None, True, False, False, 'none'),
(None, True, False, True, 'fn'),
(None, True, True, False, 'none'),
(None, True, True, True, 'fn'),
]:
with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
if result == 'exception':
with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
@dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
class C:
pass
else:
@dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
class C:
pass
# See if the result matches what's expected.
if result == 'fn':
# __hash__ contains the function we generated.
self.assertIn('__hash__', C.__dict__)
self.assertIsNotNone(C.__dict__['__hash__'])
elif result == 'absent':
# __hash__ is not present in our class.
self.assertNotIn('__hash__', C.__dict__)
elif result == 'none':
# __hash__ is set to None.
self.assertIn('__hash__', C.__dict__)
self.assertIsNone(C.__dict__['__hash__'])
else:
assert False, f'unknown result {result!r}'
def test_eq_order(self):
# Test combining eq and order.
for (eq, order, result ) in [
(False, False, 'neither'),
(False, True, 'exception'),
@ -513,21 +328,18 @@ class TestCase(unittest.TestCase):
if result == 'neither':
self.assertNotIn('__eq__', C.__dict__)
self.assertNotIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__)
elif result == 'both':
self.assertIn('__eq__', C.__dict__)
self.assertIn('__ne__', C.__dict__)
self.assertIn('__lt__', C.__dict__)
self.assertIn('__le__', C.__dict__)
self.assertIn('__gt__', C.__dict__)
self.assertIn('__ge__', C.__dict__)
elif result == 'eq_only':
self.assertIn('__eq__', C.__dict__)
self.assertIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
@ -811,19 +623,6 @@ class TestCase(unittest.TestCase):
y: int
self.assertNotEqual(Point(1, 3), C(1, 3))
def test_base_has_init(self):
class B:
def __init__(self):
pass
# Make sure that declaring this class doesn't raise an error.
# The issue is that we can't override __init__ in our class,
# but it should be okay to add __init__ to us if our base has
# an __init__.
@dataclass
class C(B):
x: int = 0
def test_frozen(self):
@dataclass(frozen=True)
class C:
@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase):
'y': int,
'z': 'typing.Any'})
class TestDocString(unittest.TestCase):
def assertDocStrEqual(self, a, b):
# Because 3.6 and 3.7 differ in how inspect.signature work
@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase):
self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
class TestInit(unittest.TestCase):
def test_base_has_init(self):
class B:
def __init__(self):
self.z = 100
pass
# Make sure that declaring this class doesn't raise an error.
# The issue is that we can't override __init__ in our class,
# but it should be okay to add __init__ to us if our base has
# an __init__.
@dataclass
class C(B):
x: int = 0
c = C(10)
self.assertEqual(c.x, 10)
self.assertNotIn('z', vars(c))
# Make sure that if we don't add an init, the base __init__
# gets called.
@dataclass(init=False)
class C(B):
x: int = 10
c = C()
self.assertEqual(c.x, 10)
self.assertEqual(c.z, 100)
def test_no_init(self):
dataclass(init=False)
class C:
i: int = 0
self.assertEqual(C().i, 0)
dataclass(init=False)
class C:
i: int = 2
def __init__(self):
self.i = 3
self.assertEqual(C().i, 3)
def test_overwriting_init(self):
# If the class has __init__, use it no matter the value of
# init=.
@dataclass
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(3).x, 6)
@dataclass(init=True)
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(4).x, 8)
@dataclass(init=False)
class C:
x: int
def __init__(self, x):
self.x = 2 * x
self.assertEqual(C(5).x, 10)
class TestRepr(unittest.TestCase):
def test_repr(self):
@dataclass
class B:
x: int
@dataclass
class C(B):
y: int = 10
o = C(4)
self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
@dataclass
class D(C):
x: int = 20
self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
@dataclass
class C:
@dataclass
class D:
i: int
@dataclass
class E:
pass
self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
def test_no_repr(self):
# Test a class with no __repr__ and repr=False.
@dataclass(repr=False)
class C:
x: int
self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
repr(C(3)))
# Test a class with a __repr__ and repr=False.
@dataclass(repr=False)
class C:
x: int
def __repr__(self):
return 'C-class'
self.assertEqual(repr(C(3)), 'C-class')
def test_overwriting_repr(self):
# If the class has __repr__, use it no matter the value of
# repr=.
@dataclass
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
@dataclass(repr=True)
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
@dataclass(repr=False)
class C:
x: int
def __repr__(self):
return 'x'
self.assertEqual(repr(C(0)), 'x')
class TestFrozen(unittest.TestCase):
def test_overwriting_frozen(self):
# frozen uses __setattr__ and __delattr__
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __setattr__'):
@dataclass(frozen=True)
class C:
x: int
def __setattr__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __delattr__'):
@dataclass(frozen=True)
class C:
x: int
def __delattr__(self):
pass
@dataclass(frozen=False)
class C:
x: int
def __setattr__(self, name, value):
self.__dict__['x'] = value * 2
self.assertEqual(C(10).x, 20)
class TestEq(unittest.TestCase):
def test_no_eq(self):
# Test a class with no __eq__ and eq=False.
@dataclass(eq=False)
class C:
x: int
self.assertNotEqual(C(0), C(0))
c = C(3)
self.assertEqual(c, c)
# Test a class with an __eq__ and eq=False.
@dataclass(eq=False)
class C:
x: int
def __eq__(self, other):
return other == 10
self.assertEqual(C(3), 10)
def test_overwriting_eq(self):
# If the class has __eq__, use it no matter the value of
# eq=.
@dataclass
class C:
x: int
def __eq__(self, other):
return other == 3
self.assertEqual(C(1), 3)
self.assertNotEqual(C(1), 1)
@dataclass(eq=True)
class C:
x: int
def __eq__(self, other):
return other == 4
self.assertEqual(C(1), 4)
self.assertNotEqual(C(1), 1)
@dataclass(eq=False)
class C:
x: int
def __eq__(self, other):
return other == 5
self.assertEqual(C(1), 5)
self.assertNotEqual(C(1), 1)
class TestOrdering(unittest.TestCase):
def test_functools_total_ordering(self):
# Test that functools.total_ordering works with this class.
@total_ordering
@dataclass
class C:
x: int
def __lt__(self, other):
# Perform the test "backward", just to make
# sure this is being called.
return self.x >= other
self.assertLess(C(0), -1)
self.assertLessEqual(C(0), -1)
self.assertGreater(C(0), 1)
self.assertGreaterEqual(C(0), 1)
def test_no_order(self):
# Test that no ordering functions are added by default.
@dataclass(order=False)
class C:
x: int
# Make sure no order methods are added.
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
# Test that __lt__ is still called
@dataclass(order=False)
class C:
x: int
def __lt__(self, other):
return False
# Make sure other methods aren't added.
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
def test_overwriting_order(self):
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __lt__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __lt__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __le__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __le__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __gt__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __gt__(self):
pass
with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __ge__'
'.*using functools.total_ordering'):
@dataclass(order=True)
class C:
x: int
def __ge__(self):
pass
class TestHash(unittest.TestCase):
def test_hash(self):
@dataclass(hash=True)
class C:
x: int
y: str
self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
def test_hash_false(self):
@dataclass(hash=False)
class C:
x: int
y: str
self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
def test_hash_none(self):
@dataclass(hash=None)
class C:
x: int
with self.assertRaisesRegex(TypeError,
"unhashable type: 'C'"):
hash(C(1))
def test_hash_rules(self):
def non_bool(value):
# Map to something else that's True, but not a bool.
if value is None:
return None
if value:
return (3,)
return 0
def test(case, hash, eq, frozen, with_hash, result):
with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
if with_hash:
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
def __hash__(self):
return 0
else:
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
pass
# See if the result matches what's expected.
if result in ('fn', 'fn-x'):
# __hash__ contains the function we generated.
self.assertIn('__hash__', C.__dict__)
self.assertIsNotNone(C.__dict__['__hash__'])
if result == 'fn-x':
# This is the "auto-hash test" case. We
# should overwrite __hash__ iff there's an
# __eq__ and if __hash__=None.
# There are two ways of getting __hash__=None:
# explicitely, and by defining __eq__. If
# __eq__ is defined, python will add __hash__
# when the class is created.
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
def __eq__(self, other): pass
__hash__ = None
# Hash should be overwritten (non-None).
self.assertIsNotNone(C.__dict__['__hash__'])
# Same test as above, but we don't provide
# __hash__, it will implicitely set to None.
@dataclass(hash=hash, eq=eq, frozen=frozen)
class C:
def __eq__(self, other): pass
# Hash should be overwritten (non-None).
self.assertIsNotNone(C.__dict__['__hash__'])
elif result == '':
# __hash__ is not present in our class.
if not with_hash:
self.assertNotIn('__hash__', C.__dict__)
elif result == 'none':
# __hash__ is set to None.
self.assertIn('__hash__', C.__dict__)
self.assertIsNone(C.__dict__['__hash__'])
else:
assert False, f'unknown result {result!r}'
# There are 12 cases of:
# hash=True/False/None
# eq=True/False
# frozen=True/False
# And for each of these, a different result if
# __hash__ is defined or not.
for case, (hash, eq, frozen, result_no, result_yes) in enumerate([
(None, False, False, '', ''),
(None, False, True, '', ''),
(None, True, False, 'none', ''),
(None, True, True, 'fn', 'fn-x'),
(False, False, False, '', ''),
(False, False, True, '', ''),
(False, True, False, '', ''),
(False, True, True, '', ''),
(True, False, False, 'fn', 'fn-x'),
(True, False, True, 'fn', 'fn-x'),
(True, True, False, 'fn', 'fn-x'),
(True, True, True, 'fn', 'fn-x'),
], 1):
test(case, hash, eq, frozen, False, result_no)
test(case, hash, eq, frozen, True, result_yes)
# Test non-bool truth values, too. This is just to
# make sure the data-driven table in the decorator
# handles non-bool values.
test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes)
def test_eq_only(self):
# If a class defines __eq__, __hash__ is automatically added
# and set to None. This is normal Python behavior, not
# related to dataclasses. Make sure we don't interfere with
# that (see bpo=32546).
@dataclass
class C:
i: int
def __eq__(self, other):
return self.i == other.i
self.assertEqual(C(1), C(1))
self.assertNotEqual(C(1), C(4))
# And make sure things work in this case if we specify
# hash=True.
@dataclass(hash=True)
class C:
i: int
def __eq__(self, other):
return self.i == other.i
self.assertEqual(C(1), C(1.0))
self.assertEqual(hash(C(1)), hash(C(1.0)))
# And check that the classes __eq__ is being used, despite
# specifying eq=True.
@dataclass(hash=True, eq=True)
class C:
i: int
def __eq__(self, other):
return self.i == 3 and self.i == other.i
self.assertEqual(C(3), C(3))
self.assertNotEqual(C(1), C(1))
self.assertEqual(hash(C(1)), hash(C(1.0)))
if __name__ == '__main__':
unittest.main()