mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
[3.13] gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735) (GH-125851)
gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735)
(cherry picked from commit aaed91cabc
)
Co-authored-by: Ethan Furman <ethan@stoneleaf.us>
This commit is contained in:
parent
e52095a0c1
commit
5bb0538f6e
3 changed files with 28 additions and 6 deletions
26
Lib/enum.py
26
Lib/enum.py
|
@ -328,6 +328,8 @@ class _proto_member:
|
||||||
# to the map, and by-value lookups for this value will be
|
# to the map, and by-value lookups for this value will be
|
||||||
# linear.
|
# linear.
|
||||||
enum_class._value2member_map_.setdefault(value, enum_member)
|
enum_class._value2member_map_.setdefault(value, enum_member)
|
||||||
|
if value not in enum_class._hashable_values_:
|
||||||
|
enum_class._hashable_values_.append(value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# keep track of the value in a list so containment checks are quick
|
# keep track of the value in a list so containment checks are quick
|
||||||
enum_class._unhashable_values_.append(value)
|
enum_class._unhashable_values_.append(value)
|
||||||
|
@ -545,7 +547,8 @@ class EnumType(type):
|
||||||
classdict['_member_names_'] = []
|
classdict['_member_names_'] = []
|
||||||
classdict['_member_map_'] = {}
|
classdict['_member_map_'] = {}
|
||||||
classdict['_value2member_map_'] = {}
|
classdict['_value2member_map_'] = {}
|
||||||
classdict['_unhashable_values_'] = []
|
classdict['_hashable_values_'] = [] # for comparing with non-hashable types
|
||||||
|
classdict['_unhashable_values_'] = [] # e.g. frozenset() with set()
|
||||||
classdict['_unhashable_values_map_'] = {}
|
classdict['_unhashable_values_map_'] = {}
|
||||||
classdict['_member_type_'] = member_type
|
classdict['_member_type_'] = member_type
|
||||||
# now set the __repr__ for the value
|
# now set the __repr__ for the value
|
||||||
|
@ -755,7 +758,10 @@ class EnumType(type):
|
||||||
try:
|
try:
|
||||||
return value in cls._value2member_map_
|
return value in cls._value2member_map_
|
||||||
except TypeError:
|
except TypeError:
|
||||||
return value in cls._unhashable_values_
|
return (
|
||||||
|
value in cls._unhashable_values_ # both structures are lists
|
||||||
|
or value in cls._hashable_values_
|
||||||
|
)
|
||||||
|
|
||||||
def __delattr__(cls, attr):
|
def __delattr__(cls, attr):
|
||||||
# nicer error message when someone tries to delete an attribute
|
# nicer error message when someone tries to delete an attribute
|
||||||
|
@ -1165,8 +1171,11 @@ class Enum(metaclass=EnumType):
|
||||||
pass
|
pass
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# not there, now do long search -- O(n) behavior
|
# not there, now do long search -- O(n) behavior
|
||||||
for name, values in cls._unhashable_values_map_.items():
|
for name, unhashable_values in cls._unhashable_values_map_.items():
|
||||||
if value in values:
|
if value in unhashable_values:
|
||||||
|
return cls[name]
|
||||||
|
for name, member in cls._member_map_.items():
|
||||||
|
if value == member._value_:
|
||||||
return cls[name]
|
return cls[name]
|
||||||
# still not found -- verify that members exist, in-case somebody got here mistakenly
|
# still not found -- verify that members exist, in-case somebody got here mistakenly
|
||||||
# (such as via super when trying to override __new__)
|
# (such as via super when trying to override __new__)
|
||||||
|
@ -1232,6 +1241,7 @@ class Enum(metaclass=EnumType):
|
||||||
# to the map, and by-value lookups for this value will be
|
# to the map, and by-value lookups for this value will be
|
||||||
# linear.
|
# linear.
|
||||||
cls._value2member_map_.setdefault(value, self)
|
cls._value2member_map_.setdefault(value, self)
|
||||||
|
cls._hashable_values_.append(value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# keep track of the value in a list so containment checks are quick
|
# keep track of the value in a list so containment checks are quick
|
||||||
cls._unhashable_values_.append(value)
|
cls._unhashable_values_.append(value)
|
||||||
|
@ -1762,6 +1772,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
||||||
body['_member_names_'] = member_names = []
|
body['_member_names_'] = member_names = []
|
||||||
body['_member_map_'] = member_map = {}
|
body['_member_map_'] = member_map = {}
|
||||||
body['_value2member_map_'] = value2member_map = {}
|
body['_value2member_map_'] = value2member_map = {}
|
||||||
|
body['_hashable_values_'] = hashable_values = []
|
||||||
body['_unhashable_values_'] = unhashable_values = []
|
body['_unhashable_values_'] = unhashable_values = []
|
||||||
body['_unhashable_values_map_'] = {}
|
body['_unhashable_values_map_'] = {}
|
||||||
body['_member_type_'] = member_type = etype._member_type_
|
body['_member_type_'] = member_type = etype._member_type_
|
||||||
|
@ -1825,7 +1836,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
||||||
contained = value2member_map.get(member._value_)
|
contained = value2member_map.get(member._value_)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
contained = None
|
contained = None
|
||||||
if member._value_ in unhashable_values:
|
if member._value_ in unhashable_values or member.value in hashable_values:
|
||||||
for m in enum_class:
|
for m in enum_class:
|
||||||
if m._value_ == member._value_:
|
if m._value_ == member._value_:
|
||||||
contained = m
|
contained = m
|
||||||
|
@ -1845,6 +1856,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
||||||
else:
|
else:
|
||||||
enum_class._add_member_(name, member)
|
enum_class._add_member_(name, member)
|
||||||
value2member_map[value] = member
|
value2member_map[value] = member
|
||||||
|
hashable_values.append(value)
|
||||||
if _is_single_bit(value):
|
if _is_single_bit(value):
|
||||||
# not a multi-bit alias, record in _member_names_ and _flag_mask_
|
# not a multi-bit alias, record in _member_names_ and _flag_mask_
|
||||||
member_names.append(name)
|
member_names.append(name)
|
||||||
|
@ -1881,7 +1893,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
||||||
contained = value2member_map.get(member._value_)
|
contained = value2member_map.get(member._value_)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
contained = None
|
contained = None
|
||||||
if member._value_ in unhashable_values:
|
if member._value_ in unhashable_values or member._value_ in hashable_values:
|
||||||
for m in enum_class:
|
for m in enum_class:
|
||||||
if m._value_ == member._value_:
|
if m._value_ == member._value_:
|
||||||
contained = m
|
contained = m
|
||||||
|
@ -1907,6 +1919,8 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
|
||||||
# to the map, and by-value lookups for this value will be
|
# to the map, and by-value lookups for this value will be
|
||||||
# linear.
|
# linear.
|
||||||
enum_class._value2member_map_.setdefault(value, member)
|
enum_class._value2member_map_.setdefault(value, member)
|
||||||
|
if value not in hashable_values:
|
||||||
|
hashable_values.append(value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# keep track of the value in a list so containment checks are quick
|
# keep track of the value in a list so containment checks are quick
|
||||||
enum_class._unhashable_values_.append(value)
|
enum_class._unhashable_values_.append(value)
|
||||||
|
|
|
@ -3474,6 +3474,13 @@ class TestSpecial(unittest.TestCase):
|
||||||
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', names=0)
|
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', names=0)
|
||||||
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', 0, type=int)
|
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', 0, type=int)
|
||||||
|
|
||||||
|
def test_nonhashable_matches_hashable(self): # issue 125710
|
||||||
|
class Directions(Enum):
|
||||||
|
DOWN_ONLY = frozenset({"sc"})
|
||||||
|
UP_ONLY = frozenset({"cs"})
|
||||||
|
UNRESTRICTED = frozenset({"sc", "cs"})
|
||||||
|
self.assertIs(Directions({"sc"}), Directions.DOWN_ONLY)
|
||||||
|
|
||||||
|
|
||||||
class TestOrder(unittest.TestCase):
|
class TestOrder(unittest.TestCase):
|
||||||
"test usage of the `_order_` attribute"
|
"test usage of the `_order_` attribute"
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
[Enum] fix hashable<->nonhashable comparisons for member values
|
Loading…
Add table
Add a link
Reference in a new issue