bpo-44859: Improve error handling in sqlite3 and and raise more accurate exceptions. (GH-27654)

* MemoryError is now raised instead of sqlite3.Warning when
  memory is not enough for encoding a statement to UTF-8
  in Connection.__call__() and Cursor.execute().
* UnicodEncodeError is now raised instead of sqlite3.Warning when
  the statement contains surrogate characters
  in Connection.__call__() and Cursor.execute().
* TypeError is now raised instead of ValueError for non-string
  script argument in Cursor.executescript().
* ValueError is now raised for script containing the null
  character instead of truncating it in Cursor.executescript().
* Correctly handle exceptions raised when getting boolean value
  of the result of the progress handler.
* Add many tests covering different corner cases.

Co-authored-by: Erlend Egeberg Aasland <erlend.aasland@innova.no>
This commit is contained in:
Serhiy Storchaka 2021-08-08 08:49:44 +03:00 committed by GitHub
parent ebecffdb6d
commit 0eec6276fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 226 additions and 52 deletions

View file

@ -26,7 +26,7 @@ import sys
import threading import threading
import unittest import unittest
from test.support import check_disallow_instantiation, threading_helper from test.support import check_disallow_instantiation, threading_helper, bigmemtest
from test.support.os_helper import TESTFN, unlink from test.support.os_helper import TESTFN, unlink
@ -758,9 +758,35 @@ class ExtensionTests(unittest.TestCase):
def test_cursor_executescript_as_bytes(self): def test_cursor_executescript_as_bytes(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
cur = con.cursor() cur = con.cursor()
with self.assertRaises(ValueError) as cm: with self.assertRaises(TypeError):
cur.executescript(b"create table test(foo); insert into test(foo) values (5);") cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
self.assertEqual(str(cm.exception), 'script argument must be unicode.')
def test_cursor_executescript_with_null_characters(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(ValueError):
cur.executescript("""
create table a(i);\0
insert into a(i) values (5);
""")
def test_cursor_executescript_with_surrogates(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
with self.assertRaises(UnicodeEncodeError):
cur.executescript("""
create table a(s);
insert into a(s) values ('\ud8ff');
""")
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_cursor_executescript_too_large_script(self, maxsize):
con = sqlite.connect(":memory:")
cur = con.cursor()
for size in 2**31-1, 2**31:
with self.assertRaises(sqlite.DataError):
cur.executescript("create table a(s);".ljust(size))
def test_connection_execute(self): def test_connection_execute(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
@ -969,6 +995,7 @@ def suite():
CursorTests, CursorTests,
ExtensionTests, ExtensionTests,
ModuleTests, ModuleTests,
OpenTests,
SqliteOnConflictTests, SqliteOnConflictTests,
ThreadTests, ThreadTests,
UninitialisedConnectionTests, UninitialisedConnectionTests,

View file

@ -24,7 +24,7 @@ import unittest
import sqlite3 as sqlite import sqlite3 as sqlite
from test.support.os_helper import TESTFN, unlink from test.support.os_helper import TESTFN, unlink
from .userfunctions import with_tracebacks
class CollationTests(unittest.TestCase): class CollationTests(unittest.TestCase):
def test_create_collation_not_string(self): def test_create_collation_not_string(self):
@ -145,7 +145,6 @@ class ProgressTests(unittest.TestCase):
""") """)
self.assertTrue(progress_calls) self.assertTrue(progress_calls)
def test_opcode_count(self): def test_opcode_count(self):
""" """
Test that the opcode argument is respected. Test that the opcode argument is respected.
@ -198,6 +197,32 @@ class ProgressTests(unittest.TestCase):
con.execute("select 1 union select 2 union select 3").fetchall() con.execute("select 1 union select 2 union select 3").fetchall()
self.assertEqual(action, 0, "progress handler was not cleared") self.assertEqual(action, 0, "progress handler was not cleared")
@with_tracebacks(['bad_progress', 'ZeroDivisionError'])
def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress():
1 / 0
con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
create table foo(a, b)
""")
@with_tracebacks(['__bool__', 'ZeroDivisionError'])
def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool:
def __bool__(self):
1 / 0
def bad_progress():
return BadBool()
con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
create table foo(a, b)
""")
class TraceCallbackTests(unittest.TestCase): class TraceCallbackTests(unittest.TestCase):
def test_trace_callback_used(self): def test_trace_callback_used(self):
""" """

View file

@ -21,6 +21,7 @@
# 3. This notice may not be removed or altered from any source distribution. # 3. This notice may not be removed or altered from any source distribution.
import datetime import datetime
import sys
import unittest import unittest
import sqlite3 as sqlite import sqlite3 as sqlite
import weakref import weakref
@ -273,7 +274,7 @@ class RegressionTests(unittest.TestCase):
Call a connection with a non-string SQL request: check error handling Call a connection with a non-string SQL request: check error handling
of the statement constructor. of the statement constructor.
""" """
self.assertRaises(TypeError, self.con, 1) self.assertRaises(TypeError, self.con, b"select 1")
def test_collation(self): def test_collation(self):
def collation_cb(a, b): def collation_cb(a, b):
@ -344,6 +345,26 @@ class RegressionTests(unittest.TestCase):
self.assertRaises(ValueError, cur.execute, " \0select 2") self.assertRaises(ValueError, cur.execute, " \0select 2")
self.assertRaises(ValueError, cur.execute, "select 2\0") self.assertRaises(ValueError, cur.execute, "select 2\0")
def test_surrogates(self):
con = sqlite.connect(":memory:")
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
cur = con.cursor()
self.assertRaises(UnicodeEncodeError, cur.execute, "select '\ud8ff'")
self.assertRaises(UnicodeEncodeError, cur.execute, "select '\udcff'")
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=4, dry_run=False)
def test_large_sql(self, maxsize):
# Test two cases: size+1 > INT_MAX and size+1 <= INT_MAX.
for size in (2**31, 2**31-2):
con = sqlite.connect(":memory:")
sql = "select 1".ljust(size)
self.assertRaises(sqlite.DataError, con, sql)
cur = con.cursor()
self.assertRaises(sqlite.DataError, cur.execute, sql)
del sql
def test_commit_cursor_reset(self): def test_commit_cursor_reset(self):
""" """
Connection.commit() did reset cursors, which made sqlite3 Connection.commit() did reset cursors, which made sqlite3

View file

@ -23,11 +23,14 @@
import datetime import datetime
import unittest import unittest
import sqlite3 as sqlite import sqlite3 as sqlite
import sys
try: try:
import zlib import zlib
except ImportError: except ImportError:
zlib = None zlib = None
from test import support
class SqliteTypeTests(unittest.TestCase): class SqliteTypeTests(unittest.TestCase):
def setUp(self): def setUp(self):
@ -45,6 +48,12 @@ class SqliteTypeTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich") self.assertEqual(row[0], "Österreich")
def test_string_with_null_character(self):
self.cur.execute("insert into test(s) values (?)", ("a\0b",))
self.cur.execute("select s from test")
row = self.cur.fetchone()
self.assertEqual(row[0], "a\0b")
def test_small_int(self): def test_small_int(self):
self.cur.execute("insert into test(i) values (?)", (42,)) self.cur.execute("insert into test(i) values (?)", (42,))
self.cur.execute("select i from test") self.cur.execute("select i from test")
@ -52,7 +61,7 @@ class SqliteTypeTests(unittest.TestCase):
self.assertEqual(row[0], 42) self.assertEqual(row[0], 42)
def test_large_int(self): def test_large_int(self):
num = 2**40 num = 123456789123456789
self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test") self.cur.execute("select i from test")
row = self.cur.fetchone() row = self.cur.fetchone()
@ -78,6 +87,45 @@ class SqliteTypeTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich") self.assertEqual(row[0], "Österreich")
def test_too_large_int(self):
for value in 2**63, -2**63-1, 2**64:
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(i) values (?)", (value,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertIsNone(row)
def test_string_with_surrogates(self):
for value in 0xd8ff, 0xdcff:
with self.assertRaises(UnicodeEncodeError):
self.cur.execute("insert into test(s) values (?)", (chr(value),))
self.cur.execute("select s from test")
row = self.cur.fetchone()
self.assertIsNone(row)
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=4, dry_run=False)
def test_too_large_string(self, maxsize):
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(s) values (?)", ('x'*(2**31-1),))
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(s) values (?)", ('x'*(2**31),))
self.cur.execute("select 1 from test")
row = self.cur.fetchone()
self.assertIsNone(row)
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@support.bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_too_large_blob(self, maxsize):
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31-1),))
with self.assertRaises(OverflowError):
self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31),))
self.cur.execute("select 1 from test")
row = self.cur.fetchone()
self.assertIsNone(row)
class DeclTypesTests(unittest.TestCase): class DeclTypesTests(unittest.TestCase):
class Foo: class Foo:
def __init__(self, _val): def __init__(self, _val):
@ -163,7 +211,7 @@ class DeclTypesTests(unittest.TestCase):
def test_large_int(self): def test_large_int(self):
# default # default
num = 2**40 num = 123456789123456789
self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test") self.cur.execute("select i from test")
row = self.cur.fetchone() row = self.cur.fetchone()

View file

@ -33,28 +33,37 @@ import sqlite3 as sqlite
from test.support import bigmemtest from test.support import bigmemtest
def with_tracebacks(strings): def with_tracebacks(strings, traceback=True):
"""Convenience decorator for testing callback tracebacks.""" """Convenience decorator for testing callback tracebacks."""
strings.append('Traceback') if traceback:
strings.append('Traceback')
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
# First, run the test with traceback enabled. # First, run the test with traceback enabled.
sqlite.enable_callback_tracebacks(True) with check_tracebacks(self, strings):
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
func(self, *args, **kwargs) func(self, *args, **kwargs)
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)
# Then run the test with traceback disabled. # Then run the test with traceback disabled.
sqlite.enable_callback_tracebacks(False)
func(self, *args, **kwargs) func(self, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@contextlib.contextmanager
def check_tracebacks(self, strings):
"""Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)
finally:
sqlite.enable_callback_tracebacks(False)
def func_returntext(): def func_returntext():
return "foo" return "foo"
def func_returntextwithnull(): def func_returntextwithnull():
@ -408,9 +417,26 @@ class FunctionTests(unittest.TestCase):
del x,y del x,y
gc.collect() gc.collect()
def test_func_return_too_large_int(self):
cur = self.con.cursor()
for value in 2**63, -2**63-1, 2**64:
self.con.create_function("largeint", 0, lambda value=value: value)
with check_tracebacks(self, ['OverflowError']):
with self.assertRaises(sqlite.DataError):
cur.execute("select largeint()")
def test_func_return_text_with_surrogates(self):
cur = self.con.cursor()
self.con.create_function("pychr", 1, chr)
for value in 0xd8ff, 0xdcff:
with check_tracebacks(self,
['UnicodeEncodeError', 'surrogates not allowed']):
with self.assertRaises(sqlite.OperationalError):
cur.execute("select pychr(?)", (value,))
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False) @bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_large_text(self, size): def test_func_return_too_large_text(self, size):
cur = self.con.cursor() cur = self.con.cursor()
for size in 2**31-1, 2**31: for size in 2**31-1, 2**31:
self.con.create_function("largetext", 0, lambda size=size: "b" * size) self.con.create_function("largetext", 0, lambda size=size: "b" * size)
@ -419,7 +445,7 @@ class FunctionTests(unittest.TestCase):
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=2, dry_run=False) @bigmemtest(size=2**31, memuse=2, dry_run=False)
def test_large_blob(self, size): def test_func_return_too_large_blob(self, size):
cur = self.con.cursor() cur = self.con.cursor()
for size in 2**31-1, 2**31: for size in 2**31-1, 2**31:
self.con.create_function("largeblob", 0, lambda size=size: b"b" * size) self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)

View file

@ -0,0 +1,8 @@
Improve error handling in :mod:`sqlite3` and raise more accurate exceptions.
* :exc:`MemoryError` is now raised instead of :exc:`sqlite3.Warning` when memory is not enough for encoding a statement to UTF-8 in ``Connection.__call__()`` and ``Cursor.execute()``.
* :exc:`UnicodEncodeError` is now raised instead of :exc:`sqlite3.Warning` when the statement contains surrogate characters in ``Connection.__call__()`` and ``Cursor.execute()``.
* :exc:`TypeError` is now raised instead of :exc:`ValueError` for non-string script argument in ``Cursor.executescript()``.
* :exc:`ValueError` is now raised for script containing the null character instead of truncating it in ``Cursor.executescript()``.
* Correctly handle exceptions raised when getting boolean value of the result of the progress handler.
* Add many tests covering different corner cases.

View file

@ -119,6 +119,35 @@ PyDoc_STRVAR(pysqlite_cursor_executescript__doc__,
#define PYSQLITE_CURSOR_EXECUTESCRIPT_METHODDEF \ #define PYSQLITE_CURSOR_EXECUTESCRIPT_METHODDEF \
{"executescript", (PyCFunction)pysqlite_cursor_executescript, METH_O, pysqlite_cursor_executescript__doc__}, {"executescript", (PyCFunction)pysqlite_cursor_executescript, METH_O, pysqlite_cursor_executescript__doc__},
static PyObject *
pysqlite_cursor_executescript_impl(pysqlite_Cursor *self,
const char *sql_script);
static PyObject *
pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *arg)
{
PyObject *return_value = NULL;
const char *sql_script;
if (!PyUnicode_Check(arg)) {
_PyArg_BadArgument("executescript", "argument", "str", arg);
goto exit;
}
Py_ssize_t sql_script_length;
sql_script = PyUnicode_AsUTF8AndSize(arg, &sql_script_length);
if (sql_script == NULL) {
goto exit;
}
if (strlen(sql_script) != (size_t)sql_script_length) {
PyErr_SetString(PyExc_ValueError, "embedded null character");
goto exit;
}
return_value = pysqlite_cursor_executescript_impl(self, sql_script);
exit:
return return_value;
}
PyDoc_STRVAR(pysqlite_cursor_fetchone__doc__, PyDoc_STRVAR(pysqlite_cursor_fetchone__doc__,
"fetchone($self, /)\n" "fetchone($self, /)\n"
"--\n" "--\n"
@ -270,4 +299,4 @@ pysqlite_cursor_close(pysqlite_Cursor *self, PyTypeObject *cls, PyObject *const
exit: exit:
return return_value; return return_value;
} }
/*[clinic end generated code: output=7b216aba2439f5cf input=a9049054013a1b77]*/ /*[clinic end generated code: output=ace31a7481aa3f41 input=a9049054013a1b77]*/

View file

@ -997,6 +997,14 @@ static int _progress_handler(void* user_arg)
ret = _PyObject_CallNoArg((PyObject*)user_arg); ret = _PyObject_CallNoArg((PyObject*)user_arg);
if (!ret) { if (!ret) {
/* abort query if error occurred */
rc = -1;
}
else {
rc = PyObject_IsTrue(ret);
Py_DECREF(ret);
}
if (rc < 0) {
pysqlite_state *state = pysqlite_get_state(NULL); pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) { if (state->enable_callback_tracebacks) {
PyErr_Print(); PyErr_Print();
@ -1004,12 +1012,6 @@ static int _progress_handler(void* user_arg)
else { else {
PyErr_Clear(); PyErr_Clear();
} }
/* abort query if error occurred */
rc = 1;
} else {
rc = (int)PyObject_IsTrue(ret);
Py_DECREF(ret);
} }
PyGILState_Release(gilstate); PyGILState_Release(gilstate);

View file

@ -728,21 +728,21 @@ pysqlite_cursor_executemany_impl(pysqlite_Cursor *self, PyObject *sql,
/*[clinic input] /*[clinic input]
_sqlite3.Cursor.executescript as pysqlite_cursor_executescript _sqlite3.Cursor.executescript as pysqlite_cursor_executescript
sql_script as script_obj: object sql_script: str
/ /
Executes multiple SQL statements at once. Non-standard. Executes multiple SQL statements at once. Non-standard.
[clinic start generated code]*/ [clinic start generated code]*/
static PyObject * static PyObject *
pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) pysqlite_cursor_executescript_impl(pysqlite_Cursor *self,
/*[clinic end generated code: output=115a8132b0f200fe input=ba3ec59df205e362]*/ const char *sql_script)
/*[clinic end generated code: output=8fd726dde1c65164 input=1ac0693dc8db02a8]*/
{ {
_Py_IDENTIFIER(commit); _Py_IDENTIFIER(commit);
const char* script_cstr;
sqlite3_stmt* statement; sqlite3_stmt* statement;
int rc; int rc;
Py_ssize_t sql_len; size_t sql_len;
PyObject* result; PyObject* result;
if (!check_cursor(self)) { if (!check_cursor(self)) {
@ -751,21 +751,12 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj)
self->reset = 0; self->reset = 0;
if (PyUnicode_Check(script_obj)) { sql_len = strlen(sql_script);
script_cstr = PyUnicode_AsUTF8AndSize(script_obj, &sql_len); int max_length = sqlite3_limit(self->connection->db,
if (!script_cstr) { SQLITE_LIMIT_LENGTH, -1);
return NULL; if (sql_len >= (unsigned)max_length) {
} PyErr_SetString(self->connection->DataError,
"query string is too large");
int max_length = sqlite3_limit(self->connection->db,
SQLITE_LIMIT_LENGTH, -1);
if (sql_len >= max_length) {
PyErr_SetString(self->connection->DataError,
"query string is too large");
return NULL;
}
} else {
PyErr_SetString(PyExc_ValueError, "script argument must be unicode.");
return NULL; return NULL;
} }
@ -782,7 +773,7 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj)
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
rc = sqlite3_prepare_v2(self->connection->db, rc = sqlite3_prepare_v2(self->connection->db,
script_cstr, sql_script,
(int)sql_len + 1, (int)sql_len + 1,
&statement, &statement,
&tail); &tail);
@ -816,8 +807,8 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj)
if (*tail == (char)0) { if (*tail == (char)0) {
break; break;
} }
sql_len -= (tail - script_cstr); sql_len -= (tail - sql_script);
script_cstr = tail; sql_script = tail;
} }
error: error:

View file

@ -56,9 +56,6 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
Py_ssize_t size; Py_ssize_t size;
const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size); const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size);
if (sql_cstr == NULL) { if (sql_cstr == NULL) {
PyErr_Format(connection->Warning,
"SQL is of wrong type ('%s'). Must be string.",
Py_TYPE(sql)->tp_name);
return NULL; return NULL;
} }