bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)

Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
This commit is contained in:
Eric V. Smith 2018-01-27 19:07:40 -05:00 committed by GitHub
parent 2a2247ce5e
commit ea8fc52e75
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 679 additions and 295 deletions

View file

@ -18,6 +18,142 @@ __all__ = ['dataclass',
'is_dataclass',
]
# Conditions for adding methods. The boxes indicate what action the
# dataclass decorator takes. For all of these tables, when I talk
# about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring
# to the arguments to the @dataclass decorator. When checking if a
# dunder method already exists, I mean check for an entry in the
# class's __dict__. I never check to see if an attribute is defined
# in a base class.
# Key:
# +=========+=========================================+
# + Value | Meaning |
# +=========+=========================================+
# | <blank> | No action: no method is added. |
# +---------+-----------------------------------------+
# | add | Generated method is added. |
# +---------+-----------------------------------------+
# | add* | Generated method is added only if the |
# | | existing attribute is None and if the |
# | | user supplied a __eq__ method in the |
# | | class definition. |
# +---------+-----------------------------------------+
# | raise | TypeError is raised. |
# +---------+-----------------------------------------+
# | None | Attribute is set to None. |
# +=========+=========================================+
# __init__
#
# +--- init= parameter
# |
# v | | |
# | no | yes | <--- class has __init__ in __dict__?
# +=======+=======+=======+
# | False | | |
# +-------+-------+-------+
# | True | add | | <- the default
# +=======+=======+=======+
# __repr__
#
# +--- repr= parameter
# |
# v | | |
# | no | yes | <--- class has __repr__ in __dict__?
# +=======+=======+=======+
# | False | | |
# +-------+-------+-------+
# | True | add | | <- the default
# +=======+=======+=======+
# __setattr__
# __delattr__
#
# +--- frozen= parameter
# |
# v | | |
# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__?
# +=======+=======+=======+
# | False | | | <- the default
# +-------+-------+-------+
# | True | add | raise |
# +=======+=======+=======+
# Raise because not adding these methods would break the "frozen-ness"
# of the class.
# __eq__
#
# +--- eq= parameter
# |
# v | | |
# | no | yes | <--- class has __eq__ in __dict__?
# +=======+=======+=======+
# | False | | |
# +-------+-------+-------+
# | True | add | | <- the default
# +=======+=======+=======+
# __lt__
# __le__
# __gt__
# __ge__
#
# +--- order= parameter
# |
# v | | |
# | no | yes | <--- class has any comparison method in __dict__?
# +=======+=======+=======+
# | False | | | <- the default
# +-------+-------+-------+
# | True | add | raise |
# +=======+=======+=======+
# Raise because to allow this case would interfere with using
# functools.total_ordering.
# __hash__
# +------------------- hash= parameter
# | +----------- eq= parameter
# | | +--- frozen= parameter
# | | |
# v v v | | |
# | no | yes | <--- class has __hash__ in __dict__?
# +=========+=======+=======+========+========+
# | 1 None | False | False | | | No __eq__, use the base class __hash__
# +---------+-------+-------+--------+--------+
# | 2 None | False | True | | | No __eq__, use the base class __hash__
# +---------+-------+-------+--------+--------+
# | 3 None | True | False | None | | <-- the default, not hashable
# +---------+-------+-------+--------+--------+
# | 4 None | True | True | add | add* | Frozen, so hashable
# +---------+-------+-------+--------+--------+
# | 5 False | False | False | | |
# +---------+-------+-------+--------+--------+
# | 6 False | False | True | | |
# +---------+-------+-------+--------+--------+
# | 7 False | True | False | | |
# +---------+-------+-------+--------+--------+
# | 8 False | True | True | | |
# +---------+-------+-------+--------+--------+
# | 9 True | False | False | add | add* | Has no __eq__, but hashable
# +---------+-------+-------+--------+--------+
# |10 True | False | True | add | add* | Has no __eq__, but hashable
# +---------+-------+-------+--------+--------+
# |11 True | True | False | add | add* | Not frozen, but hashable
# +---------+-------+-------+--------+--------+
# |12 True | True | True | add | add* | Frozen, so hashable
# +=========+=======+=======+========+========+
# For boxes that are blank, __hash__ is untouched and therefore
# inherited from the base class. If the base is object, then
# id-based hashing is used.
# Note that a class may have already __hash__=None if it specified an
# __eq__ method in the class body (not one that was created by
# @dataclass).
# Raised when an attempt is made to modify a frozen class.
class FrozenInstanceError(AttributeError): pass
@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields):
# return "(self.x,self.y)".
# Special case for the 0-tuple.
if len(fields) == 0:
if not fields:
return '()'
# Note the trailing comma, needed if this turns out to be a 1-tuple.
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
def _create_fn(name, args, body, globals=None, locals=None,
def _create_fn(name, args, body, *, globals=None, locals=None,
return_type=MISSING):
# Note that we mutate locals when exec() is called. Caller beware!
if locals is None:
@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name):
body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
# If no body lines, use 'pass'.
if len(body_lines) == 0:
if not body_lines:
body_lines = ['pass']
locals = {f'_type_{f.name}': f.type for f in fields}
@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
'return NotImplemented'])
def _set_eq_fns(cls, fields):
# Create and set the equality comparison methods on cls.
# Pre-compute self_tuple and other_tuple, then re-use them for
# each function.
self_tuple = _tuple_str('self', fields)
other_tuple = _tuple_str('other', fields)
for name, op in [('__eq__', '=='),
('__ne__', '!='),
]:
_set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
def _set_order_fns(cls, fields):
# Create and set the ordering methods on cls.
# Pre-compute self_tuple and other_tuple, then re-use them for
# each function.
self_tuple = _tuple_str('self', fields)
other_tuple = _tuple_str('other', fields)
for name, op in [('__lt__', '<'),
('__le__', '<='),
('__gt__', '>'),
('__ge__', '>='),
]:
_set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
def _hash_fn(fields):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
@ -431,20 +541,20 @@ def _find_fields(cls):
# a Field(), then it contains additional info beyond (and
# possibly including) the actual default value. Pseudo-fields
# ClassVars and InitVars are included, despite the fact that
# they're not real fields. That's deal with later.
# they're not real fields. That's dealt with later.
annotations = getattr(cls, '__annotations__', {})
return [_get_field(cls, a_name, a_type)
for a_name, a_type in annotations.items()]
def _set_attribute(cls, name, value):
# Raise TypeError if an attribute by this name already exists.
def _set_new_attribute(cls, name, value):
# Never overwrites an existing attribute. Returns True if the
# attribute already exists.
if name in cls.__dict__:
raise TypeError(f'Cannot overwrite attribute {name} '
f'in {cls.__name__}')
return True
setattr(cls, name, value)
return False
def _process_class(cls, repr, eq, order, hash, init, frozen):
@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
# be inherited down.
is_frozen = frozen or cls.__setattr__ is _frozen_setattr
# Was this class defined with an __eq__? Used in __hash__ logic.
auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None
# If we're generating ordering methods, we must be generating
# the eq methods.
if order and not eq:
@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
has_post_init = hasattr(cls, _POST_INIT_NAME)
# Include InitVars and regular fields (so, not ClassVars).
_set_attribute(cls, '__init__',
_init_fn(list(filter(lambda f: f._field_type
in (_FIELD, _FIELD_INITVAR),
fields.values())),
is_frozen,
has_post_init,
# The name to use for the "self" param
# in __init__. Use "self" if possible.
'__dataclass_self__' if 'self' in fields
else 'self',
))
flds = [f for f in fields.values()
if f._field_type in (_FIELD, _FIELD_INITVAR)]
_set_new_attribute(cls, '__init__',
_init_fn(flds,
is_frozen,
has_post_init,
# The name to use for the "self" param
# in __init__. Use "self" if possible.
'__dataclass_self__' if 'self' in fields
else 'self',
))
# Get the fields as a list, and include only real fields. This is
# used in all of the following methods.
field_list = list(filter(lambda f: f._field_type is _FIELD,
fields.values()))
field_list = [f for f in fields.values() if f._field_type is _FIELD]
if repr:
_set_attribute(cls, '__repr__',
_repr_fn(list(filter(lambda f: f.repr, field_list))))
if is_frozen:
_set_attribute(cls, '__setattr__', _frozen_setattr)
_set_attribute(cls, '__delattr__', _frozen_delattr)
generate_hash = False
if hash is None:
if eq and frozen:
# Generate a hash function.
generate_hash = True
elif eq and not frozen:
# Not hashable.
_set_attribute(cls, '__hash__', None)
elif not eq:
# Otherwise, use the base class definition of hash(). That is,
# don't set anything on this class.
pass
else:
assert "can't get here"
else:
generate_hash = hash
if generate_hash:
_set_attribute(cls, '__hash__',
_hash_fn(list(filter(lambda f: f.compare
if f.hash is None
else f.hash,
field_list))))
flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
if eq:
# Create and __eq__ and __ne__ methods.
_set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
# Create _eq__ method. There's no need for a __ne__ method,
# since python will call __eq__ and negate it.
flds = [f for f in field_list if f.compare]
self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
_set_new_attribute(cls, '__eq__',
_cmp_fn('__eq__', '==',
self_tuple, other_tuple))
if order:
# Create and __lt__, __le__, __gt__, and __ge__ methods.
# Create and set the comparison functions.
_set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
# Create and set the ordering methods.
flds = [f for f in field_list if f.compare]
self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
for name, op in [('__lt__', '<'),
('__le__', '<='),
('__gt__', '>'),
('__ge__', '>='),
]:
if _set_new_attribute(cls, name,
_cmp_fn(name, op, self_tuple, other_tuple)):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in {cls.__name__}. Consider using '
'functools.total_ordering')
if is_frozen:
for name, fn in [('__setattr__', _frozen_setattr),
('__delattr__', _frozen_delattr)]:
if _set_new_attribute(cls, name, fn):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in {cls.__name__}')
# Decide if/how we're going to create a hash function.
# TODO: Move this table to module scope, so it's not recreated
# all the time.
generate_hash = {(None, False, False): ('', ''),
(None, False, True): ('', ''),
(None, True, False): ('none', ''),
(None, True, True): ('fn', 'fn-x'),
(False, False, False): ('', ''),
(False, False, True): ('', ''),
(False, True, False): ('', ''),
(False, True, True): ('', ''),
(True, False, False): ('fn', 'fn-x'),
(True, False, True): ('fn', 'fn-x'),
(True, True, False): ('fn', 'fn-x'),
(True, True, True): ('fn', 'fn-x'),
}[None if hash is None else bool(hash), # Force bool() if not None.
bool(eq),
bool(frozen)]['__hash__' in cls.__dict__]
# No need to call _set_new_attribute here, since we already know if
# we're overwriting a __hash__ or not.
if generate_hash == '':
# Do nothing.
pass
elif generate_hash == 'none':
cls.__hash__ = None
elif generate_hash in ('fn', 'fn-x'):
if generate_hash == 'fn' or auto_hash_test:
flds = [f for f in field_list
if (f.compare if f.hash is None else f.hash)]
cls.__hash__ = _hash_fn(flds)
else:
assert False, f"can't get here: {generate_hash}"
if not getattr(cls, '__doc__'):
# Create a class doc-string.