mirror of
https://github.com/python/cpython.git
synced 2025-08-04 17:08:35 +00:00
bpo-27575: port set intersection logic into dictview intersection (GH-7696)
This commit is contained in:
parent
c3ea41e9bf
commit
998cf1f03a
3 changed files with 93 additions and 4 deletions
|
@ -4169,24 +4169,97 @@ dictviews_sub(PyObject* self, PyObject *other)
|
|||
return result;
|
||||
}
|
||||
|
||||
PyObject*
|
||||
static int
|
||||
dictitems_contains(_PyDictViewObject *dv, PyObject *obj);
|
||||
|
||||
PyObject *
|
||||
_PyDictView_Intersect(PyObject* self, PyObject *other)
|
||||
{
|
||||
PyObject *result = PySet_New(self);
|
||||
PyObject *result;
|
||||
PyObject *it;
|
||||
PyObject *key;
|
||||
Py_ssize_t len_self;
|
||||
int rv;
|
||||
int (*dict_contains)(_PyDictViewObject *, PyObject *);
|
||||
PyObject *tmp;
|
||||
_Py_IDENTIFIER(intersection_update);
|
||||
|
||||
/* Python interpreter swaps parameters when dict view
|
||||
is on right side of & */
|
||||
if (!PyDictViewSet_Check(self)) {
|
||||
PyObject *tmp = other;
|
||||
other = self;
|
||||
self = tmp;
|
||||
}
|
||||
|
||||
len_self = dictview_len((_PyDictViewObject *)self);
|
||||
|
||||
/* if other is a set and self is smaller than other,
|
||||
reuse set intersection logic */
|
||||
if (Py_TYPE(other) == &PySet_Type && len_self <= PyObject_Size(other)) {
|
||||
_Py_IDENTIFIER(intersection);
|
||||
return _PyObject_CallMethodIdObjArgs(other, &PyId_intersection, self, NULL);
|
||||
}
|
||||
|
||||
/* if other is another dict view, and it is bigger than self,
|
||||
swap them */
|
||||
if (PyDictViewSet_Check(other)) {
|
||||
Py_ssize_t len_other = dictview_len((_PyDictViewObject *)other);
|
||||
if (len_other > len_self) {
|
||||
PyObject *tmp = other;
|
||||
other = self;
|
||||
self = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
/* at this point, two things should be true
|
||||
1. self is a dictview
|
||||
2. if other is a dictview then it is smaller than self */
|
||||
result = PySet_New(NULL);
|
||||
if (result == NULL)
|
||||
return NULL;
|
||||
|
||||
it = PyObject_GetIter(other);
|
||||
|
||||
_Py_IDENTIFIER(intersection_update);
|
||||
tmp = _PyObject_CallMethodIdOneArg(result, &PyId_intersection_update, other);
|
||||
if (tmp == NULL) {
|
||||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_DECREF(tmp);
|
||||
|
||||
if (PyDictKeys_Check(self)) {
|
||||
dict_contains = dictkeys_contains;
|
||||
}
|
||||
/* else PyDictItems_Check(self) */
|
||||
else {
|
||||
dict_contains = dictitems_contains;
|
||||
}
|
||||
|
||||
while ((key = PyIter_Next(it)) != NULL) {
|
||||
rv = dict_contains((_PyDictViewObject *)self, key);
|
||||
if (rv < 0) {
|
||||
goto error;
|
||||
}
|
||||
if (rv) {
|
||||
if (PySet_Add(result, key)) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
Py_DECREF(key);
|
||||
}
|
||||
Py_DECREF(it);
|
||||
if (PyErr_Occurred()) {
|
||||
Py_DECREF(result);
|
||||
return NULL;
|
||||
}
|
||||
return result;
|
||||
|
||||
error:
|
||||
Py_DECREF(it);
|
||||
Py_DECREF(result);
|
||||
Py_DECREF(key);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue