mirror of
https://github.com/python/cpython.git
synced 2025-10-10 00:43:41 +00:00
bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)
This commit is contained in:
parent
976dec9b3b
commit
353e3b2820
2 changed files with 51 additions and 42 deletions
86
Lib/enum.py
86
Lib/enum.py
|
@ -618,6 +618,18 @@ class EnumType(type):
|
||||||
if name not in classdict:
|
if name not in classdict:
|
||||||
setattr(enum_class, name, getattr(first_enum, name))
|
setattr(enum_class, name, getattr(first_enum, name))
|
||||||
#
|
#
|
||||||
|
# for Flag, add __or__, __and__, __xor__, and __invert__
|
||||||
|
if Flag is not None and issubclass(enum_class, Flag):
|
||||||
|
for name in (
|
||||||
|
'__or__', '__and__', '__xor__',
|
||||||
|
'__ror__', '__rand__', '__rxor__',
|
||||||
|
'__invert__'
|
||||||
|
):
|
||||||
|
if name not in classdict:
|
||||||
|
enum_method = getattr(Flag, name)
|
||||||
|
setattr(enum_class, name, enum_method)
|
||||||
|
classdict[name] = enum_method
|
||||||
|
#
|
||||||
# replace any other __new__ with our own (as long as Enum is not None,
|
# replace any other __new__ with our own (as long as Enum is not None,
|
||||||
# anyway) -- again, this is to support pickle
|
# anyway) -- again, this is to support pickle
|
||||||
if Enum is not None:
|
if Enum is not None:
|
||||||
|
@ -1467,19 +1479,34 @@ class Flag(Enum, boundary=STRICT):
|
||||||
return bool(self._value_)
|
return bool(self._value_)
|
||||||
|
|
||||||
def __or__(self, other):
|
def __or__(self, other):
|
||||||
if not isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
|
other = other._value_
|
||||||
|
elif self._member_type_ is not object and isinstance(other, self._member_type_):
|
||||||
|
other = other
|
||||||
|
else:
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.__class__(self._value_ | other._value_)
|
value = self._value_
|
||||||
|
return self.__class__(value | other)
|
||||||
|
|
||||||
def __and__(self, other):
|
def __and__(self, other):
|
||||||
if not isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
|
other = other._value_
|
||||||
|
elif self._member_type_ is not object and isinstance(other, self._member_type_):
|
||||||
|
other = other
|
||||||
|
else:
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.__class__(self._value_ & other._value_)
|
value = self._value_
|
||||||
|
return self.__class__(value & other)
|
||||||
|
|
||||||
def __xor__(self, other):
|
def __xor__(self, other):
|
||||||
if not isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
|
other = other._value_
|
||||||
|
elif self._member_type_ is not object and isinstance(other, self._member_type_):
|
||||||
|
other = other
|
||||||
|
else:
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.__class__(self._value_ ^ other._value_)
|
value = self._value_
|
||||||
|
return self.__class__(value ^ other)
|
||||||
|
|
||||||
def __invert__(self):
|
def __invert__(self):
|
||||||
if self._inverted_ is None:
|
if self._inverted_ is None:
|
||||||
|
@ -1493,6 +1520,10 @@ class Flag(Enum, boundary=STRICT):
|
||||||
self._inverted_._inverted_ = self
|
self._inverted_._inverted_ = self
|
||||||
return self._inverted_
|
return self._inverted_
|
||||||
|
|
||||||
|
__rand__ = __and__
|
||||||
|
__ror__ = __or__
|
||||||
|
__rxor__ = __xor__
|
||||||
|
|
||||||
|
|
||||||
class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
|
class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
|
||||||
"""
|
"""
|
||||||
|
@ -1500,42 +1531,6 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def __or__(self, other):
|
|
||||||
if isinstance(other, self.__class__):
|
|
||||||
other = other._value_
|
|
||||||
elif isinstance(other, int):
|
|
||||||
other = other
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
|
||||||
value = self._value_
|
|
||||||
return self.__class__(value | other)
|
|
||||||
|
|
||||||
def __and__(self, other):
|
|
||||||
if isinstance(other, self.__class__):
|
|
||||||
other = other._value_
|
|
||||||
elif isinstance(other, int):
|
|
||||||
other = other
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
|
||||||
value = self._value_
|
|
||||||
return self.__class__(value & other)
|
|
||||||
|
|
||||||
def __xor__(self, other):
|
|
||||||
if isinstance(other, self.__class__):
|
|
||||||
other = other._value_
|
|
||||||
elif isinstance(other, int):
|
|
||||||
other = other
|
|
||||||
else:
|
|
||||||
return NotImplemented
|
|
||||||
value = self._value_
|
|
||||||
return self.__class__(value ^ other)
|
|
||||||
|
|
||||||
__ror__ = __or__
|
|
||||||
__rand__ = __and__
|
|
||||||
__rxor__ = __xor__
|
|
||||||
__invert__ = Flag.__invert__
|
|
||||||
|
|
||||||
|
|
||||||
def _high_bit(value):
|
def _high_bit(value):
|
||||||
"""
|
"""
|
||||||
returns index of highest bit, or -1 if value is zero or negative
|
returns index of highest bit, or -1 if value is zero or negative
|
||||||
|
@ -1662,6 +1657,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
||||||
body['_flag_mask_'] = None
|
body['_flag_mask_'] = None
|
||||||
body['_all_bits_'] = None
|
body['_all_bits_'] = None
|
||||||
body['_inverted_'] = None
|
body['_inverted_'] = None
|
||||||
|
body['__or__'] = Flag.__or__
|
||||||
|
body['__xor__'] = Flag.__xor__
|
||||||
|
body['__and__'] = Flag.__and__
|
||||||
|
body['__ror__'] = Flag.__ror__
|
||||||
|
body['__rxor__'] = Flag.__rxor__
|
||||||
|
body['__rand__'] = Flag.__rand__
|
||||||
|
body['__invert__'] = Flag.__invert__
|
||||||
for name, obj in cls.__dict__.items():
|
for name, obj in cls.__dict__.items():
|
||||||
if name in ('__dict__', '__weakref__'):
|
if name in ('__dict__', '__weakref__'):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -2496,6 +2496,13 @@ class TestSpecial(unittest.TestCase):
|
||||||
self.assertEqual(Some.x.value, 1)
|
self.assertEqual(Some.x.value, 1)
|
||||||
self.assertEqual(Some.y.value, 2)
|
self.assertEqual(Some.y.value, 2)
|
||||||
|
|
||||||
|
def test_custom_flag_bitwise(self):
|
||||||
|
class MyIntFlag(int, Flag):
|
||||||
|
ONE = 1
|
||||||
|
TWO = 2
|
||||||
|
FOUR = 4
|
||||||
|
self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO)
|
||||||
|
self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag))
|
||||||
|
|
||||||
class TestOrder(unittest.TestCase):
|
class TestOrder(unittest.TestCase):
|
||||||
"test usage of the `_order_` attribute"
|
"test usage of the `_order_` attribute"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue