mirror of
https://github.com/python/cpython.git
synced 2025-08-04 17:08:35 +00:00
bpo-45126: Harden sqlite3
connection initialisation (GH-28227)
This commit is contained in:
parent
6a84d61c55
commit
9d6215a54c
3 changed files with 110 additions and 62 deletions
|
@ -523,6 +523,44 @@ class ConnectionTests(unittest.TestCase):
|
|||
with memory_database(isolation_level=level) as cx:
|
||||
cx.execute("select 'ok'")
|
||||
|
||||
def test_connection_reinit(self):
|
||||
db = ":memory:"
|
||||
cx = sqlite.connect(db)
|
||||
cx.text_factory = bytes
|
||||
cx.row_factory = sqlite.Row
|
||||
cu = cx.cursor()
|
||||
cu.execute("create table foo (bar)")
|
||||
cu.executemany("insert into foo (bar) values (?)",
|
||||
((str(v),) for v in range(4)))
|
||||
cu.execute("select bar from foo")
|
||||
|
||||
rows = [r for r in cu.fetchmany(2)]
|
||||
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
|
||||
self.assertEqual([r[0] for r in rows], [b"0", b"1"])
|
||||
|
||||
cx.__init__(db)
|
||||
cx.execute("create table foo (bar)")
|
||||
cx.executemany("insert into foo (bar) values (?)",
|
||||
((v,) for v in ("a", "b", "c", "d")))
|
||||
|
||||
# This uses the old database, old row factory, but new text factory
|
||||
rows = [r for r in cu.fetchall()]
|
||||
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
|
||||
self.assertEqual([r[0] for r in rows], ["2", "3"])
|
||||
|
||||
def test_connection_bad_reinit(self):
|
||||
cx = sqlite.connect(":memory:")
|
||||
with cx:
|
||||
cx.execute("create table t(t)")
|
||||
with temp_dir() as db:
|
||||
self.assertRaisesRegex(sqlite.OperationalError,
|
||||
"unable to open database file",
|
||||
cx.__init__, db)
|
||||
self.assertRaisesRegex(sqlite.ProgrammingError,
|
||||
"Base Connection.__init__ not called",
|
||||
cx.executemany, "insert into t values(?)",
|
||||
((v,) for v in range(3)))
|
||||
|
||||
|
||||
class UninitialisedConnectionTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
|
|||
const char *database, double timeout,
|
||||
int detect_types, const char *isolation_level,
|
||||
int check_same_thread, PyObject *factory,
|
||||
int cached_statements, int uri);
|
||||
int cache_size, int uri);
|
||||
|
||||
static int
|
||||
pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
|
||||
|
@ -25,7 +25,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
|
|||
const char *isolation_level = "";
|
||||
int check_same_thread = 1;
|
||||
PyObject *factory = (PyObject*)clinic_state()->ConnectionType;
|
||||
int cached_statements = 128;
|
||||
int cache_size = 128;
|
||||
int uri = 0;
|
||||
|
||||
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf);
|
||||
|
@ -101,8 +101,8 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
|
|||
}
|
||||
}
|
||||
if (fastargs[6]) {
|
||||
cached_statements = _PyLong_AsInt(fastargs[6]);
|
||||
if (cached_statements == -1 && PyErr_Occurred()) {
|
||||
cache_size = _PyLong_AsInt(fastargs[6]);
|
||||
if (cache_size == -1 && PyErr_Occurred()) {
|
||||
goto exit;
|
||||
}
|
||||
if (!--noptargs) {
|
||||
|
@ -114,7 +114,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
|
|||
goto exit;
|
||||
}
|
||||
skip_optional_pos:
|
||||
return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cached_statements, uri);
|
||||
return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cache_size, uri);
|
||||
|
||||
exit:
|
||||
/* Cleanup for database */
|
||||
|
@ -851,4 +851,4 @@ exit:
|
|||
#ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
|
||||
#define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
|
||||
#endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */
|
||||
/*[clinic end generated code: output=663b1e9e71128f19 input=a9049054013a1b77]*/
|
||||
/*[clinic end generated code: output=6f267f20e77f92d0 input=a9049054013a1b77]*/
|
||||
|
|
|
@ -83,15 +83,17 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
|
|||
static void free_callback_context(callback_context *ctx);
|
||||
static void set_callback_context(callback_context **ctx_pp,
|
||||
callback_context *ctx);
|
||||
static void connection_close(pysqlite_Connection *self);
|
||||
|
||||
static PyObject *
|
||||
new_statement_cache(pysqlite_Connection *self, int maxsize)
|
||||
new_statement_cache(pysqlite_Connection *self, pysqlite_state *state,
|
||||
int maxsize)
|
||||
{
|
||||
PyObject *args[] = { NULL, PyLong_FromLong(maxsize), };
|
||||
if (args[1] == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *lru_cache = self->state->lru_cache;
|
||||
PyObject *lru_cache = state->lru_cache;
|
||||
size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET;
|
||||
PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL);
|
||||
Py_DECREF(args[1]);
|
||||
|
@ -153,7 +155,7 @@ _sqlite3.Connection.__init__ as pysqlite_connection_init
|
|||
isolation_level: str(accept={str, NoneType}) = ""
|
||||
check_same_thread: bool(accept={int}) = True
|
||||
factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType
|
||||
cached_statements: int = 128
|
||||
cached_statements as cache_size: int = 128
|
||||
uri: bool = False
|
||||
[clinic start generated code]*/
|
||||
|
||||
|
@ -162,78 +164,82 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
|
|||
const char *database, double timeout,
|
||||
int detect_types, const char *isolation_level,
|
||||
int check_same_thread, PyObject *factory,
|
||||
int cached_statements, int uri)
|
||||
/*[clinic end generated code: output=d8c37afc46d318b0 input=adfb29ac461f9e61]*/
|
||||
int cache_size, int uri)
|
||||
/*[clinic end generated code: output=7d640ae1d83abfd4 input=35e316f66d9f70fd]*/
|
||||
{
|
||||
int rc;
|
||||
|
||||
if (PySys_Audit("sqlite3.connect", "s", database) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
|
||||
self->state = state;
|
||||
|
||||
Py_CLEAR(self->statement_cache);
|
||||
Py_CLEAR(self->cursors);
|
||||
|
||||
Py_INCREF(Py_None);
|
||||
Py_XSETREF(self->row_factory, Py_None);
|
||||
|
||||
Py_INCREF(&PyUnicode_Type);
|
||||
Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type);
|
||||
if (self->initialized) {
|
||||
PyTypeObject *tp = Py_TYPE(self);
|
||||
tp->tp_clear((PyObject *)self);
|
||||
connection_close(self);
|
||||
self->initialized = 0;
|
||||
}
|
||||
|
||||
// Create and configure SQLite database object.
|
||||
sqlite3 *db;
|
||||
int rc;
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
rc = sqlite3_open_v2(database, &self->db,
|
||||
rc = sqlite3_open_v2(database, &db,
|
||||
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE |
|
||||
(uri ? SQLITE_OPEN_URI : 0), NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
(void)sqlite3_busy_timeout(db, (int)(timeout*1000));
|
||||
}
|
||||
Py_END_ALLOW_THREADS
|
||||
|
||||
if (self->db == NULL && rc == SQLITE_NOMEM) {
|
||||
if (db == NULL && rc == SQLITE_NOMEM) {
|
||||
PyErr_NoMemory();
|
||||
return -1;
|
||||
}
|
||||
|
||||
pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
|
||||
if (rc != SQLITE_OK) {
|
||||
_pysqlite_seterror(state, self->db);
|
||||
_pysqlite_seterror(state, db);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (isolation_level) {
|
||||
const char *stmt = get_begin_statement(isolation_level);
|
||||
if (stmt == NULL) {
|
||||
// Convert isolation level to begin statement.
|
||||
const char *begin_statement = NULL;
|
||||
if (isolation_level != NULL) {
|
||||
begin_statement = get_begin_statement(isolation_level);
|
||||
if (begin_statement == NULL) {
|
||||
return -1;
|
||||
}
|
||||
self->begin_statement = stmt;
|
||||
}
|
||||
else {
|
||||
self->begin_statement = NULL;
|
||||
}
|
||||
|
||||
self->statement_cache = new_statement_cache(self, cached_statements);
|
||||
if (self->statement_cache == NULL) {
|
||||
return -1;
|
||||
}
|
||||
if (PyErr_Occurred()) {
|
||||
// Create LRU statement cache; returns a new reference.
|
||||
PyObject *statement_cache = new_statement_cache(self, state, cache_size);
|
||||
if (statement_cache == NULL) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
self->created_cursors = 0;
|
||||
|
||||
/* Create list of weak references to cursors */
|
||||
self->cursors = PyList_New(0);
|
||||
if (self->cursors == NULL) {
|
||||
// Create list of weak references to cursors.
|
||||
PyObject *cursors = PyList_New(0);
|
||||
if (cursors == NULL) {
|
||||
Py_DECREF(statement_cache);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Init connection state members.
|
||||
self->db = db;
|
||||
self->state = state;
|
||||
self->detect_types = detect_types;
|
||||
(void)sqlite3_busy_timeout(self->db, (int)(timeout*1000));
|
||||
self->thread_ident = PyThread_get_thread_ident();
|
||||
self->begin_statement = begin_statement;
|
||||
self->check_same_thread = check_same_thread;
|
||||
self->thread_ident = PyThread_get_thread_ident();
|
||||
self->statement_cache = statement_cache;
|
||||
self->cursors = cursors;
|
||||
self->created_cursors = 0;
|
||||
self->row_factory = Py_NewRef(Py_None);
|
||||
self->text_factory = Py_NewRef(&PyUnicode_Type);
|
||||
self->trace_ctx = NULL;
|
||||
self->progress_ctx = NULL;
|
||||
self->authorizer_ctx = NULL;
|
||||
|
||||
set_callback_context(&self->trace_ctx, NULL);
|
||||
set_callback_context(&self->progress_ctx, NULL);
|
||||
set_callback_context(&self->authorizer_ctx, NULL);
|
||||
|
||||
// Borrowed refs
|
||||
self->Warning = state->Warning;
|
||||
self->Error = state->Error;
|
||||
self->InterfaceError = state->InterfaceError;
|
||||
|
@ -250,7 +256,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
|
|||
}
|
||||
|
||||
self->initialized = 1;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -321,16 +326,6 @@ connection_clear(pysqlite_Connection *self)
|
|||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
connection_close(pysqlite_Connection *self)
|
||||
{
|
||||
if (self->db) {
|
||||
int rc = sqlite3_close_v2(self->db);
|
||||
assert(rc == SQLITE_OK), (void)rc;
|
||||
self->db = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
free_callback_contexts(pysqlite_Connection *self)
|
||||
{
|
||||
|
@ -339,6 +334,22 @@ free_callback_contexts(pysqlite_Connection *self)
|
|||
set_callback_context(&self->authorizer_ctx, NULL);
|
||||
}
|
||||
|
||||
static void
|
||||
connection_close(pysqlite_Connection *self)
|
||||
{
|
||||
if (self->db) {
|
||||
free_callback_contexts(self);
|
||||
|
||||
sqlite3 *db = self->db;
|
||||
self->db = NULL;
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
int rc = sqlite3_close_v2(db);
|
||||
assert(rc == SQLITE_OK), (void)rc;
|
||||
Py_END_ALLOW_THREADS
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
connection_dealloc(pysqlite_Connection *self)
|
||||
{
|
||||
|
@ -348,7 +359,6 @@ connection_dealloc(pysqlite_Connection *self)
|
|||
|
||||
/* Clean up if user has not called .close() explicitly. */
|
||||
connection_close(self);
|
||||
free_callback_contexts(self);
|
||||
|
||||
tp->tp_free(self);
|
||||
Py_DECREF(tp);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue