bpo-45126: Harden sqlite3 connection initialisation (GH-28227)

This commit is contained in:
Erlend Egeberg Aasland 2021-11-16 15:53:35 +01:00 committed by GitHub
parent 6a84d61c55
commit 9d6215a54c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 62 deletions

View file

@ -523,6 +523,44 @@ class ConnectionTests(unittest.TestCase):
with memory_database(isolation_level=level) as cx: with memory_database(isolation_level=level) as cx:
cx.execute("select 'ok'") cx.execute("select 'ok'")
def test_connection_reinit(self):
db = ":memory:"
cx = sqlite.connect(db)
cx.text_factory = bytes
cx.row_factory = sqlite.Row
cu = cx.cursor()
cu.execute("create table foo (bar)")
cu.executemany("insert into foo (bar) values (?)",
((str(v),) for v in range(4)))
cu.execute("select bar from foo")
rows = [r for r in cu.fetchmany(2)]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
cx.__init__(db)
cx.execute("create table foo (bar)")
cx.executemany("insert into foo (bar) values (?)",
((v,) for v in ("a", "b", "c", "d")))
# This uses the old database, old row factory, but new text factory
rows = [r for r in cu.fetchall()]
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"])
def test_connection_bad_reinit(self):
cx = sqlite.connect(":memory:")
with cx:
cx.execute("create table t(t)")
with temp_dir() as db:
self.assertRaisesRegex(sqlite.OperationalError,
"unable to open database file",
cx.__init__, db)
self.assertRaisesRegex(sqlite.ProgrammingError,
"Base Connection.__init__ not called",
cx.executemany, "insert into t values(?)",
((v,) for v in range(3)))
class UninitialisedConnectionTests(unittest.TestCase): class UninitialisedConnectionTests(unittest.TestCase):
def setUp(self): def setUp(self):

View file

@ -7,7 +7,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
const char *database, double timeout, const char *database, double timeout,
int detect_types, const char *isolation_level, int detect_types, const char *isolation_level,
int check_same_thread, PyObject *factory, int check_same_thread, PyObject *factory,
int cached_statements, int uri); int cache_size, int uri);
static int static int
pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs) pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
@ -25,7 +25,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
const char *isolation_level = ""; const char *isolation_level = "";
int check_same_thread = 1; int check_same_thread = 1;
PyObject *factory = (PyObject*)clinic_state()->ConnectionType; PyObject *factory = (PyObject*)clinic_state()->ConnectionType;
int cached_statements = 128; int cache_size = 128;
int uri = 0; int uri = 0;
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf); fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf);
@ -101,8 +101,8 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
} }
} }
if (fastargs[6]) { if (fastargs[6]) {
cached_statements = _PyLong_AsInt(fastargs[6]); cache_size = _PyLong_AsInt(fastargs[6]);
if (cached_statements == -1 && PyErr_Occurred()) { if (cache_size == -1 && PyErr_Occurred()) {
goto exit; goto exit;
} }
if (!--noptargs) { if (!--noptargs) {
@ -114,7 +114,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
goto exit; goto exit;
} }
skip_optional_pos: skip_optional_pos:
return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cached_statements, uri); return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cache_size, uri);
exit: exit:
/* Cleanup for database */ /* Cleanup for database */
@ -851,4 +851,4 @@ exit:
#ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF #ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
#define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF #define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
#endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */ #endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */
/*[clinic end generated code: output=663b1e9e71128f19 input=a9049054013a1b77]*/ /*[clinic end generated code: output=6f267f20e77f92d0 input=a9049054013a1b77]*/

View file

@ -83,15 +83,17 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
static void free_callback_context(callback_context *ctx); static void free_callback_context(callback_context *ctx);
static void set_callback_context(callback_context **ctx_pp, static void set_callback_context(callback_context **ctx_pp,
callback_context *ctx); callback_context *ctx);
static void connection_close(pysqlite_Connection *self);
static PyObject * static PyObject *
new_statement_cache(pysqlite_Connection *self, int maxsize) new_statement_cache(pysqlite_Connection *self, pysqlite_state *state,
int maxsize)
{ {
PyObject *args[] = { NULL, PyLong_FromLong(maxsize), }; PyObject *args[] = { NULL, PyLong_FromLong(maxsize), };
if (args[1] == NULL) { if (args[1] == NULL) {
return NULL; return NULL;
} }
PyObject *lru_cache = self->state->lru_cache; PyObject *lru_cache = state->lru_cache;
size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET; size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET;
PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL); PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL);
Py_DECREF(args[1]); Py_DECREF(args[1]);
@ -153,7 +155,7 @@ _sqlite3.Connection.__init__ as pysqlite_connection_init
isolation_level: str(accept={str, NoneType}) = "" isolation_level: str(accept={str, NoneType}) = ""
check_same_thread: bool(accept={int}) = True check_same_thread: bool(accept={int}) = True
factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType
cached_statements: int = 128 cached_statements as cache_size: int = 128
uri: bool = False uri: bool = False
[clinic start generated code]*/ [clinic start generated code]*/
@ -162,78 +164,82 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
const char *database, double timeout, const char *database, double timeout,
int detect_types, const char *isolation_level, int detect_types, const char *isolation_level,
int check_same_thread, PyObject *factory, int check_same_thread, PyObject *factory,
int cached_statements, int uri) int cache_size, int uri)
/*[clinic end generated code: output=d8c37afc46d318b0 input=adfb29ac461f9e61]*/ /*[clinic end generated code: output=7d640ae1d83abfd4 input=35e316f66d9f70fd]*/
{ {
int rc;
if (PySys_Audit("sqlite3.connect", "s", database) < 0) { if (PySys_Audit("sqlite3.connect", "s", database) < 0) {
return -1; return -1;
} }
pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self)); if (self->initialized) {
self->state = state; PyTypeObject *tp = Py_TYPE(self);
tp->tp_clear((PyObject *)self);
Py_CLEAR(self->statement_cache); connection_close(self);
Py_CLEAR(self->cursors); self->initialized = 0;
}
Py_INCREF(Py_None);
Py_XSETREF(self->row_factory, Py_None);
Py_INCREF(&PyUnicode_Type);
Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type);
// Create and configure SQLite database object.
sqlite3 *db;
int rc;
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
rc = sqlite3_open_v2(database, &self->db, rc = sqlite3_open_v2(database, &db,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE |
(uri ? SQLITE_OPEN_URI : 0), NULL); (uri ? SQLITE_OPEN_URI : 0), NULL);
if (rc == SQLITE_OK) {
(void)sqlite3_busy_timeout(db, (int)(timeout*1000));
}
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
if (self->db == NULL && rc == SQLITE_NOMEM) { if (db == NULL && rc == SQLITE_NOMEM) {
PyErr_NoMemory(); PyErr_NoMemory();
return -1; return -1;
} }
pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
_pysqlite_seterror(state, self->db); _pysqlite_seterror(state, db);
return -1; return -1;
} }
if (isolation_level) { // Convert isolation level to begin statement.
const char *stmt = get_begin_statement(isolation_level); const char *begin_statement = NULL;
if (stmt == NULL) { if (isolation_level != NULL) {
begin_statement = get_begin_statement(isolation_level);
if (begin_statement == NULL) {
return -1; return -1;
} }
self->begin_statement = stmt;
}
else {
self->begin_statement = NULL;
} }
self->statement_cache = new_statement_cache(self, cached_statements); // Create LRU statement cache; returns a new reference.
if (self->statement_cache == NULL) { PyObject *statement_cache = new_statement_cache(self, state, cache_size);
return -1; if (statement_cache == NULL) {
}
if (PyErr_Occurred()) {
return -1; return -1;
} }
self->created_cursors = 0; // Create list of weak references to cursors.
PyObject *cursors = PyList_New(0);
/* Create list of weak references to cursors */ if (cursors == NULL) {
self->cursors = PyList_New(0); Py_DECREF(statement_cache);
if (self->cursors == NULL) {
return -1; return -1;
} }
// Init connection state members.
self->db = db;
self->state = state;
self->detect_types = detect_types; self->detect_types = detect_types;
(void)sqlite3_busy_timeout(self->db, (int)(timeout*1000)); self->begin_statement = begin_statement;
self->thread_ident = PyThread_get_thread_ident();
self->check_same_thread = check_same_thread; self->check_same_thread = check_same_thread;
self->thread_ident = PyThread_get_thread_ident();
self->statement_cache = statement_cache;
self->cursors = cursors;
self->created_cursors = 0;
self->row_factory = Py_NewRef(Py_None);
self->text_factory = Py_NewRef(&PyUnicode_Type);
self->trace_ctx = NULL;
self->progress_ctx = NULL;
self->authorizer_ctx = NULL;
set_callback_context(&self->trace_ctx, NULL); // Borrowed refs
set_callback_context(&self->progress_ctx, NULL);
set_callback_context(&self->authorizer_ctx, NULL);
self->Warning = state->Warning; self->Warning = state->Warning;
self->Error = state->Error; self->Error = state->Error;
self->InterfaceError = state->InterfaceError; self->InterfaceError = state->InterfaceError;
@ -250,7 +256,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
} }
self->initialized = 1; self->initialized = 1;
return 0; return 0;
} }
@ -321,16 +326,6 @@ connection_clear(pysqlite_Connection *self)
return 0; return 0;
} }
static void
connection_close(pysqlite_Connection *self)
{
if (self->db) {
int rc = sqlite3_close_v2(self->db);
assert(rc == SQLITE_OK), (void)rc;
self->db = NULL;
}
}
static void static void
free_callback_contexts(pysqlite_Connection *self) free_callback_contexts(pysqlite_Connection *self)
{ {
@ -339,6 +334,22 @@ free_callback_contexts(pysqlite_Connection *self)
set_callback_context(&self->authorizer_ctx, NULL); set_callback_context(&self->authorizer_ctx, NULL);
} }
static void
connection_close(pysqlite_Connection *self)
{
if (self->db) {
free_callback_contexts(self);
sqlite3 *db = self->db;
self->db = NULL;
Py_BEGIN_ALLOW_THREADS
int rc = sqlite3_close_v2(db);
assert(rc == SQLITE_OK), (void)rc;
Py_END_ALLOW_THREADS
}
}
static void static void
connection_dealloc(pysqlite_Connection *self) connection_dealloc(pysqlite_Connection *self)
{ {
@ -348,7 +359,6 @@ connection_dealloc(pysqlite_Connection *self)
/* Clean up if user has not called .close() explicitly. */ /* Clean up if user has not called .close() explicitly. */
connection_close(self); connection_close(self);
free_callback_contexts(self);
tp->tp_free(self); tp->tp_free(self);
Py_DECREF(tp); Py_DECREF(tp);