Generalize filter(f, seq) to work with iterators. This also generalizes

filter() to no longer insist that len(seq) be defined.
NEEDS DOC CHANGES.
This commit is contained in:
Tim Peters 2001-05-02 07:39:38 +00:00
parent 6ad22c41c2
commit 0e57abf0cd
3 changed files with 109 additions and 50 deletions

View file

@ -275,4 +275,48 @@ class TestCase(unittest.TestCase):
except OSError: except OSError:
pass pass
# Test filter()'s use of iterators.
def test_builtin_filter(self):
self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
self.assertEqual(filter(None, SequenceClass(0)), [])
self.assertEqual(filter(None, ()), ())
self.assertEqual(filter(None, "abc"), "abc")
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(filter(None, d), d.keys())
self.assertRaises(TypeError, filter, None, list)
self.assertRaises(TypeError, filter, None, 42)
class Boolean:
def __init__(self, truth):
self.truth = truth
def __nonzero__(self):
return self.truth
True = Boolean(1)
False = Boolean(0)
class Seq:
def __init__(self, *args):
self.vals = args
def __iter__(self):
class SeqIter:
def __init__(self, vals):
self.vals = vals
self.i = 0
def __iter__(self):
return self
def next(self):
i = self.i
self.i = i + 1
if i < len(self.vals):
return self.vals[i]
else:
raise StopIteration
return SeqIter(self.vals)
seq = Seq(*([True, False] * 25))
self.assertEqual(filter(lambda x: not x, seq), [False]*25)
self.assertEqual(filter(lambda x: not x, iter(seq)), [False]*25)
run_unittest(TestCase) run_unittest(TestCase)

View file

@ -17,6 +17,7 @@ Core
- The following functions were generalized to work nicely with iterator - The following functions were generalized to work nicely with iterator
arguments: arguments:
filter()
list() list()

View file

@ -162,53 +162,65 @@ Note that classes are callable, as are instances with a __call__() method.";
static PyObject * static PyObject *
builtin_filter(PyObject *self, PyObject *args) builtin_filter(PyObject *self, PyObject *args)
{ {
PyObject *func, *seq, *result; PyObject *func, *seq, *result, *it;
PySequenceMethods *sqf; int len; /* guess for result list size */
int len;
register int i, j; register int i, j;
if (!PyArg_ParseTuple(args, "OO:filter", &func, &seq)) if (!PyArg_ParseTuple(args, "OO:filter", &func, &seq))
return NULL; return NULL;
if (PyString_Check(seq)) { /* Strings and tuples return a result of the same type. */
PyObject *r = filterstring(func, seq); if (PyString_Check(seq))
return r; return filterstring(func, seq);
if (PyTuple_Check(seq))
return filtertuple(func, seq);
/* Get iterator. */
it = PyObject_GetIter(seq);
if (it == NULL)
return NULL;
/* Guess a result list size. */
len = -1; /* unknown */
if (PySequence_Check(seq) &&
seq->ob_type->tp_as_sequence->sq_length) {
len = PySequence_Size(seq);
if (len < 0)
PyErr_Clear();
} }
if (len < 0)
len = 8; /* arbitrary */
if (PyTuple_Check(seq)) { /* Get a result list. */
PyObject *r = filtertuple(func, seq);
return r;
}
sqf = seq->ob_type->tp_as_sequence;
if (sqf == NULL || sqf->sq_length == NULL || sqf->sq_item == NULL) {
PyErr_SetString(PyExc_TypeError,
"filter() arg 2 must be a sequence");
goto Fail_2;
}
if ((len = (*sqf->sq_length)(seq)) < 0)
goto Fail_2;
if (PyList_Check(seq) && seq->ob_refcnt == 1) { if (PyList_Check(seq) && seq->ob_refcnt == 1) {
/* Eww - can modify the list in-place. */
Py_INCREF(seq); Py_INCREF(seq);
result = seq; result = seq;
} }
else { else {
if ((result = PyList_New(len)) == NULL) result = PyList_New(len);
goto Fail_2; if (result == NULL)
goto Fail_it;
} }
/* Build the result list. */
for (i = j = 0; ; ++i) { for (i = j = 0; ; ++i) {
PyObject *item, *good; PyObject *item, *good;
int ok; int ok;
if ((item = (*sqf->sq_item)(seq, i)) == NULL) { item = PyIter_Next(it);
if (PyErr_ExceptionMatches(PyExc_IndexError)) { if (item == NULL) {
/* We're out of here in any case, but if this is a
* StopIteration exception it's expected, but if
* any other kind of exception it's an error.
*/
if (PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_StopIteration))
PyErr_Clear(); PyErr_Clear();
break; else
goto Fail_result_it;
} }
goto Fail_1; break;
} }
if (func == Py_None) { if (func == Py_None) {
@ -217,43 +229,45 @@ builtin_filter(PyObject *self, PyObject *args)
} }
else { else {
PyObject *arg = Py_BuildValue("(O)", item); PyObject *arg = Py_BuildValue("(O)", item);
if (arg == NULL) if (arg == NULL) {
goto Fail_1; Py_DECREF(item);
goto Fail_result_it;
}
good = PyEval_CallObject(func, arg); good = PyEval_CallObject(func, arg);
Py_DECREF(arg); Py_DECREF(arg);
if (good == NULL) { if (good == NULL) {
Py_DECREF(item); Py_DECREF(item);
goto Fail_1; goto Fail_result_it;
} }
} }
ok = PyObject_IsTrue(good); ok = PyObject_IsTrue(good);
Py_DECREF(good); Py_DECREF(good);
if (ok) { if (ok) {
if (j < len) { if (j < len)
if (PyList_SetItem(result, j++, item) < 0) PyList_SET_ITEM(result, j, item);
goto Fail_1;
}
else { else {
int status = PyList_Append(result, item); int status = PyList_Append(result, item);
j++;
Py_DECREF(item); Py_DECREF(item);
if (status < 0) if (status < 0)
goto Fail_1; goto Fail_result_it;
} }
} else { ++j;
}
else
Py_DECREF(item); Py_DECREF(item);
} }
}
/* Cut back result list if len is too big. */
if (j < len && PyList_SetSlice(result, j, len, NULL) < 0) if (j < len && PyList_SetSlice(result, j, len, NULL) < 0)
goto Fail_1; goto Fail_result_it;
return result; return result;
Fail_1: Fail_result_it:
Py_DECREF(result); Py_DECREF(result);
Fail_2: Fail_it:
Py_DECREF(it);
return NULL; return NULL;
} }