bpo-45828: Use unraisable exceptions within sqlite3 callbacks (FH-29591)

This commit is contained in:
Erlend Egeberg Aasland 2021-11-29 16:22:32 +01:00 committed by GitHub
parent 6ac3c8a314
commit c4a69a4ad0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 34 deletions

View file

@ -197,7 +197,7 @@ class ProgressTests(unittest.TestCase):
con.execute("select 1 union select 2 union select 3").fetchall()
self.assertEqual(action, 0, "progress handler was not cleared")
@with_tracebacks(['bad_progress', 'ZeroDivisionError'])
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress():
@ -208,7 +208,7 @@ class ProgressTests(unittest.TestCase):
create table foo(a, b)
""")
@with_tracebacks(['__bool__', 'ZeroDivisionError'])
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool:

View file

@ -25,46 +25,52 @@ import contextlib
import functools
import gc
import io
import re
import sys
import unittest
import unittest.mock
import sqlite3 as sqlite
from test.support import bigmemtest
from test.support import bigmemtest, catch_unraisable_exception
from .test_dbapi import cx_limit
def with_tracebacks(strings, traceback=True):
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
if traceback:
strings.append('Traceback')
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# First, run the test with traceback enabled.
with check_tracebacks(self, strings):
func(self, *args, **kwargs)
with catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, strings):
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite.enable_callback_tracebacks(False)
def func_returntext():
return "foo"
def func_returntextwithnull():
@ -299,7 +305,7 @@ class FunctionTests(unittest.TestCase):
val = cur.fetchone()[0]
self.assertEqual(val, 1<<31)
@with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError'])
@with_tracebacks(ZeroDivisionError, name="func_raiseexception")
def test_func_exception(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@ -307,14 +313,14 @@ class FunctionTests(unittest.TestCase):
cur.fetchone()
self.assertEqual(str(cm.exception), 'user-defined function raised exception')
@with_tracebacks(['func_memoryerror', 'MemoryError'])
@with_tracebacks(MemoryError, name="func_memoryerror")
def test_func_memory_error(self):
cur = self.con.cursor()
with self.assertRaises(MemoryError):
cur.execute("select memoryerror()")
cur.fetchone()
@with_tracebacks(['func_overflowerror', 'OverflowError'])
@with_tracebacks(OverflowError, name="func_overflowerror")
def test_func_overflow_error(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.DataError):
@ -426,22 +432,21 @@ class FunctionTests(unittest.TestCase):
del x,y
gc.collect()
@with_tracebacks(OverflowError)
def test_func_return_too_large_int(self):
cur = self.con.cursor()
for value in 2**63, -2**63-1, 2**64:
self.con.create_function("largeint", 0, lambda value=value: value)
with check_tracebacks(self, ['OverflowError']):
with self.assertRaises(sqlite.DataError):
cur.execute("select largeint()")
with self.assertRaises(sqlite.DataError):
cur.execute("select largeint()")
@with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr")
def test_func_return_text_with_surrogates(self):
cur = self.con.cursor()
self.con.create_function("pychr", 1, chr)
for value in 0xd8ff, 0xdcff:
with check_tracebacks(self,
['UnicodeEncodeError', 'surrogates not allowed']):
with self.assertRaises(sqlite.OperationalError):
cur.execute("select pychr(?)", (value,))
with self.assertRaises(sqlite.OperationalError):
cur.execute("select pychr(?)", (value,))
@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
@ -510,7 +515,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
@with_tracebacks(['__init__', '5/0', 'ZeroDivisionError'])
@with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
def test_aggr_exception_in_init(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@ -518,7 +523,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
@with_tracebacks(['step', '5/0', 'ZeroDivisionError'])
@with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep")
def test_aggr_exception_in_step(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@ -526,7 +531,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
@with_tracebacks(['finalize', '5/0', 'ZeroDivisionError'])
@with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize")
def test_aggr_exception_in_finalize(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
@ -643,11 +648,11 @@ class AuthorizerRaiseExceptionTests(AuthorizerTests):
raise ValueError
return sqlite.SQLITE_OK
@with_tracebacks(['authorizer_cb', 'ValueError'])
@with_tracebacks(ValueError, name="authorizer_cb")
def test_table_access(self):
super().test_table_access()
@with_tracebacks(['authorizer_cb', 'ValueError'])
@with_tracebacks(ValueError, name="authorizer_cb")
def test_column_access(self):
super().test_table_access()