mirror of
https://github.com/python/cpython.git
synced 2025-08-22 17:55:18 +00:00
gh-103365: [Enum] STRICT boundary corrections (GH-103494)
STRICT boundary: - fix bitwise operations - make default for Flag
This commit is contained in:
parent
efb8a2553c
commit
2194071540
4 changed files with 82 additions and 38 deletions
67
Lib/enum.py
67
Lib/enum.py
|
@ -275,6 +275,13 @@ class _proto_member:
|
|||
enum_member.__objclass__ = enum_class
|
||||
enum_member.__init__(*args)
|
||||
enum_member._sort_order_ = len(enum_class._member_names_)
|
||||
|
||||
if Flag is not None and issubclass(enum_class, Flag):
|
||||
enum_class._flag_mask_ |= value
|
||||
if _is_single_bit(value):
|
||||
enum_class._singles_mask_ |= value
|
||||
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
|
||||
|
||||
# If another member with the same value was already defined, the
|
||||
# new member becomes an alias to the existing one.
|
||||
try:
|
||||
|
@ -532,12 +539,8 @@ class EnumType(type):
|
|||
classdict['_use_args_'] = use_args
|
||||
#
|
||||
# convert future enum members into temporary _proto_members
|
||||
# and record integer values in case this will be a Flag
|
||||
flag_mask = 0
|
||||
for name in member_names:
|
||||
value = classdict[name]
|
||||
if isinstance(value, int):
|
||||
flag_mask |= value
|
||||
classdict[name] = _proto_member(value)
|
||||
#
|
||||
# house-keeping structures
|
||||
|
@ -554,8 +557,9 @@ class EnumType(type):
|
|||
boundary
|
||||
or getattr(first_enum, '_boundary_', None)
|
||||
)
|
||||
classdict['_flag_mask_'] = flag_mask
|
||||
classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
|
||||
classdict['_flag_mask_'] = 0
|
||||
classdict['_singles_mask_'] = 0
|
||||
classdict['_all_bits_'] = 0
|
||||
classdict['_inverted_'] = None
|
||||
try:
|
||||
exc = None
|
||||
|
@ -644,21 +648,10 @@ class EnumType(type):
|
|||
):
|
||||
delattr(enum_class, '_boundary_')
|
||||
delattr(enum_class, '_flag_mask_')
|
||||
delattr(enum_class, '_singles_mask_')
|
||||
delattr(enum_class, '_all_bits_')
|
||||
delattr(enum_class, '_inverted_')
|
||||
elif Flag is not None and issubclass(enum_class, Flag):
|
||||
# ensure _all_bits_ is correct and there are no missing flags
|
||||
single_bit_total = 0
|
||||
multi_bit_total = 0
|
||||
for flag in enum_class._member_map_.values():
|
||||
flag_value = flag._value_
|
||||
if _is_single_bit(flag_value):
|
||||
single_bit_total |= flag_value
|
||||
else:
|
||||
# multi-bit flags are considered aliases
|
||||
multi_bit_total |= flag_value
|
||||
enum_class._flag_mask_ = single_bit_total
|
||||
#
|
||||
# set correct __iter__
|
||||
member_list = [m._value_ for m in enum_class]
|
||||
if member_list != sorted(member_list):
|
||||
|
@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto):
|
|||
class FlagBoundary(StrEnum):
|
||||
"""
|
||||
control how out of range values are handled
|
||||
"strict" -> error is raised
|
||||
"conform" -> extra bits are discarded [default for Flag]
|
||||
"strict" -> error is raised [default for Flag]
|
||||
"conform" -> extra bits are discarded
|
||||
"eject" -> lose flag status
|
||||
"keep" -> keep flag status and all bits [default for IntFlag]
|
||||
"""
|
||||
|
@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum):
|
|||
STRICT, CONFORM, EJECT, KEEP = FlagBoundary
|
||||
|
||||
|
||||
class Flag(Enum, boundary=CONFORM):
|
||||
class Flag(Enum, boundary=STRICT):
|
||||
"""
|
||||
Support for flags
|
||||
"""
|
||||
|
@ -1394,6 +1387,7 @@ class Flag(Enum, boundary=CONFORM):
|
|||
# - value must not include any skipped flags (e.g. if bit 2 is not
|
||||
# defined, then 0d10 is invalid)
|
||||
flag_mask = cls._flag_mask_
|
||||
singles_mask = cls._singles_mask_
|
||||
all_bits = cls._all_bits_
|
||||
neg_value = None
|
||||
if (
|
||||
|
@ -1425,7 +1419,8 @@ class Flag(Enum, boundary=CONFORM):
|
|||
value = all_bits + 1 + value
|
||||
# get members and unknown
|
||||
unknown = value & ~flag_mask
|
||||
member_value = value & flag_mask
|
||||
aliases = value & ~singles_mask
|
||||
member_value = value & singles_mask
|
||||
if unknown and cls._boundary_ is not KEEP:
|
||||
raise ValueError(
|
||||
'%s(%r) --> unknown values %r [%s]'
|
||||
|
@ -1439,11 +1434,25 @@ class Flag(Enum, boundary=CONFORM):
|
|||
pseudo_member = cls._member_type_.__new__(cls, value)
|
||||
if not hasattr(pseudo_member, '_value_'):
|
||||
pseudo_member._value_ = value
|
||||
if member_value:
|
||||
pseudo_member._name_ = '|'.join([
|
||||
m._name_ for m in cls._iter_member_(member_value)
|
||||
])
|
||||
if unknown:
|
||||
if member_value or aliases:
|
||||
members = []
|
||||
combined_value = 0
|
||||
for m in cls._iter_member_(member_value):
|
||||
members.append(m)
|
||||
combined_value |= m._value_
|
||||
if aliases:
|
||||
value = member_value | aliases
|
||||
for n, pm in cls._member_map_.items():
|
||||
if pm not in members and pm._value_ and pm._value_ & value == pm._value_:
|
||||
members.append(pm)
|
||||
combined_value |= pm._value_
|
||||
unknown = value ^ combined_value
|
||||
pseudo_member._name_ = '|'.join([m._name_ for m in members])
|
||||
if not combined_value:
|
||||
pseudo_member._name_ = None
|
||||
elif unknown and cls._boundary_ is STRICT:
|
||||
raise ValueError('%r: no members with value %r' % (cls, unknown))
|
||||
elif unknown:
|
||||
pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown)
|
||||
else:
|
||||
pseudo_member._name_ = None
|
||||
|
@ -1675,6 +1684,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
|||
body['_boundary_'] = boundary or etype._boundary_
|
||||
body['_flag_mask_'] = None
|
||||
body['_all_bits_'] = None
|
||||
body['_singles_mask_'] = None
|
||||
body['_inverted_'] = None
|
||||
body['__or__'] = Flag.__or__
|
||||
body['__xor__'] = Flag.__xor__
|
||||
|
@ -1750,7 +1760,8 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
|||
else:
|
||||
multi_bits |= value
|
||||
gnv_last_values.append(value)
|
||||
enum_class._flag_mask_ = single_bits
|
||||
enum_class._flag_mask_ = single_bits | multi_bits
|
||||
enum_class._singles_mask_ = single_bits
|
||||
enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
|
||||
# set correct __iter__
|
||||
member_list = [m._value_ for m in enum_class]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue