#4759: allow None as first argument of bytearray.translate(), for consistency with bytes.translate().

Also fix segfault for bytearray.translate(x, None) -- will backport this part to 3.0 and 2.6.
This commit is contained in:
Georg Brandl 2008-12-28 11:44:14 +00:00
parent 15fafbe6f2
commit ccc47b6eee
3 changed files with 38 additions and 16 deletions

View file

@ -888,11 +888,21 @@ class AssortedBytesTest(unittest.TestCase):
def test_translate(self): def test_translate(self):
b = b'hello' b = b'hello'
ba = bytearray(b)
rosetta = bytearray(range(0, 256)) rosetta = bytearray(range(0, 256))
rosetta[ord('o')] = ord('e') rosetta[ord('o')] = ord('e')
c = b.translate(rosetta, b'l') c = b.translate(rosetta, b'l')
self.assertEqual(b, b'hello') self.assertEqual(b, b'hello')
self.assertEqual(c, b'hee') self.assertEqual(c, b'hee')
c = ba.translate(rosetta, b'l')
self.assertEqual(ba, b'hello')
self.assertEqual(c, b'hee')
c = b.translate(None, b'e')
self.assertEqual(c, b'hllo')
c = ba.translate(None, b'e')
self.assertEqual(c, b'hllo')
self.assertRaises(TypeError, b.translate, None, None)
self.assertRaises(TypeError, ba.translate, None, None)
def test_split_bytearray(self): def test_split_bytearray(self):
self.assertEqual(b'a b'.split(memoryview(b' ')), [b'a', b'b']) self.assertEqual(b'a b'.split(memoryview(b' ')), [b'a', b'b'])

View file

@ -12,6 +12,9 @@ What's New in Python 3.1 alpha 0
Core and Builtins Core and Builtins
----------------- -----------------
- Issue #4759: None is now allowed as the first argument of
bytearray.translate(). It was always allowed for bytes.translate().
- Added test case to ensure attempts to read from a file opened for writing - Added test case to ensure attempts to read from a file opened for writing
fail. fail.

View file

@ -1371,28 +1371,32 @@ bytes_translate(PyByteArrayObject *self, PyObject *args)
PyObject *input_obj = (PyObject*)self; PyObject *input_obj = (PyObject*)self;
const char *output_start; const char *output_start;
Py_ssize_t inlen; Py_ssize_t inlen;
PyObject *result; PyObject *result = NULL;
int trans_table[256]; int trans_table[256];
PyObject *tableobj, *delobj = NULL; PyObject *tableobj = NULL, *delobj = NULL;
Py_buffer vtable, vdel; Py_buffer vtable, vdel;
if (!PyArg_UnpackTuple(args, "translate", 1, 2, if (!PyArg_UnpackTuple(args, "translate", 1, 2,
&tableobj, &delobj)) &tableobj, &delobj))
return NULL; return NULL;
if (_getbuffer(tableobj, &vtable) < 0) if (tableobj == Py_None) {
table = NULL;
tableobj = NULL;
} else if (_getbuffer(tableobj, &vtable) < 0) {
return NULL; return NULL;
} else {
if (vtable.len != 256) { if (vtable.len != 256) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"translation table must be 256 characters long"); "translation table must be 256 characters long");
result = NULL; goto done;
goto done; }
table = (const char*)vtable.buf;
} }
if (delobj != NULL) { if (delobj != NULL) {
if (_getbuffer(delobj, &vdel) < 0) { if (_getbuffer(delobj, &vdel) < 0) {
result = NULL; delobj = NULL; /* don't try to release vdel buffer on exit */
goto done; goto done;
} }
} }
@ -1401,7 +1405,6 @@ bytes_translate(PyByteArrayObject *self, PyObject *args)
vdel.len = 0; vdel.len = 0;
} }
table = (const char *)vtable.buf;
inlen = PyByteArray_GET_SIZE(input_obj); inlen = PyByteArray_GET_SIZE(input_obj);
result = PyByteArray_FromStringAndSize((char *)NULL, inlen); result = PyByteArray_FromStringAndSize((char *)NULL, inlen);
if (result == NULL) if (result == NULL)
@ -1409,7 +1412,7 @@ bytes_translate(PyByteArrayObject *self, PyObject *args)
output_start = output = PyByteArray_AsString(result); output_start = output = PyByteArray_AsString(result);
input = PyByteArray_AS_STRING(input_obj); input = PyByteArray_AS_STRING(input_obj);
if (vdel.len == 0) { if (vdel.len == 0 && table != NULL) {
/* If no deletions are required, use faster code */ /* If no deletions are required, use faster code */
for (i = inlen; --i >= 0; ) { for (i = inlen; --i >= 0; ) {
c = Py_CHARMASK(*input++); c = Py_CHARMASK(*input++);
@ -1418,8 +1421,13 @@ bytes_translate(PyByteArrayObject *self, PyObject *args)
goto done; goto done;
} }
for (i = 0; i < 256; i++) if (table == NULL) {
trans_table[i] = Py_CHARMASK(table[i]); for (i = 0; i < 256; i++)
trans_table[i] = Py_CHARMASK(i);
} else {
for (i = 0; i < 256; i++)
trans_table[i] = Py_CHARMASK(table[i]);
}
for (i = 0; i < vdel.len; i++) for (i = 0; i < vdel.len; i++)
trans_table[(int) Py_CHARMASK( ((unsigned char*)vdel.buf)[i] )] = -1; trans_table[(int) Py_CHARMASK( ((unsigned char*)vdel.buf)[i] )] = -1;
@ -1435,7 +1443,8 @@ bytes_translate(PyByteArrayObject *self, PyObject *args)
PyByteArray_Resize(result, output - output_start); PyByteArray_Resize(result, output - output_start);
done: done:
PyBuffer_Release(&vtable); if (tableobj != NULL)
PyBuffer_Release(&vtable);
if (delobj != NULL) if (delobj != NULL)
PyBuffer_Release(&vdel); PyBuffer_Release(&vdel);
return result; return result;