bpo-45828: Use unraisable exceptions within sqlite3 callbacks (FH-29591)

This commit is contained in:
Erlend Egeberg Aasland 2021-11-29 16:22:32 +01:00 committed by GitHub
parent 6ac3c8a314
commit c4a69a4ad0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 34 deletions

View file

@ -329,9 +329,27 @@ Module functions and constants
By default you will not get any tracebacks in user-defined functions, By default you will not get any tracebacks in user-defined functions,
aggregates, converters, authorizer callbacks etc. If you want to debug them, aggregates, converters, authorizer callbacks etc. If you want to debug them,
you can call this function with *flag* set to ``True``. Afterwards, you will you can call this function with *flag* set to :const:`True`. Afterwards, you
get tracebacks from callbacks on ``sys.stderr``. Use :const:`False` to will get tracebacks from callbacks on :data:`sys.stderr`. Use :const:`False`
disable the feature again. to disable the feature again.
Register an :func:`unraisable hook handler <sys.unraisablehook>` for an
improved debug experience::
>>> import sqlite3
>>> sqlite3.enable_callback_tracebacks(True)
>>> cx = sqlite3.connect(":memory:")
>>> cx.set_trace_callback(lambda stmt: 5/0)
>>> cx.execute("select 1")
Exception ignored in: <function <lambda> at 0x10b4e3ee0>
Traceback (most recent call last):
File "<stdin>", line 1, in <lambda>
ZeroDivisionError: division by zero
>>> import sys
>>> sys.unraisablehook = lambda unraisable: print(unraisable)
>>> cx.execute("select 1")
UnraisableHookArgs(exc_type=<class 'ZeroDivisionError'>, exc_value=ZeroDivisionError('division by zero'), exc_traceback=<traceback object at 0x10b559900>, err_msg=None, object=<function <lambda> at 0x10b4e3ee0>)
<sqlite3.Cursor object at 0x10b1fe840>
.. _sqlite3-connection-objects: .. _sqlite3-connection-objects:

View file

@ -248,7 +248,6 @@ sqlite3
(Contributed by Aviv Palivoda, Daniel Shahaf, and Erlend E. Aasland in (Contributed by Aviv Palivoda, Daniel Shahaf, and Erlend E. Aasland in
:issue:`16379` and :issue:`24139`.) :issue:`16379` and :issue:`24139`.)
* Add :meth:`~sqlite3.Connection.setlimit` and * Add :meth:`~sqlite3.Connection.setlimit` and
:meth:`~sqlite3.Connection.getlimit` to :class:`sqlite3.Connection` for :meth:`~sqlite3.Connection.getlimit` to :class:`sqlite3.Connection` for
setting and getting SQLite limits by connection basis. setting and getting SQLite limits by connection basis.
@ -258,6 +257,12 @@ sqlite3
threading mode the underlying SQLite library has been compiled with. threading mode the underlying SQLite library has been compiled with.
(Contributed by Erlend E. Aasland in :issue:`45613`.) (Contributed by Erlend E. Aasland in :issue:`45613`.)
* :mod:`sqlite3` C callbacks now use unraisable exceptions if callback
tracebacks are enabled. Users can now register an
:func:`unraisable hook handler <sys.unraisablehook>` to improve their debug
experience.
(Contributed by Erlend E. Aasland in :issue:`45828`.)
threading threading
--------- ---------

View file

@ -197,7 +197,7 @@ class ProgressTests(unittest.TestCase):
con.execute("select 1 union select 2 union select 3").fetchall() con.execute("select 1 union select 2 union select 3").fetchall()
self.assertEqual(action, 0, "progress handler was not cleared") self.assertEqual(action, 0, "progress handler was not cleared")
@with_tracebacks(['bad_progress', 'ZeroDivisionError']) @with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler(self): def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
def bad_progress(): def bad_progress():
@ -208,7 +208,7 @@ class ProgressTests(unittest.TestCase):
create table foo(a, b) create table foo(a, b)
""") """)
@with_tracebacks(['__bool__', 'ZeroDivisionError']) @with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler_result(self): def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
class BadBool: class BadBool:

View file

@ -25,25 +25,25 @@ import contextlib
import functools import functools
import gc import gc
import io import io
import re
import sys import sys
import unittest import unittest
import unittest.mock import unittest.mock
import sqlite3 as sqlite import sqlite3 as sqlite
from test.support import bigmemtest from test.support import bigmemtest, catch_unraisable_exception
from .test_dbapi import cx_limit from .test_dbapi import cx_limit
def with_tracebacks(strings, traceback=True): def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks.""" """Convenience decorator for testing callback tracebacks."""
if traceback:
strings.append('Traceback')
def decorator(func): def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func) @functools.wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
with catch_unraisable_exception() as cm:
# First, run the test with traceback enabled. # First, run the test with traceback enabled.
with check_tracebacks(self, strings): with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs) func(self, *args, **kwargs)
# Then run the test with traceback disabled. # Then run the test with traceback disabled.
@ -51,20 +51,26 @@ def with_tracebacks(strings, traceback=True):
return wrapper return wrapper
return decorator return decorator
@contextlib.contextmanager @contextlib.contextmanager
def check_tracebacks(self, strings): def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks.""" """Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True) sqlite.enable_callback_tracebacks(True)
try: try:
buf = io.StringIO() buf = io.StringIO()
with contextlib.redirect_stderr(buf): with contextlib.redirect_stderr(buf):
yield yield
tb = buf.getvalue()
for s in strings: self.assertEqual(cm.unraisable.exc_type, exc)
self.assertIn(s, tb) if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally: finally:
sqlite.enable_callback_tracebacks(False) sqlite.enable_callback_tracebacks(False)
def func_returntext(): def func_returntext():
return "foo" return "foo"
def func_returntextwithnull(): def func_returntextwithnull():
@ -299,7 +305,7 @@ class FunctionTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1<<31) self.assertEqual(val, 1<<31)
@with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError']) @with_tracebacks(ZeroDivisionError, name="func_raiseexception")
def test_func_exception(self): def test_func_exception(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -307,14 +313,14 @@ class FunctionTests(unittest.TestCase):
cur.fetchone() cur.fetchone()
self.assertEqual(str(cm.exception), 'user-defined function raised exception') self.assertEqual(str(cm.exception), 'user-defined function raised exception')
@with_tracebacks(['func_memoryerror', 'MemoryError']) @with_tracebacks(MemoryError, name="func_memoryerror")
def test_func_memory_error(self): def test_func_memory_error(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(MemoryError): with self.assertRaises(MemoryError):
cur.execute("select memoryerror()") cur.execute("select memoryerror()")
cur.fetchone() cur.fetchone()
@with_tracebacks(['func_overflowerror', 'OverflowError']) @with_tracebacks(OverflowError, name="func_overflowerror")
def test_func_overflow_error(self): def test_func_overflow_error(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.DataError): with self.assertRaises(sqlite.DataError):
@ -426,20 +432,19 @@ class FunctionTests(unittest.TestCase):
del x,y del x,y
gc.collect() gc.collect()
@with_tracebacks(OverflowError)
def test_func_return_too_large_int(self): def test_func_return_too_large_int(self):
cur = self.con.cursor() cur = self.con.cursor()
for value in 2**63, -2**63-1, 2**64: for value in 2**63, -2**63-1, 2**64:
self.con.create_function("largeint", 0, lambda value=value: value) self.con.create_function("largeint", 0, lambda value=value: value)
with check_tracebacks(self, ['OverflowError']):
with self.assertRaises(sqlite.DataError): with self.assertRaises(sqlite.DataError):
cur.execute("select largeint()") cur.execute("select largeint()")
@with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr")
def test_func_return_text_with_surrogates(self): def test_func_return_text_with_surrogates(self):
cur = self.con.cursor() cur = self.con.cursor()
self.con.create_function("pychr", 1, chr) self.con.create_function("pychr", 1, chr)
for value in 0xd8ff, 0xdcff: for value in 0xd8ff, 0xdcff:
with check_tracebacks(self,
['UnicodeEncodeError', 'surrogates not allowed']):
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
cur.execute("select pychr(?)", (value,)) cur.execute("select pychr(?)", (value,))
@ -510,7 +515,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
@with_tracebacks(['__init__', '5/0', 'ZeroDivisionError']) @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
def test_aggr_exception_in_init(self): def test_aggr_exception_in_init(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -518,7 +523,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
@with_tracebacks(['step', '5/0', 'ZeroDivisionError']) @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep")
def test_aggr_exception_in_step(self): def test_aggr_exception_in_step(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -526,7 +531,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
@with_tracebacks(['finalize', '5/0', 'ZeroDivisionError']) @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize")
def test_aggr_exception_in_finalize(self): def test_aggr_exception_in_finalize(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -643,11 +648,11 @@ class AuthorizerRaiseExceptionTests(AuthorizerTests):
raise ValueError raise ValueError
return sqlite.SQLITE_OK return sqlite.SQLITE_OK
@with_tracebacks(['authorizer_cb', 'ValueError']) @with_tracebacks(ValueError, name="authorizer_cb")
def test_table_access(self): def test_table_access(self):
super().test_table_access() super().test_table_access()
@with_tracebacks(['authorizer_cb', 'ValueError']) @with_tracebacks(ValueError, name="authorizer_cb")
def test_column_access(self): def test_column_access(self):
super().test_table_access() super().test_table_access()

View file

@ -0,0 +1,2 @@
:mod:`sqlite` C callbacks now use unraisable exceptions if callback
tracebacks are enabled. Patch by Erlend E. Aasland.

View file

@ -691,7 +691,7 @@ print_or_clear_traceback(callback_context *ctx)
assert(ctx != NULL); assert(ctx != NULL);
assert(ctx->state != NULL); assert(ctx->state != NULL);
if (ctx->state->enable_callback_tracebacks) { if (ctx->state->enable_callback_tracebacks) {
PyErr_Print(); PyErr_WriteUnraisable(ctx->callable);
} }
else { else {
PyErr_Clear(); PyErr_Clear();