bpo-24275: Don't downgrade unicode-only dicts to mixed on lookups (GH-25186)

This commit is contained in:
Hristo Venev 2021-04-29 05:06:03 +03:00 committed by GitHub
parent 69a733bda3
commit 8557edbfa8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 3 deletions

View file

@ -1471,6 +1471,106 @@ class DictTest(unittest.TestCase):
gc.collect()
self.assertTrue(gc.is_tracked(next(it)))
def test_str_nonstr(self):
# cpython uses a different lookup function if the dict only contains
# `str` keys. Make sure the unoptimized path is used when a non-`str`
# key appears.
class StrSub(str):
pass
eq_count = 0
# This class compares equal to the string 'key3'
class Key3:
def __hash__(self):
return hash('key3')
def __eq__(self, other):
nonlocal eq_count
if isinstance(other, Key3) or isinstance(other, str) and other == 'key3':
eq_count += 1
return True
return False
key3_1 = StrSub('key3')
key3_2 = Key3()
key3_3 = Key3()
dicts = []
# Create dicts of the form `{'key1': 42, 'key2': 43, key3: 44}` in a
# bunch of different ways. In all cases, `key3` is not of type `str`.
# `key3_1` is a `str` subclass and `key3_2` is a completely unrelated
# type.
for key3 in (key3_1, key3_2):
# A literal
dicts.append({'key1': 42, 'key2': 43, key3: 44})
# key3 inserted via `dict.__setitem__`
d = {'key1': 42, 'key2': 43}
d[key3] = 44
dicts.append(d)
# key3 inserted via `dict.setdefault`
d = {'key1': 42, 'key2': 43}
self.assertEqual(d.setdefault(key3, 44), 44)
dicts.append(d)
# key3 inserted via `dict.update`
d = {'key1': 42, 'key2': 43}
d.update({key3: 44})
dicts.append(d)
# key3 inserted via `dict.__ior__`
d = {'key1': 42, 'key2': 43}
d |= {key3: 44}
dicts.append(d)
# `dict(iterable)`
def make_pairs():
yield ('key1', 42)
yield ('key2', 43)
yield (key3, 44)
d = dict(make_pairs())
dicts.append(d)
# `dict.copy`
d = d.copy()
dicts.append(d)
# dict comprehension
d = {key: 42 + i for i,key in enumerate(['key1', 'key2', key3])}
dicts.append(d)
for d in dicts:
with self.subTest(d=d):
self.assertEqual(d.get('key1'), 42)
# Try to make an object that is of type `str` and is equal to
# `'key1'`, but (at least on cpython) is a different object.
noninterned_key1 = 'ke'
noninterned_key1 += 'y1'
if support.check_impl_detail(cpython=True):
# suppress a SyntaxWarning
interned_key1 = 'key1'
self.assertFalse(noninterned_key1 is interned_key1)
self.assertEqual(d.get(noninterned_key1), 42)
self.assertEqual(d.get('key3'), 44)
self.assertEqual(d.get(key3_1), 44)
self.assertEqual(d.get(key3_2), 44)
# `key3_3` itself is definitely not a dict key, so make sure
# that `__eq__` gets called.
#
# Note that this might not hold for `key3_1` and `key3_2`
# because they might be the same object as one of the dict keys,
# in which case implementations are allowed to skip the call to
# `__eq__`.
eq_count = 0
self.assertEqual(d.get(key3_3), 44)
self.assertGreaterEqual(eq_count, 1)
class CAPITest(unittest.TestCase):

View file

@ -857,7 +857,6 @@ lookdict_unicode(PyDictObject *mp, PyObject *key,
unicodes is to override __eq__, and for speed we don't cater to
that here. */
if (!PyUnicode_CheckExact(key)) {
mp->ma_keys->dk_lookup = lookdict;
return lookdict(mp, key, hash, value_addr);
}
@ -900,7 +899,6 @@ lookdict_unicode_nodummy(PyDictObject *mp, PyObject *key,
unicodes is to override __eq__, and for speed we don't cater to
that here. */
if (!PyUnicode_CheckExact(key)) {
mp->ma_keys->dk_lookup = lookdict;
return lookdict(mp, key, hash, value_addr);
}
@ -1084,7 +1082,6 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value)
if (ix == DKIX_ERROR)
goto Fail;
assert(PyUnicode_CheckExact(key) || mp->ma_keys->dk_lookup == lookdict);
MAINTAIN_TRACKING(mp, key, value);
/* When insertion order is different from shared key, we can't share
@ -1106,6 +1103,9 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value)
if (insertion_resize(mp) < 0)
goto Fail;
}
if (!PyUnicode_CheckExact(key) && mp->ma_keys->dk_lookup != lookdict) {
mp->ma_keys->dk_lookup = lookdict;
}
Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash);
ep = &DK_ENTRIES(mp->ma_keys)[mp->ma_keys->dk_nentries];
dictkeys_set_index(mp->ma_keys, hashpos, mp->ma_keys->dk_nentries);
@ -3068,6 +3068,9 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
return NULL;
}
}
if (!PyUnicode_CheckExact(key) && mp->ma_keys->dk_lookup != lookdict) {
mp->ma_keys->dk_lookup = lookdict;
}
Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash);
ep0 = DK_ENTRIES(mp->ma_keys);
ep = &ep0[mp->ma_keys->dk_nentries];