[3.10] gh-79579: Improve DML query detection in sqlite3 (GH-93623) (#93801)

The fix involves using pysqlite_check_remaining_sql(), not only to check
for multiple statements, but now also to strip leading comments and
whitespace from SQL statements, so we can improve DML query detection.

pysqlite_check_remaining_sql() is renamed lstrip_sql(), to more
accurately reflect its function, and hardened to handle more SQL comment
corner cases.

(cherry picked from commit 46740073ef)
This commit is contained in:
Erlend Egeberg Aasland 2022-06-14 15:05:36 +02:00 committed by GitHub
parent f9585e2adc
commit 2229d34a6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 1939 additions and 75 deletions

View file

@ -29,16 +29,7 @@
#include "util.h"
/* prototypes */
static int pysqlite_check_remaining_sql(const char* tail);
typedef enum {
LINECOMMENT_1,
IN_LINECOMMENT,
COMMENTSTART_1,
IN_COMMENT,
COMMENTEND_1,
NORMAL
} parse_remaining_sql_state;
static const char *lstrip_sql(const char *sql);
typedef enum {
TYPE_LONG,
@ -55,7 +46,6 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
int rc;
const char* sql_cstr;
Py_ssize_t sql_cstr_len;
const char* p;
assert(PyUnicode_Check(sql));
@ -87,20 +77,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
/* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */
for (p = sql_cstr; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}
const char *p = lstrip_sql(sql_cstr);
if (p != NULL) {
self->is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
break;
}
Py_BEGIN_ALLOW_THREADS
@ -118,7 +100,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
goto error;
}
if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) {
if (rc == SQLITE_OK && lstrip_sql(tail)) {
(void)sqlite3_finalize(self->st);
self->st = NULL;
PyErr_SetString(pysqlite_Warning,
@ -431,73 +413,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
}
/*
* Checks if there is anything left in an SQL string after SQLite compiled it.
* This is used to check if somebody tried to execute more than one SQL command
* with one execute()/executemany() command, which the DB-API and we don't
* allow.
* Strip leading whitespace and comments from incoming SQL (null terminated C
* string) and return a pointer to the first non-whitespace, non-comment
* character.
*
* Returns 1 if there is more left than should be. 0 if ok.
* This is used to check if somebody tries to execute more than one SQL query
* with one execute()/executemany() command, which the DB-API don't allow.
*
* It is also used to harden DML query detection.
*/
static int pysqlite_check_remaining_sql(const char* tail)
static inline const char *
lstrip_sql(const char *sql)
{
const char* pos = tail;
parse_remaining_sql_state state = NORMAL;
for (;;) {
// This loop is borrowed from the SQLite source code.
for (const char *pos = sql; *pos; pos++) {
switch (*pos) {
case 0:
return 0;
case '-':
if (state == NORMAL) {
state = LINECOMMENT_1;
} else if (state == LINECOMMENT_1) {
state = IN_LINECOMMENT;
}
break;
case ' ':
case '\t':
break;
case '\f':
case '\n':
case 13:
if (state == IN_LINECOMMENT) {
state = NORMAL;
}
case '\r':
// Skip whitespace.
break;
case '-':
// Skip line comments.
if (pos[1] == '-') {
pos += 2;
while (pos[0] && pos[0] != '\n') {
pos++;
}
if (pos[0] == '\0') {
return NULL;
}
continue;
}
return pos;
case '/':
if (state == NORMAL) {
state = COMMENTSTART_1;
} else if (state == COMMENTEND_1) {
state = NORMAL;
} else if (state == COMMENTSTART_1) {
return 1;
// Skip C style comments.
if (pos[1] == '*') {
pos += 2;
while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
pos++;
}
if (pos[0] == '\0') {
return NULL;
}
pos++;
continue;
}
break;
case '*':
if (state == NORMAL) {
return 1;
} else if (state == LINECOMMENT_1) {
return 1;
} else if (state == COMMENTSTART_1) {
state = IN_COMMENT;
} else if (state == IN_COMMENT) {
state = COMMENTEND_1;
}
break;
return pos;
default:
if (state == COMMENTEND_1) {
state = IN_COMMENT;
} else if (state == IN_LINECOMMENT) {
} else if (state == IN_COMMENT) {
} else {
return 1;
}
return pos;
}
pos++;
}
return 0;
return NULL;
}
static PyMemberDef stmt_members[] = {