From 9d1fff1fcd5332f0ba7f72d0e0f9f66b47ec4e8d Mon Sep 17 00:00:00 2001 From: Ethan Furman Date: Mon, 14 Dec 2020 18:41:34 -0800 Subject: [PATCH] [3.9] bpo-42567: [Enum] call __init_subclass__ after members are added (GH-23714) (GH-23772) When creating an Enum, `type.__new__` calls `__init_subclass__`, but at that point the members have not been added. This patch suppresses the initial call, then manually calls the ancestor `__init_subclass__` before returning the new Enum class. (cherry picked from commit 6bd94de168b58ac9358277ed6f200490ab26c174) --- Lib/enum.py | 32 +++++++- Lib/test/test_enum.py | 73 ++++++++++++++++++- .../2020-12-08-22-43-35.bpo-42678.ba9ktU.rst | 1 + 3 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst diff --git a/Lib/enum.py b/Lib/enum.py index 46b5435b7c8..88c951f4f12 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -9,6 +9,14 @@ __all__ = [ ] +class _NoInitSubclass: + """ + temporary base class to suppress calling __init_subclass__ + """ + @classmethod + def __init_subclass__(cls, **kwds): + pass + def _is_descriptor(obj): """ Returns True if obj is a descriptor, False otherwise. @@ -176,7 +184,7 @@ class EnumMeta(type): ) return enum_dict - def __new__(metacls, cls, bases, classdict): + def __new__(metacls, cls, bases, classdict, **kwds): # an Enum class is final once enumeration items have been defined; it # cannot be mixed with other types (int, float, etc.) if it has an # inherited __new__ unless a new __new__ is defined (or the resulting @@ -211,8 +219,22 @@ class EnumMeta(type): if '__doc__' not in classdict: classdict['__doc__'] = 'An enumeration.' + # postpone calling __init_subclass__ + if '__init_subclass__' in classdict and classdict['__init_subclass__'] is None: + raise TypeError('%s.__init_subclass__ cannot be None') + # remove current __init_subclass__ so previous one can be found with getattr + new_init_subclass = classdict.pop('__init_subclass__', None) # create our new Enum type - enum_class = super().__new__(metacls, cls, bases, classdict) + if bases: + bases = (_NoInitSubclass, ) + bases + enum_class = type.__new__(metacls, cls, bases, classdict) + enum_class.__bases__ = enum_class.__bases__[1:] #or (object, ) + else: + enum_class = type.__new__(metacls, cls, bases, classdict) + old_init_subclass = getattr(enum_class, '__init_subclass__', None) + # and restore the new one (if there was one) + if new_init_subclass is not None: + enum_class.__init_subclass__ = classmethod(new_init_subclass) enum_class._member_names_ = [] # names in definition order enum_class._member_map_ = {} # name->value map enum_class._member_type_ = member_type @@ -324,6 +346,9 @@ class EnumMeta(type): if _order_ != enum_class._member_names_: raise TypeError('member order does not match _order_') + # finally, call parents' __init_subclass__ + if Enum is not None and old_init_subclass is not None: + old_init_subclass(**kwds) return enum_class def __bool__(self): @@ -701,6 +726,9 @@ class Enum(metaclass=EnumMeta): else: return start + def __init_subclass__(cls, **kwds): + super().__init_subclass__(**kwds) + @classmethod def _missing_(cls, value): return None diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index f7f8522bda4..0868c303472 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -2049,7 +2049,6 @@ class TestEnum(unittest.TestCase): local_ls = {} exec(code, global_ns, local_ls) - @unittest.skipUnless( sys.version_info[:2] == (3, 9), 'private variables are now normal attributes', @@ -2066,6 +2065,42 @@ class TestEnum(unittest.TestCase): except ValueError: pass + def test_init_subclass(self): + class MyEnum(Enum): + def __init_subclass__(cls, **kwds): + super(MyEnum, cls).__init_subclass__(**kwds) + self.assertFalse(cls.__dict__.get('_test', False)) + cls._test1 = 'MyEnum' + # + class TheirEnum(MyEnum): + def __init_subclass__(cls, **kwds): + super().__init_subclass__(**kwds) + cls._test2 = 'TheirEnum' + class WhoseEnum(TheirEnum): + def __init_subclass__(cls, **kwds): + pass + class NoEnum(WhoseEnum): + ONE = 1 + self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum') + self.assertFalse(NoEnum.__dict__.get('_test1', False)) + self.assertFalse(NoEnum.__dict__.get('_test2', False)) + # + class OurEnum(MyEnum): + def __init_subclass__(cls, **kwds): + cls._test2 = 'OurEnum' + class WhereEnum(OurEnum): + def __init_subclass__(cls, **kwds): + pass + class NeverEnum(WhereEnum): + ONE = 'one' + self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum') + self.assertFalse(WhereEnum.__dict__.get('_test1', False)) + self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum') + self.assertFalse(NeverEnum.__dict__.get('_test1', False)) + self.assertFalse(NeverEnum.__dict__.get('_test2', False)) + class TestOrder(unittest.TestCase): @@ -2516,6 +2551,42 @@ class TestFlag(unittest.TestCase): 'at least one thread failed while creating composite members') self.assertEqual(256, len(seen), 'too many composite members created') + def test_init_subclass(self): + class MyEnum(Flag): + def __init_subclass__(cls, **kwds): + super().__init_subclass__(**kwds) + self.assertFalse(cls.__dict__.get('_test', False)) + cls._test1 = 'MyEnum' + # + class TheirEnum(MyEnum): + def __init_subclass__(cls, **kwds): + super(TheirEnum, cls).__init_subclass__(**kwds) + cls._test2 = 'TheirEnum' + class WhoseEnum(TheirEnum): + def __init_subclass__(cls, **kwds): + pass + class NoEnum(WhoseEnum): + ONE = 1 + self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum') + self.assertFalse(NoEnum.__dict__.get('_test1', False)) + self.assertFalse(NoEnum.__dict__.get('_test2', False)) + # + class OurEnum(MyEnum): + def __init_subclass__(cls, **kwds): + cls._test2 = 'OurEnum' + class WhereEnum(OurEnum): + def __init_subclass__(cls, **kwds): + pass + class NeverEnum(WhereEnum): + ONE = 1 + self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum') + self.assertFalse(WhereEnum.__dict__.get('_test1', False)) + self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum') + self.assertFalse(NeverEnum.__dict__.get('_test1', False)) + self.assertFalse(NeverEnum.__dict__.get('_test2', False)) + class TestIntFlag(unittest.TestCase): """Tests of the IntFlags.""" diff --git a/Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst b/Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst new file mode 100644 index 00000000000..7c94cdf40dd --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst @@ -0,0 +1 @@ +`Enum`: call `__init_subclass__` after members have been added