bpo-35013: Add more type checks for children of Element. (GH-9944)

It is now guarantied that children of xml.etree.ElementTree.Element
are Elements (at least in C implementation). Previously methods
__setitem__(), __setstate__() and __deepcopy__() could be used for
adding non-Element children.
This commit is contained in:
Serhiy Storchaka 2018-10-19 12:12:57 +03:00 committed by GitHub
parent 68def052dc
commit f081fd8303
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 50 deletions

View file

@ -1795,6 +1795,28 @@ class BasicElementTest(ElementTestCase, unittest.TestCase):
self.assertRaises(TypeError, e.append, 'b') self.assertRaises(TypeError, e.append, 'b')
self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo']) self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo'])
self.assertRaises(TypeError, e.insert, 0, 'foo') self.assertRaises(TypeError, e.insert, 0, 'foo')
e[:] = [ET.Element('bar')]
with self.assertRaises(TypeError):
e[0] = 'foo'
with self.assertRaises(TypeError):
e[:] = [ET.Element('bar'), 'foo']
if hasattr(e, '__setstate__'):
state = {
'tag': 'tag',
'_children': [None], # non-Element
'attrib': 'attr',
'tail': 'tail',
'text': 'text',
}
self.assertRaises(TypeError, e.__setstate__, state)
if hasattr(e, '__deepcopy__'):
class E(ET.Element):
def __deepcopy__(self, memo):
return None # non-Element
e[:] = [E('bar')]
self.assertRaises(TypeError, copy.deepcopy, e)
def test_cyclic_gc(self): def test_cyclic_gc(self):
class Dummy: class Dummy:
@ -1981,26 +2003,6 @@ class BadElementTest(ElementTestCase, unittest.TestCase):
elem = b.close() elem = b.close()
self.assertEqual(elem[0].tail, 'ABCDEFGHIJKL') self.assertEqual(elem[0].tail, 'ABCDEFGHIJKL')
def test_element_iter(self):
# Issue #27863
state = {
'tag': 'tag',
'_children': [None], # non-Element
'attrib': 'attr',
'tail': 'tail',
'text': 'text',
}
e = ET.Element('tag')
try:
e.__setstate__(state)
except AttributeError:
e.__dict__ = state
it = e.iter()
self.assertIs(next(it), e)
self.assertRaises(AttributeError, next, it)
def test_subscr(self): def test_subscr(self):
# Issue #27863 # Issue #27863
class X: class X:

View file

@ -217,11 +217,11 @@ class Element:
return self._children[index] return self._children[index]
def __setitem__(self, index, element): def __setitem__(self, index, element):
# if isinstance(index, slice): if isinstance(index, slice):
# for elt in element: for elt in element:
# assert iselement(elt) self._assert_is_element(elt)
# else: else:
# assert iselement(element) self._assert_is_element(element)
self._children[index] = element self._children[index] = element
def __delitem__(self, index): def __delitem__(self, index):

View file

@ -480,11 +480,24 @@ element_resize(ElementObject* self, Py_ssize_t extra)
return -1; return -1;
} }
LOCAL(void)
raise_type_error(PyObject *element)
{
PyErr_Format(PyExc_TypeError,
"expected an Element, not \"%.200s\"",
Py_TYPE(element)->tp_name);
}
LOCAL(int) LOCAL(int)
element_add_subelement(ElementObject* self, PyObject* element) element_add_subelement(ElementObject* self, PyObject* element)
{ {
/* add a child element to a parent */ /* add a child element to a parent */
if (!Element_Check(element)) {
raise_type_error(element);
return -1;
}
if (element_resize(self, 1) < 0) if (element_resize(self, 1) < 0)
return -1; return -1;
@ -803,7 +816,11 @@ _elementtree_Element___deepcopy___impl(ElementObject *self, PyObject *memo)
for (i = 0; i < self->extra->length; i++) { for (i = 0; i < self->extra->length; i++) {
PyObject* child = deepcopy(self->extra->children[i], memo); PyObject* child = deepcopy(self->extra->children[i], memo);
if (!child) { if (!child || !Element_Check(child)) {
if (child) {
raise_type_error(child);
Py_DECREF(child);
}
element->extra->length = i; element->extra->length = i;
goto error; goto error;
} }
@ -1024,8 +1041,15 @@ element_setstate_from_attributes(ElementObject *self,
/* Copy children */ /* Copy children */
for (i = 0; i < nchildren; i++) { for (i = 0; i < nchildren; i++) {
self->extra->children[i] = PyList_GET_ITEM(children, i); PyObject *child = PyList_GET_ITEM(children, i);
Py_INCREF(self->extra->children[i]); if (!Element_Check(child)) {
raise_type_error(child);
self->extra->length = i;
dealloc_extra(oldextra);
return NULL;
}
Py_INCREF(child);
self->extra->children[i] = child;
} }
assert(!self->extra->length); assert(!self->extra->length);
@ -1167,16 +1191,6 @@ _elementtree_Element_extend(ElementObject *self, PyObject *elements)
for (i = 0; i < PySequence_Fast_GET_SIZE(seq); i++) { for (i = 0; i < PySequence_Fast_GET_SIZE(seq); i++) {
PyObject* element = PySequence_Fast_GET_ITEM(seq, i); PyObject* element = PySequence_Fast_GET_ITEM(seq, i);
Py_INCREF(element); Py_INCREF(element);
if (!Element_Check(element)) {
PyErr_Format(
PyExc_TypeError,
"expected an Element, not \"%.200s\"",
Py_TYPE(element)->tp_name);
Py_DECREF(seq);
Py_DECREF(element);
return NULL;
}
if (element_add_subelement(self, element) < 0) { if (element_add_subelement(self, element) < 0) {
Py_DECREF(seq); Py_DECREF(seq);
Py_DECREF(element); Py_DECREF(element);
@ -1219,8 +1233,7 @@ _elementtree_Element_find_impl(ElementObject *self, PyObject *path,
for (i = 0; i < self->extra->length; i++) { for (i = 0; i < self->extra->length; i++) {
PyObject* item = self->extra->children[i]; PyObject* item = self->extra->children[i];
int rc; int rc;
if (!Element_Check(item)) assert(Element_Check(item));
continue;
Py_INCREF(item); Py_INCREF(item);
rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ); rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ);
if (rc > 0) if (rc > 0)
@ -1266,8 +1279,7 @@ _elementtree_Element_findtext_impl(ElementObject *self, PyObject *path,
for (i = 0; i < self->extra->length; i++) { for (i = 0; i < self->extra->length; i++) {
PyObject *item = self->extra->children[i]; PyObject *item = self->extra->children[i];
int rc; int rc;
if (!Element_Check(item)) assert(Element_Check(item));
continue;
Py_INCREF(item); Py_INCREF(item);
rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ); rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ);
if (rc > 0) { if (rc > 0) {
@ -1323,8 +1335,7 @@ _elementtree_Element_findall_impl(ElementObject *self, PyObject *path,
for (i = 0; i < self->extra->length; i++) { for (i = 0; i < self->extra->length; i++) {
PyObject* item = self->extra->children[i]; PyObject* item = self->extra->children[i];
int rc; int rc;
if (!Element_Check(item)) assert(Element_Check(item));
continue;
Py_INCREF(item); Py_INCREF(item);
rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ); rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ);
if (rc != 0 && (rc < 0 || PyList_Append(out, item) < 0)) { if (rc != 0 && (rc < 0 || PyList_Append(out, item) < 0)) {
@ -1736,6 +1747,10 @@ element_setitem(PyObject* self_, Py_ssize_t index, PyObject* item)
old = self->extra->children[index]; old = self->extra->children[index];
if (item) { if (item) {
if (!Element_Check(item)) {
raise_type_error(item);
return -1;
}
Py_INCREF(item); Py_INCREF(item);
self->extra->children[index] = item; self->extra->children[index] = item;
} else { } else {
@ -1930,6 +1945,15 @@ element_ass_subscr(PyObject* self_, PyObject* item, PyObject* value)
} }
} }
for (i = 0; i < newlen; i++) {
PyObject *element = PySequence_Fast_GET_ITEM(seq, i);
if (!Element_Check(element)) {
raise_type_error(element);
Py_DECREF(seq);
return -1;
}
}
if (slicelen > 0) { if (slicelen > 0) {
/* to avoid recursive calls to this method (via decref), move /* to avoid recursive calls to this method (via decref), move
old items to the recycle bin here, and get rid of them when old items to the recycle bin here, and get rid of them when
@ -2207,12 +2231,7 @@ elementiter_next(ElementIterObject *it)
continue; continue;
} }
if (!Element_Check(extra->children[child_index])) { assert(Element_Check(extra->children[child_index]));
PyErr_Format(PyExc_AttributeError,
"'%.100s' object has no attribute 'iter'",
Py_TYPE(extra->children[child_index])->tp_name);
return NULL;
}
elem = (ElementObject *)extra->children[child_index]; elem = (ElementObject *)extra->children[child_index];
item->child_index++; item->child_index++;
Py_INCREF(elem); Py_INCREF(elem);