bpo-44839: Raise more specific errors in sqlite3 (GH-27613)

MemoryError raised in user-defined functions will now preserve
its type. OverflowError will now be converted to DataError.
Previously both were converted to OperationalError.
This commit is contained in:
Serhiy Storchaka 2021-08-06 21:28:47 +03:00 committed by GitHub
parent 738ac00a08
commit 7d747f26e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 37 deletions

View file

@ -23,12 +23,16 @@
import contextlib import contextlib
import functools import functools
import gc
import io import io
import sys
import unittest import unittest
import unittest.mock import unittest.mock
import gc
import sqlite3 as sqlite import sqlite3 as sqlite
from test.support import bigmemtest
def with_tracebacks(strings): def with_tracebacks(strings):
"""Convenience decorator for testing callback tracebacks.""" """Convenience decorator for testing callback tracebacks."""
strings.append('Traceback') strings.append('Traceback')
@ -69,6 +73,10 @@ def func_returnlonglong():
return 1<<31 return 1<<31
def func_raiseexception(): def func_raiseexception():
5/0 5/0
def func_memoryerror():
raise MemoryError
def func_overflowerror():
raise OverflowError
def func_isstring(v): def func_isstring(v):
return type(v) is str return type(v) is str
@ -187,6 +195,8 @@ class FunctionTests(unittest.TestCase):
self.con.create_function("returnblob", 0, func_returnblob) self.con.create_function("returnblob", 0, func_returnblob)
self.con.create_function("returnlonglong", 0, func_returnlonglong) self.con.create_function("returnlonglong", 0, func_returnlonglong)
self.con.create_function("raiseexception", 0, func_raiseexception) self.con.create_function("raiseexception", 0, func_raiseexception)
self.con.create_function("memoryerror", 0, func_memoryerror)
self.con.create_function("overflowerror", 0, func_overflowerror)
self.con.create_function("isstring", 1, func_isstring) self.con.create_function("isstring", 1, func_isstring)
self.con.create_function("isint", 1, func_isint) self.con.create_function("isint", 1, func_isint)
@ -279,6 +289,20 @@ class FunctionTests(unittest.TestCase):
cur.fetchone() cur.fetchone()
self.assertEqual(str(cm.exception), 'user-defined function raised exception') self.assertEqual(str(cm.exception), 'user-defined function raised exception')
@with_tracebacks(['func_memoryerror', 'MemoryError'])
def test_func_memory_error(self):
cur = self.con.cursor()
with self.assertRaises(MemoryError):
cur.execute("select memoryerror()")
cur.fetchone()
@with_tracebacks(['func_overflowerror', 'OverflowError'])
def test_func_overflow_error(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.DataError):
cur.execute("select overflowerror()")
cur.fetchone()
def test_param_string(self): def test_param_string(self):
cur = self.con.cursor() cur = self.con.cursor()
for text in ["foo", str()]: for text in ["foo", str()]:
@ -384,6 +408,25 @@ class FunctionTests(unittest.TestCase):
del x,y del x,y
gc.collect() gc.collect()
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_large_text(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largetext", 0, lambda size=size: "b" * size)
with self.assertRaises(sqlite.DataError):
cur.execute("select largetext()")
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=2, dry_run=False)
def test_large_blob(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
with self.assertRaises(sqlite.DataError):
cur.execute("select largeblob()")
class AggregateTests(unittest.TestCase): class AggregateTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")

View file

@ -0,0 +1,4 @@
:class:`MemoryError` raised in user-defined functions will now produce a
``MemoryError`` in :mod:`sqlite3`. :class:`OverflowError` will now be converted
to :class:`~sqlite3.DataError`. Previously
:class:`~sqlite3.OperationalError` was produced in these cases.

View file

@ -619,6 +619,29 @@ error:
return NULL; return NULL;
} }
// Checks the Python exception and sets the appropriate SQLite error code.
static void
set_sqlite_error(sqlite3_context *context, const char *msg)
{
assert(PyErr_Occurred());
if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
sqlite3_result_error_nomem(context);
}
else if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
sqlite3_result_error_toobig(context);
}
else {
sqlite3_result_error(context, msg, -1);
}
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
}
static void static void
_pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
{ {
@ -645,14 +668,7 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv
Py_DECREF(py_retval); Py_DECREF(py_retval);
} }
if (!ok) { if (!ok) {
pysqlite_state *state = pysqlite_get_state(NULL); set_sqlite_error(context, "user-defined function raised exception");
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined function raised exception", -1);
} }
PyGILState_Release(threadstate); PyGILState_Release(threadstate);
@ -676,18 +692,9 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
if (*aggregate_instance == NULL) { if (*aggregate_instance == NULL) {
*aggregate_instance = _PyObject_CallNoArg(aggregate_class); *aggregate_instance = _PyObject_CallNoArg(aggregate_class);
if (!*aggregate_instance) {
if (PyErr_Occurred()) { set_sqlite_error(context,
*aggregate_instance = 0; "user-defined aggregate's '__init__' method raised error");
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1);
goto error; goto error;
} }
} }
@ -706,14 +713,8 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
Py_DECREF(args); Py_DECREF(args);
if (!function_result) { if (!function_result) {
pysqlite_state *state = pysqlite_get_state(NULL); set_sqlite_error(context,
if (state->enable_callback_tracebacks) { "user-defined aggregate's 'step' method raised error");
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined aggregate's 'step' method raised error", -1);
} }
error: error:
@ -761,14 +762,8 @@ _pysqlite_final_callback(sqlite3_context *context)
Py_DECREF(function_result); Py_DECREF(function_result);
} }
if (!ok) { if (!ok) {
pysqlite_state *state = pysqlite_get_state(NULL); set_sqlite_error(context,
if (state->enable_callback_tracebacks) { "user-defined aggregate's 'finalize' method raised error");
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined aggregate's 'finalize' method raised error", -1);
} }
/* Restore the exception (if any) of the last call to step(), /* Restore the exception (if any) of the last call to step(),