mirror of
https://github.com/python/cpython.git
synced 2025-08-31 22:18:28 +00:00
gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942)
This commit is contained in:
parent
07aeb7405e
commit
c961d14f85
2 changed files with 42 additions and 10 deletions
26
Lib/enum.py
26
Lib/enum.py
|
@ -247,7 +247,10 @@ class _proto_member:
|
||||||
if not enum_class._use_args_:
|
if not enum_class._use_args_:
|
||||||
enum_member = enum_class._new_member_(enum_class)
|
enum_member = enum_class._new_member_(enum_class)
|
||||||
if not hasattr(enum_member, '_value_'):
|
if not hasattr(enum_member, '_value_'):
|
||||||
enum_member._value_ = value
|
try:
|
||||||
|
enum_member._value_ = enum_class._member_type_(*args)
|
||||||
|
except Exception as exc:
|
||||||
|
enum_member._value_ = value
|
||||||
else:
|
else:
|
||||||
enum_member = enum_class._new_member_(enum_class, *args)
|
enum_member = enum_class._new_member_(enum_class, *args)
|
||||||
if not hasattr(enum_member, '_value_'):
|
if not hasattr(enum_member, '_value_'):
|
||||||
|
@ -562,7 +565,13 @@ class EnumType(type):
|
||||||
classdict['__str__'] = enum_class.__str__
|
classdict['__str__'] = enum_class.__str__
|
||||||
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
|
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
|
||||||
if name not in classdict:
|
if name not in classdict:
|
||||||
setattr(enum_class, name, getattr(first_enum, name))
|
# check for mixin overrides before replacing
|
||||||
|
enum_method = getattr(first_enum, name)
|
||||||
|
found_method = getattr(enum_class, name)
|
||||||
|
object_method = getattr(object, name)
|
||||||
|
data_type_method = getattr(member_type, name)
|
||||||
|
if found_method in (data_type_method, object_method):
|
||||||
|
setattr(enum_class, name, enum_method)
|
||||||
#
|
#
|
||||||
# for Flag, add __or__, __and__, __xor__, and __invert__
|
# for Flag, add __or__, __and__, __xor__, and __invert__
|
||||||
if Flag is not None and issubclass(enum_class, Flag):
|
if Flag is not None and issubclass(enum_class, Flag):
|
||||||
|
@ -937,16 +946,18 @@ class EnumType(type):
|
||||||
@classmethod
|
@classmethod
|
||||||
def _find_data_type_(mcls, class_name, bases):
|
def _find_data_type_(mcls, class_name, bases):
|
||||||
data_types = set()
|
data_types = set()
|
||||||
|
base_chain = set()
|
||||||
for chain in bases:
|
for chain in bases:
|
||||||
candidate = None
|
candidate = None
|
||||||
for base in chain.__mro__:
|
for base in chain.__mro__:
|
||||||
|
base_chain.add(base)
|
||||||
if base is object:
|
if base is object:
|
||||||
continue
|
continue
|
||||||
elif issubclass(base, Enum):
|
elif issubclass(base, Enum):
|
||||||
if base._member_type_ is not object:
|
if base._member_type_ is not object:
|
||||||
data_types.add(base._member_type_)
|
data_types.add(base._member_type_)
|
||||||
break
|
break
|
||||||
elif '__new__' in base.__dict__:
|
elif '__new__' in base.__dict__ or '__init__' in base.__dict__:
|
||||||
if issubclass(base, Enum):
|
if issubclass(base, Enum):
|
||||||
continue
|
continue
|
||||||
data_types.add(candidate or base)
|
data_types.add(candidate or base)
|
||||||
|
@ -1658,7 +1669,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
||||||
enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
|
enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
|
||||||
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
|
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
|
||||||
if name not in body:
|
if name not in body:
|
||||||
setattr(enum_class, name, getattr(etype, name))
|
# check for mixin overrides before replacing
|
||||||
|
enum_method = getattr(etype, name)
|
||||||
|
found_method = getattr(enum_class, name)
|
||||||
|
object_method = getattr(object, name)
|
||||||
|
data_type_method = getattr(member_type, name)
|
||||||
|
if found_method in (data_type_method, object_method):
|
||||||
|
setattr(enum_class, name, enum_method)
|
||||||
gnv_last_values = []
|
gnv_last_values = []
|
||||||
if issubclass(enum_class, Flag):
|
if issubclass(enum_class, Flag):
|
||||||
# Flag / IntFlag
|
# Flag / IntFlag
|
||||||
|
@ -1989,7 +2006,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
|
||||||
members.sort(key=lambda t: t[0])
|
members.sort(key=lambda t: t[0])
|
||||||
cls = etype(name, members, module=module, boundary=boundary or KEEP)
|
cls = etype(name, members, module=module, boundary=boundary or KEEP)
|
||||||
cls.__reduce_ex__ = _reduce_ex_by_global_name
|
cls.__reduce_ex__ = _reduce_ex_by_global_name
|
||||||
cls.__repr__ = global_enum_repr
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
_stdlib_enums = IntEnum, StrEnum, IntFlag
|
_stdlib_enums = IntEnum, StrEnum, IntFlag
|
||||||
|
|
|
@ -2693,12 +2693,15 @@ class TestSpecial(unittest.TestCase):
|
||||||
@dataclass
|
@dataclass
|
||||||
class Foo:
|
class Foo:
|
||||||
__qualname__ = 'Foo'
|
__qualname__ = 'Foo'
|
||||||
a: int = 0
|
a: int
|
||||||
class Entries(Foo, Enum):
|
class Entries(Foo, Enum):
|
||||||
ENTRY1 = Foo(1)
|
ENTRY1 = 1
|
||||||
|
self.assertTrue(isinstance(Entries.ENTRY1, Foo))
|
||||||
|
self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_)
|
||||||
|
self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value)
|
||||||
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
|
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
|
||||||
|
|
||||||
def test_repr_with_non_data_type_mixin(self):
|
def test_repr_with_init_data_type_mixin(self):
|
||||||
# non-data_type is a mixin that doesn't define __new__
|
# non-data_type is a mixin that doesn't define __new__
|
||||||
class Foo:
|
class Foo:
|
||||||
def __init__(self, a):
|
def __init__(self, a):
|
||||||
|
@ -2706,10 +2709,23 @@ class TestSpecial(unittest.TestCase):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'Foo(a={self.a!r})'
|
return f'Foo(a={self.a!r})'
|
||||||
class Entries(Foo, Enum):
|
class Entries(Foo, Enum):
|
||||||
ENTRY1 = Foo(1)
|
ENTRY1 = 1
|
||||||
|
#
|
||||||
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
|
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
|
||||||
|
|
||||||
|
def test_repr_and_str_with_non_data_type_mixin(self):
|
||||||
|
# non-data_type is a mixin that doesn't define __new__
|
||||||
|
class Foo:
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Foo'
|
||||||
|
def __str__(self):
|
||||||
|
return 'ooF'
|
||||||
|
class Entries(Foo, Enum):
|
||||||
|
ENTRY1 = 1
|
||||||
|
#
|
||||||
|
self.assertEqual(repr(Entries.ENTRY1), 'Foo')
|
||||||
|
self.assertEqual(str(Entries.ENTRY1), 'ooF')
|
||||||
|
|
||||||
def test_value_backup_assign(self):
|
def test_value_backup_assign(self):
|
||||||
# check that enum will add missing values when custom __new__ does not
|
# check that enum will add missing values when custom __new__ does not
|
||||||
class Some(Enum):
|
class Some(Enum):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue