bpo-38659: [Enum] add _simple_enum decorator (GH-25285)

add:

_simple_enum decorator to transform a normal class into an enum
_test_simple_enum function to compare
_old_convert_ to enable checking _convert_ generated enums
_simple_enum takes a normal class and converts it into an enum:

@simple_enum(Enum)
class Color:
    RED = 1
    GREEN = 2
    BLUE = 3

_old_convert_ works much like _convert_ does, using the original logic:

# in a test file
import socket, enum
CheckedAddressFamily = enum._old_convert_(
        enum.IntEnum, 'AddressFamily', 'socket',
        lambda C: C.isupper() and C.startswith('AF_'),
        source=_socket,
        )

test_simple_enum takes a traditional enum and a simple enum and
compares the two:

# in the REPL or the same module as Color
class CheckedColor(Enum):
    RED = 1
    GREEN = 2
    BLUE = 3

_test_simple_enum(CheckedColor, Color)

_test_simple_enum(CheckedAddressFamily, socket.AddressFamily)

Any important differences will raise a TypeError
This commit is contained in:
Ethan Furman 2021-04-19 18:04:53 -07:00 committed by GitHub
parent 7a04116246
commit dbac8f40e8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 871 additions and 34 deletions

View file

@ -391,13 +391,15 @@ class EnumType(type):
)
return enum_dict
def __new__(metacls, cls, bases, classdict, boundary=None, **kwds):
def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **kwds):
# an Enum class is final once enumeration items have been defined; it
# cannot be mixed with other types (int, float, etc.) if it has an
# inherited __new__ unless a new __new__ is defined (or the resulting
# class will fail).
#
# remove any keys listed in _ignore_
if _simple:
return super().__new__(metacls, cls, bases, classdict, **kwds)
classdict.setdefault('_ignore_', []).append('_ignore_')
ignore = classdict['_ignore_']
for key in ignore:
@ -695,7 +697,7 @@ class EnumType(type):
"""
member_map = cls.__dict__.get('_member_map_', {})
if name in member_map:
raise AttributeError('Cannot reassign members.')
raise AttributeError('Cannot reassign member %r.' % (name, ))
super().__setattr__(name, value)
def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, start=1, boundary=None):
@ -750,7 +752,8 @@ class EnumType(type):
return metacls.__new__(metacls, class_name, bases, classdict, boundary=boundary)
def _convert_(cls, name, module, filter, source=None, boundary=None):
def _convert_(cls, name, module, filter, source=None, *, boundary=None):
"""
Create a new Enum subclass that replaces a collection of global constants
"""
@ -777,7 +780,10 @@ class EnumType(type):
except TypeError:
# unless some values aren't comparable, in which case sort by name
members.sort(key=lambda t: t[0])
cls = cls(name, members, module=module, boundary=boundary or KEEP)
body = {t[0]: t[1] for t in members}
body['__module__'] = module
tmp_cls = type(name, (object, ), body)
cls = _simple_enum(etype=cls, boundary=boundary or KEEP)(tmp_cls)
cls.__reduce_ex__ = _reduce_ex_by_name
global_enum(cls)
module_globals[name] = cls
@ -855,7 +861,7 @@ class EnumType(type):
__new__ = classdict.get('__new__', None)
# should __new__ be saved as __new_member__ later?
save_new = __new__ is not None
save_new = first_enum is not None and __new__ is not None
if __new__ is None:
# check all possibles for __new_member__ before falling back to
@ -879,7 +885,7 @@ class EnumType(type):
# if a non-object.__new__ is used then whatever value/tuple was
# assigned to the enum member name will be passed to __new__ and to the
# new enum member's __init__
if __new__ is object.__new__:
if first_enum is None or __new__ in (Enum.__new__, object.__new__):
use_args = False
else:
use_args = True
@ -1189,7 +1195,7 @@ class Flag(Enum, boundary=STRICT):
pseudo_member = object.__new__(cls)
else:
pseudo_member = (__new__ or cls._member_type_.__new__)(cls, value)
if not hasattr(pseudo_member, 'value'):
if not hasattr(pseudo_member, '_value_'):
pseudo_member._value_ = value
if member_value:
pseudo_member._name_ = '|'.join([
@ -1383,3 +1389,309 @@ def global_enum(cls):
cls.__repr__ = global_enum_repr
sys.modules[cls.__module__].__dict__.update(cls.__members__)
return cls
def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
"""
Class decorator that converts a normal class into an :class:`Enum`. No
safety checks are done, and some advanced behavior (such as
:func:`__init_subclass__`) is not available. Enum creation can be faster
using :func:`simple_enum`.
>>> from enum import Enum, _simple_enum
>>> @_simple_enum(Enum)
... class Color:
... RED = auto()
... GREEN = auto()
... BLUE = auto()
>>> Color
<enum 'Color'>
"""
def convert_class(cls):
nonlocal use_args
cls_name = cls.__name__
if use_args is None:
use_args = etype._use_args_
__new__ = cls.__dict__.get('__new__')
if __new__ is not None:
new_member = __new__.__func__
else:
new_member = etype._member_type_.__new__
attrs = {}
body = {}
if __new__ is not None:
body['__new_member__'] = new_member
body['_new_member_'] = new_member
body['_use_args_'] = use_args
body['_generate_next_value_'] = gnv = etype._generate_next_value_
body['_member_names_'] = member_names = []
body['_member_map_'] = member_map = {}
body['_value2member_map_'] = value2member_map = {}
body['_member_type_'] = member_type = etype._member_type_
if issubclass(etype, Flag):
body['_boundary_'] = boundary or etype._boundary_
body['_flag_mask_'] = None
body['_all_bits_'] = None
body['_inverted_'] = None
for name, obj in cls.__dict__.items():
if name in ('__dict__', '__weakref__'):
continue
if _is_dunder(name) or _is_private(cls_name, name) or _is_sunder(name) or _is_descriptor(obj):
body[name] = obj
else:
attrs[name] = obj
if cls.__dict__.get('__doc__') is None:
body['__doc__'] = 'An enumeration.'
#
# double check that repr and friends are not the mixin's or various
# things break (such as pickle)
# however, if the method is defined in the Enum itself, don't replace
# it
enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
if name in body:
continue
class_method = getattr(enum_class, name)
obj_method = getattr(member_type, name, None)
enum_method = getattr(etype, name, None)
if obj_method is not None and obj_method is class_method:
setattr(enum_class, name, enum_method)
gnv_last_values = []
if issubclass(enum_class, Flag):
# Flag / IntFlag
single_bits = multi_bits = 0
for name, value in attrs.items():
if isinstance(value, auto) and auto.value is _auto_null:
value = gnv(name, 1, len(member_names), gnv_last_values)
if value in value2member_map:
# an alias to an existing member
redirect = property()
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = value2member_map[value]
else:
# create the member
if use_args:
if not isinstance(value, tuple):
value = (value, )
member = new_member(enum_class, *value)
value = value[0]
else:
member = new_member(enum_class)
if __new__ is None:
member._value_ = value
member._name_ = name
member.__objclass__ = enum_class
member.__init__(value)
redirect = property()
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = member
member._sort_order_ = len(member_names)
value2member_map[value] = member
if _is_single_bit(value):
# not a multi-bit alias, record in _member_names_ and _flag_mask_
member_names.append(name)
single_bits |= value
else:
multi_bits |= value
gnv_last_values.append(value)
enum_class._flag_mask_ = single_bits
enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
# set correct __iter__
member_list = [m._value_ for m in enum_class]
if member_list != sorted(member_list):
enum_class._iter_member_ = enum_class._iter_member_by_def_
else:
# Enum / IntEnum / StrEnum
for name, value in attrs.items():
if isinstance(value, auto):
if value.value is _auto_null:
value.value = gnv(name, 1, len(member_names), gnv_last_values)
value = value.value
if value in value2member_map:
# an alias to an existing member
redirect = property()
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = value2member_map[value]
else:
# create the member
if use_args:
if not isinstance(value, tuple):
value = (value, )
member = new_member(enum_class, *value)
value = value[0]
else:
member = new_member(enum_class)
if __new__ is None:
member._value_ = value
member._name_ = name
member.__objclass__ = enum_class
member.__init__(value)
member._sort_order_ = len(member_names)
redirect = property()
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = member
value2member_map[value] = member
member_names.append(name)
gnv_last_values.append(value)
if '__new__' in body:
enum_class.__new_member__ = enum_class.__new__
enum_class.__new__ = Enum.__new__
return enum_class
return convert_class
def _test_simple_enum(checked_enum, simple_enum):
"""
A function that can be used to test an enum created with :func:`_simple_enum`
against the version created by subclassing :class:`Enum`::
>>> from enum import Enum, _simple_enum, _test_simple_enum
>>> @_simple_enum(Enum)
... class Color:
... RED = auto()
... GREEN = auto()
... BLUE = auto()
>>> class CheckedColor(Enum):
... RED = auto()
... GREEN = auto()
... BLUE = auto()
>>> _test_simple_enum(CheckedColor, Color)
If differences are found, a :exc:`TypeError` is raised.
"""
failed = []
if checked_enum.__dict__ != simple_enum.__dict__:
checked_dict = checked_enum.__dict__
checked_keys = list(checked_dict.keys())
simple_dict = simple_enum.__dict__
simple_keys = list(simple_dict.keys())
member_names = set(
list(checked_enum._member_map_.keys())
+ list(simple_enum._member_map_.keys())
)
for key in set(checked_keys + simple_keys):
if key in ('__module__', '_member_map_', '_value2member_map_'):
# keys known to be different
continue
elif key in member_names:
# members are checked below
continue
elif key not in simple_keys:
failed.append("missing key: %r" % (key, ))
elif key not in checked_keys:
failed.append("extra key: %r" % (key, ))
else:
checked_value = checked_dict[key]
simple_value = simple_dict[key]
if callable(checked_value):
continue
if key == '__doc__':
# remove all spaces/tabs
compressed_checked_value = checked_value.replace(' ','').replace('\t','')
compressed_simple_value = simple_value.replace(' ','').replace('\t','')
if compressed_checked_value != compressed_simple_value:
failed.append("%r:\n %s\n %s" % (
key,
"checked -> %r" % (checked_value, ),
"simple -> %r" % (simple_value, ),
))
elif checked_value != simple_value:
failed.append("%r:\n %s\n %s" % (
key,
"checked -> %r" % (checked_value, ),
"simple -> %r" % (simple_value, ),
))
failed.sort()
for name in member_names:
failed_member = []
if name not in simple_keys:
failed.append('missing member from simple enum: %r' % name)
elif name not in checked_keys:
failed.append('extra member in simple enum: %r' % name)
else:
checked_member_dict = checked_enum[name].__dict__
checked_member_keys = list(checked_member_dict.keys())
simple_member_dict = simple_enum[name].__dict__
simple_member_keys = list(simple_member_dict.keys())
for key in set(checked_member_keys + simple_member_keys):
if key in ('__module__', '__objclass__'):
# keys known to be different
continue
elif key not in simple_member_keys:
failed_member.append("missing key %r not in the simple enum member %r" % (key, name))
elif key not in checked_member_keys:
failed_member.append("extra key %r in simple enum member %r" % (key, name))
else:
checked_value = checked_member_dict[key]
simple_value = simple_member_dict[key]
if checked_value != simple_value:
failed_member.append("%r:\n %s\n %s" % (
key,
"checked member -> %r" % (checked_value, ),
"simple member -> %r" % (simple_value, ),
))
if failed_member:
failed.append('%r member mismatch:\n %s' % (
name, '\n '.join(failed_member),
))
for method in (
'__str__', '__repr__', '__reduce_ex__', '__format__',
'__getnewargs_ex__', '__getnewargs__', '__reduce_ex__', '__reduce__'
):
if method in simple_keys and method in checked_keys:
# cannot compare functions, and it exists in both, so we're good
continue
elif method not in simple_keys and method not in checked_keys:
# method is inherited -- check it out
checked_method = getattr(checked_enum, method, None)
simple_method = getattr(simple_enum, method, None)
if hasattr(checked_method, '__func__'):
checked_method = checked_method.__func__
simple_method = simple_method.__func__
if checked_method != simple_method:
failed.append("%r: %-30s %s" % (
method,
"checked -> %r" % (checked_method, ),
"simple -> %r" % (simple_method, ),
))
else:
# if the method existed in only one of the enums, it will have been caught
# in the first checks above
pass
if failed:
raise TypeError('enum mismatch:\n %s' % '\n '.join(failed))
def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
"""
Create a new Enum subclass that replaces a collection of global constants
"""
# convert all constants from source (or module) that pass filter() to
# a new Enum called name, and export the enum and its members back to
# module;
# also, replace the __reduce_ex__ method so unpickling works in
# previous Python versions
module_globals = sys.modules[module].__dict__
if source:
source = source.__dict__
else:
source = module_globals
# _value2member_map_ is populated in the same order every time
# for a consistent reverse mapping of number to name when there
# are multiple names for the same number.
members = [
(name, value)
for name, value in source.items()
if filter(name)]
try:
# sort by value
members.sort(key=lambda t: (t[1], t[0]))
except TypeError:
# unless some values aren't comparable, in which case sort by name
members.sort(key=lambda t: t[0])
cls = etype(name, members, module=module, boundary=boundary or KEEP)
cls.__reduce_ex__ = _reduce_ex_by_name
cls.__repr__ = global_enum_repr
return cls