[3.10] gh-89301: Fix regression with bound values in traced SQLite statements (#92147)

(cherry picked from commit 721aa96540)
This commit is contained in:
Erlend Egeberg Aasland 2022-05-02 10:21:13 -06:00 committed by GitHub
parent 6712022447
commit 178d79ae67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 17 deletions

View file

@ -20,6 +20,7 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import contextlib
import unittest
import sqlite3 as sqlite
@ -200,6 +201,16 @@ class ProgressTests(unittest.TestCase):
self.assertEqual(action, 0, "progress handler was not cleared")
class TraceCallbackTests(unittest.TestCase):
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
traced = []
cx.set_trace_callback(lambda stmt: traced.append(stmt))
yield
finally:
self.assertEqual(traced, expected)
cx.set_trace_callback(None)
def test_trace_callback_used(self):
"""
Test that the trace callback is invoked once it is set.
@ -261,6 +272,21 @@ class TraceCallbackTests(unittest.TestCase):
cur.execute(queries[1])
self.assertEqual(traced_statements, queries)
def test_trace_expanded_sql(self):
expected = [
"create table t(t)",
"BEGIN ",
"insert into t values(0)",
"insert into t values(1)",
"insert into t values(2)",
"COMMIT",
]
cx = sqlite.connect(":memory:")
with self.check_stmt_trace(cx, expected):
with cx:
cx.execute("create table t(t)")
cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
def suite():
tests = [

View file

@ -0,0 +1,3 @@
Fix a regression in the :mod:`sqlite3` trace callback where bound parameters
were not expanded in the passed statement string. The regression was introduced
in Python 3.10 by :issue:`40318`. Patch by Erlend E. Aasland.

View file

@ -1050,33 +1050,65 @@ static int _progress_handler(void* user_arg)
* may change in future releases. Callback implementations should return zero
* to ensure future compatibility.
*/
static int _trace_callback(unsigned int type, void* user_arg, void* prepared_statement, void* statement_string)
static int
_trace_callback(unsigned int type, void *callable, void *stmt, void *sql)
#else
static void _trace_callback(void* user_arg, const char* statement_string)
static void
_trace_callback(void *callable, const char *sql)
#endif
{
PyObject *py_statement = NULL;
PyObject *ret = NULL;
PyGILState_STATE gilstate;
#ifdef HAVE_TRACE_V2
if (type != SQLITE_TRACE_STMT) {
return 0;
}
#endif
gilstate = PyGILState_Ensure();
py_statement = PyUnicode_DecodeUTF8(statement_string,
strlen(statement_string), "replace");
if (py_statement) {
ret = PyObject_CallOneArg((PyObject*)user_arg, py_statement);
Py_DECREF(py_statement);
PyGILState_STATE gilstate = PyGILState_Ensure();
PyObject *py_statement = NULL;
#ifdef HAVE_TRACE_V2
const char *expanded_sql = sqlite3_expanded_sql((sqlite3_stmt *)stmt);
if (expanded_sql == NULL) {
sqlite3 *db = sqlite3_db_handle((sqlite3_stmt *)stmt);
if (sqlite3_errcode(db) == SQLITE_NOMEM) {
(void)PyErr_NoMemory();
goto exit;
}
if (ret) {
Py_DECREF(ret);
PyErr_SetString(pysqlite_DataError,
"Expanded SQL string exceeds the maximum string length");
if (_pysqlite_enable_callback_tracebacks) {
PyErr_Print();
} else {
PyErr_Clear();
}
// Fall back to unexpanded sql
py_statement = PyUnicode_FromString((const char *)sql);
}
else {
py_statement = PyUnicode_FromString(expanded_sql);
sqlite3_free((void *)expanded_sql);
}
#else
if (sql == NULL) {
PyErr_SetString(pysqlite_DataError,
"Expanded SQL string exceeds the maximum string length");
if (_pysqlite_enable_callback_tracebacks) {
PyErr_Print();
} else {
PyErr_Clear();
}
goto exit;
}
py_statement = PyUnicode_FromString(sql);
#endif
if (py_statement) {
PyObject *ret = PyObject_CallOneArg((PyObject *)callable, py_statement);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}
if (PyErr_Occurred()) {
if (_pysqlite_enable_callback_tracebacks) {
PyErr_Print();
} else {
@ -1084,6 +1116,7 @@ static void _trace_callback(void* user_arg, const char* statement_string)
}
}
exit:
PyGILState_Release(gilstate);
#ifdef HAVE_TRACE_V2
return 0;