Make C helper function more closely match the pure python version, and add tests.

This commit is contained in:
Raymond Hettinger 2011-01-03 02:12:02 +00:00
parent 23eaa70057
commit 426e052a4f
2 changed files with 63 additions and 23 deletions

View file

@ -3,7 +3,7 @@
import unittest, doctest, operator import unittest, doctest, operator
import inspect import inspect
from test import support from test import support
from collections import namedtuple, Counter, OrderedDict from collections import namedtuple, Counter, OrderedDict, _count_elements
from test import mapping_tests from test import mapping_tests
import pickle, copy import pickle, copy
from random import randrange, shuffle from random import randrange, shuffle
@ -775,6 +775,19 @@ class TestCounter(unittest.TestCase):
c.subtract('aaaabbcce') c.subtract('aaaabbcce')
self.assertEqual(c, Counter(a=-1, b=0, c=-1, d=1, e=-1)) self.assertEqual(c, Counter(a=-1, b=0, c=-1, d=1, e=-1))
def test_helper_function(self):
# two paths, one for real dicts and one for other mappings
elems = list('abracadabra')
d = dict()
_count_elements(d, elems)
self.assertEqual(d, {'a': 5, 'r': 2, 'b': 2, 'c': 1, 'd': 1})
m = OrderedDict()
_count_elements(m, elems)
self.assertEqual(m,
OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
class TestOrderedDict(unittest.TestCase): class TestOrderedDict(unittest.TestCase):
def test_init(self): def test_init(self):

View file

@ -1536,25 +1536,23 @@ _count_elements(PyObject *self, PyObject *args)
if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable)) if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
return NULL; return NULL;
if (!PyDict_Check(mapping)) {
PyErr_SetString(PyExc_TypeError,
"Expected mapping argument to be a dictionary");
return NULL;
}
it = PyObject_GetIter(iterable); it = PyObject_GetIter(iterable);
if (it == NULL) if (it == NULL)
return NULL; return NULL;
one = PyLong_FromLong(1); one = PyLong_FromLong(1);
if (one == NULL) { if (one == NULL) {
Py_DECREF(it); Py_DECREF(it);
return NULL; return NULL;
} }
if (PyDict_CheckExact(mapping)) {
while (1) { while (1) {
key = PyIter_Next(it); key = PyIter_Next(it);
if (key == NULL) { if (key == NULL) {
if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
PyErr_Clear(); PyErr_Clear();
else
break; break;
} }
oldval = PyDict_GetItem(mapping, key); oldval = PyDict_GetItem(mapping, key);
@ -1571,6 +1569,35 @@ _count_elements(PyObject *self, PyObject *args)
} }
Py_DECREF(key); Py_DECREF(key);
} }
} else {
while (1) {
key = PyIter_Next(it);
if (key == NULL) {
if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
PyErr_Clear();
else
break;
}
oldval = PyObject_GetItem(mapping, key);
if (oldval == NULL) {
if (!PyErr_Occurred() || !PyErr_ExceptionMatches(PyExc_KeyError))
break;
PyErr_Clear();
Py_INCREF(one);
newval = one;
} else {
newval = PyNumber_Add(oldval, one);
Py_DECREF(oldval);
if (newval == NULL)
break;
}
if (PyObject_SetItem(mapping, key, newval) == -1)
break;
Py_CLEAR(newval);
Py_DECREF(key);
}
}
Py_DECREF(it); Py_DECREF(it);
Py_XDECREF(key); Py_XDECREF(key);
Py_XDECREF(newval); Py_XDECREF(newval);