Guard _fileio.c against other malicious os.close(f.fileno()) attempts.
Add tests to test_fileio.py to verify behaviour.
This commit is contained in:
Kristján Valur Jónsson 2009-03-24 15:27:42 +00:00
parent 649170bd66
commit a8abe86331
2 changed files with 134 additions and 30 deletions

View file

@ -6,6 +6,7 @@ import errno
import unittest import unittest
from array import array from array import array
from weakref import proxy from weakref import proxy
from functools import wraps
from test.support import (TESTFN, findfile, check_warnings, run_unittest, from test.support import (TESTFN, findfile, check_warnings, run_unittest,
make_bad_fd) make_bad_fd)
@ -114,20 +115,106 @@ class AutoFileTests(unittest.TestCase):
else: else:
self.fail("Should have raised IOError") self.fail("Should have raised IOError")
def testErrnoOnClose(self): #A set of functions testing that we get expected behaviour if someone has
# Test that the IOError's `errno` attribute is correctly set when #manually closed the internal file descriptor. First, a decorator:
# close() fails. Here we first close the file descriptor ourselves so def ClosedFD(func):
# that close() fails with EBADF ('Bad file descriptor'). @wraps(func)
def wrapper(self):
#forcibly close the fd before invoking the problem function
f = self.f f = self.f
os.close(f.fileno()) os.close(f.fileno())
self.f = None
try: try:
f.close() func(self, f)
finally:
try:
self.f.close()
except IOError:
pass
return wrapper
def ClosedFDRaises(func):
@wraps(func)
def wrapper(self):
#forcibly close the fd before invoking the problem function
f = self.f
os.close(f.fileno())
try:
func(self, f)
except IOError as e: except IOError as e:
self.assertEqual(e.errno, errno.EBADF) self.assertEqual(e.errno, errno.EBADF)
else: else:
self.fail("Should have raised IOError") self.fail("Should have raised IOError")
finally:
try:
self.f.close()
except IOError:
pass
return wrapper
@ClosedFDRaises
def testErrnoOnClose(self, f):
f.close()
@ClosedFDRaises
def testErrnoOnClosedWrite(self, f):
f.write('a')
@ClosedFDRaises
def testErrnoOnClosedSeek(self, f):
f.seek(0)
@ClosedFDRaises
def testErrnoOnClosedTell(self, f):
f.tell()
@ClosedFDRaises
def testErrnoOnClosedTruncate(self, f):
f.truncate(0)
@ClosedFD
def testErrnoOnClosedSeekable(self, f):
f.seekable()
@ClosedFD
def testErrnoOnClosedReadable(self, f):
f.readable()
@ClosedFD
def testErrnoOnClosedWritable(self, f):
f.writable()
@ClosedFD
def testErrnoOnClosedFileno(self, f):
f.fileno()
@ClosedFD
def testErrnoOnClosedIsatty(self, f):
self.assertEqual(f.isatty(), False)
def ReopenForRead(self):
try:
self.f.close()
except IOError:
pass
self.f = _FileIO(TESTFN, 'r')
os.close(self.f.fileno())
return self.f
@ClosedFDRaises
def testErrnoOnClosedRead(self, f):
f = self.ReopenForRead()
f.read(1)
@ClosedFDRaises
def testErrnoOnClosedReadall(self, f):
f = self.ReopenForRead()
f.readall()
@ClosedFDRaises
def testErrnoOnClosedReadinto(self, f):
f = self.ReopenForRead()
a = array('b', b'x'*10)
f.readinto(a)
class OtherFileTests(unittest.TestCase): class OtherFileTests(unittest.TestCase):

View file

@ -475,10 +475,13 @@ fileio_readinto(PyFileIOObject *self, PyObject *args)
if (!PyArg_ParseTuple(args, "w*", &pbuf)) if (!PyArg_ParseTuple(args, "w*", &pbuf))
return NULL; return NULL;
if (_PyVerify_fd(self->fd)) {
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
errno = 0; errno = 0;
n = read(self->fd, pbuf.buf, pbuf.len); n = read(self->fd, pbuf.buf, pbuf.len);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
} else
n = -1;
PyBuffer_Release(&pbuf); PyBuffer_Release(&pbuf);
if (n < 0) { if (n < 0) {
if (errno == EAGAIN) if (errno == EAGAIN)
@ -522,6 +525,9 @@ fileio_readall(PyFileIOObject *self)
Py_ssize_t total = 0; Py_ssize_t total = 0;
int n; int n;
if (!_PyVerify_fd(self->fd))
return PyErr_SetFromErrno(PyExc_IOError);
result = PyBytes_FromStringAndSize(NULL, SMALLCHUNK); result = PyBytes_FromStringAndSize(NULL, SMALLCHUNK);
if (result == NULL) if (result == NULL)
return NULL; return NULL;
@ -596,10 +602,13 @@ fileio_read(PyFileIOObject *self, PyObject *args)
return NULL; return NULL;
ptr = PyBytes_AS_STRING(bytes); ptr = PyBytes_AS_STRING(bytes);
if (_PyVerify_fd(self->fd)) {
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
errno = 0; errno = 0;
n = read(self->fd, ptr, size); n = read(self->fd, ptr, size);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
} else
n = -1;
if (n < 0) { if (n < 0) {
if (errno == EAGAIN) if (errno == EAGAIN)
@ -632,10 +641,13 @@ fileio_write(PyFileIOObject *self, PyObject *args)
if (!PyArg_ParseTuple(args, "s*", &pbuf)) if (!PyArg_ParseTuple(args, "s*", &pbuf))
return NULL; return NULL;
if (_PyVerify_fd(self->fd)) {
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
errno = 0; errno = 0;
n = write(self->fd, pbuf.buf, pbuf.len); n = write(self->fd, pbuf.buf, pbuf.len);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
} else
n = -1;
PyBuffer_Release(&pbuf); PyBuffer_Release(&pbuf);
@ -688,6 +700,7 @@ portable_lseek(int fd, PyObject *posobj, int whence)
return NULL; return NULL;
} }
if (_PyVerify_fd(fd)) {
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
#if defined(MS_WIN64) || defined(MS_WINDOWS) #if defined(MS_WIN64) || defined(MS_WINDOWS)
res = _lseeki64(fd, pos, whence); res = _lseeki64(fd, pos, whence);
@ -695,6 +708,8 @@ portable_lseek(int fd, PyObject *posobj, int whence)
res = lseek(fd, pos, whence); res = lseek(fd, pos, whence);
#endif #endif
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
} else
res = -1;
if (res < 0) if (res < 0)
return PyErr_SetFromErrno(PyExc_IOError); return PyErr_SetFromErrno(PyExc_IOError);
@ -757,13 +772,15 @@ fileio_truncate(PyFileIOObject *self, PyObject *args)
/* Move to the position to be truncated. */ /* Move to the position to be truncated. */
posobj = portable_lseek(fd, posobj, 0); posobj = portable_lseek(fd, posobj, 0);
} }
if (posobj == NULL)
return NULL;
#if defined(HAVE_LARGEFILE_SUPPORT) #if defined(HAVE_LARGEFILE_SUPPORT)
pos = PyLong_AsLongLong(posobj); pos = PyLong_AsLongLong(posobj);
#else #else
pos = PyLong_AsLong(posobj); pos = PyLong_AsLong(posobj);
#endif #endif
if (PyErr_Occurred()) if (pos == -1 && PyErr_Occurred())
return NULL; return NULL;
#ifdef MS_WINDOWS #ifdef MS_WINDOWS