gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942)

This commit is contained in:
Ethan Furman 2022-07-17 18:51:04 -07:00 committed by GitHub
parent 07aeb7405e
commit c961d14f85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 10 deletions

View file

@ -247,7 +247,10 @@ class _proto_member:
if not enum_class._use_args_:
enum_member = enum_class._new_member_(enum_class)
if not hasattr(enum_member, '_value_'):
enum_member._value_ = value
try:
enum_member._value_ = enum_class._member_type_(*args)
except Exception as exc:
enum_member._value_ = value
else:
enum_member = enum_class._new_member_(enum_class, *args)
if not hasattr(enum_member, '_value_'):
@ -562,7 +565,13 @@ class EnumType(type):
classdict['__str__'] = enum_class.__str__
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
if name not in classdict:
setattr(enum_class, name, getattr(first_enum, name))
# check for mixin overrides before replacing
enum_method = getattr(first_enum, name)
found_method = getattr(enum_class, name)
object_method = getattr(object, name)
data_type_method = getattr(member_type, name)
if found_method in (data_type_method, object_method):
setattr(enum_class, name, enum_method)
#
# for Flag, add __or__, __and__, __xor__, and __invert__
if Flag is not None and issubclass(enum_class, Flag):
@ -937,16 +946,18 @@ class EnumType(type):
@classmethod
def _find_data_type_(mcls, class_name, bases):
data_types = set()
base_chain = set()
for chain in bases:
candidate = None
for base in chain.__mro__:
base_chain.add(base)
if base is object:
continue
elif issubclass(base, Enum):
if base._member_type_ is not object:
data_types.add(base._member_type_)
break
elif '__new__' in base.__dict__:
elif '__new__' in base.__dict__ or '__init__' in base.__dict__:
if issubclass(base, Enum):
continue
data_types.add(candidate or base)
@ -1658,7 +1669,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
if name not in body:
setattr(enum_class, name, getattr(etype, name))
# check for mixin overrides before replacing
enum_method = getattr(etype, name)
found_method = getattr(enum_class, name)
object_method = getattr(object, name)
data_type_method = getattr(member_type, name)
if found_method in (data_type_method, object_method):
setattr(enum_class, name, enum_method)
gnv_last_values = []
if issubclass(enum_class, Flag):
# Flag / IntFlag
@ -1989,7 +2006,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
members.sort(key=lambda t: t[0])
cls = etype(name, members, module=module, boundary=boundary or KEEP)
cls.__reduce_ex__ = _reduce_ex_by_global_name
cls.__repr__ = global_enum_repr
return cls
_stdlib_enums = IntEnum, StrEnum, IntFlag