bpo-43553: Improve sqlite3 test coverage (GH-26886)

This commit is contained in:
Erlend Egeberg Aasland 2021-06-24 13:56:56 +02:00 committed by GitHub
parent 9049ea51ec
commit 2c1ae09764
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 2 deletions

View file

@ -26,9 +26,8 @@ import sys
import threading import threading
import unittest import unittest
from test.support import check_disallow_instantiation from test.support import check_disallow_instantiation, threading_helper
from test.support.os_helper import TESTFN, unlink from test.support.os_helper import TESTFN, unlink
from test.support import threading_helper
# Helper for tests using TESTFN # Helper for tests using TESTFN
@ -110,6 +109,10 @@ class ModuleTests(unittest.TestCase):
cx = sqlite.connect(":memory:") cx = sqlite.connect(":memory:")
check_disallow_instantiation(self, type(cx("select 1"))) check_disallow_instantiation(self, type(cx("select 1")))
def test_complete_statement(self):
self.assertFalse(sqlite.complete_statement("select t"))
self.assertTrue(sqlite.complete_statement("create table t(t);"))
class ConnectionTests(unittest.TestCase): class ConnectionTests(unittest.TestCase):
@ -225,6 +228,20 @@ class ConnectionTests(unittest.TestCase):
self.assertTrue(hasattr(self.cx, exc)) self.assertTrue(hasattr(self.cx, exc))
self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc)) self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))
def test_interrupt_on_closed_db(self):
cx = sqlite.connect(":memory:")
cx.close()
with self.assertRaises(sqlite.ProgrammingError):
cx.interrupt()
def test_interrupt(self):
self.assertIsNone(self.cx.interrupt())
def test_drop_unused_refs(self):
for n in range(500):
cu = self.cx.execute(f"select {n}")
self.assertEqual(cu.fetchone()[0], n)
class OpenTests(unittest.TestCase): class OpenTests(unittest.TestCase):
_sql = "create table test(id integer)" _sql = "create table test(id integer)"
@ -594,6 +611,11 @@ class CursorTests(unittest.TestCase):
new_count = len(res.description) new_count = len(res.description)
self.assertEqual(new_count - old_count, 1) self.assertEqual(new_count - old_count, 1)
def test_same_query_in_multiple_cursors(self):
cursors = [self.cx.execute("select 1") for _ in range(3)]
for cu in cursors:
self.assertEqual(cu.fetchall(), [(1,)])
class ThreadTests(unittest.TestCase): class ThreadTests(unittest.TestCase):
def setUp(self): def setUp(self):

View file

@ -123,6 +123,8 @@ class RowFactoryTests(unittest.TestCase):
row[-3] row[-3]
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
row[2**1000] row[2**1000]
with self.assertRaises(IndexError):
row[complex()] # index must be int or string
def test_sqlite_row_index_unicode(self): def test_sqlite_row_index_unicode(self):
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row

View file

@ -381,6 +381,43 @@ class ObjectAdaptationTests(unittest.TestCase):
val = self.cur.fetchone()[0] val = self.cur.fetchone()[0]
self.assertEqual(type(val), float) self.assertEqual(type(val), float)
def test_missing_adapter(self):
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(1.) # No float adapter registered
def test_missing_protocol(self):
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(1, None)
def test_defect_proto(self):
class DefectProto():
def __adapt__(self):
return None
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(1., DefectProto)
def test_defect_self_adapt(self):
class DefectSelfAdapt(float):
def __conform__(self, _):
return None
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(DefectSelfAdapt(1.))
def test_custom_proto(self):
class CustomProto():
def __adapt__(self):
return "adapted"
self.assertEqual(sqlite.adapt(1., CustomProto), "adapted")
def test_adapt(self):
val = 42
self.assertEqual(float(val), sqlite.adapt(val))
def test_adapt_alt(self):
alt = "other"
self.assertEqual(alt, sqlite.adapt(1., None, alt))
@unittest.skipUnless(zlib, "requires zlib") @unittest.skipUnless(zlib, "requires zlib")
class BinaryConverterTests(unittest.TestCase): class BinaryConverterTests(unittest.TestCase):
def convert(s): def convert(s):

View file

@ -21,11 +21,36 @@
# misrepresented as being the original software. # misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution. # 3. This notice may not be removed or altered from any source distribution.
import contextlib
import functools
import io
import unittest import unittest
import unittest.mock import unittest.mock
import gc import gc
import sqlite3 as sqlite import sqlite3 as sqlite
def with_tracebacks(strings):
"""Convenience decorator for testing callback tracebacks."""
strings.append('Traceback')
def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# First, run the test with traceback enabled.
sqlite.enable_callback_tracebacks(True)
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
func(self, *args, **kwargs)
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)
# Then run the test with traceback disabled.
sqlite.enable_callback_tracebacks(False)
func(self, *args, **kwargs)
return wrapper
return decorator
def func_returntext(): def func_returntext():
return "foo" return "foo"
def func_returnunicode(): def func_returnunicode():
@ -228,6 +253,7 @@ class FunctionTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1<<31) self.assertEqual(val, 1<<31)
@with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError'])
def test_func_exception(self): def test_func_exception(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -387,6 +413,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
@with_tracebacks(['__init__', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_init(self): def test_aggr_exception_in_init(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -394,6 +421,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
@with_tracebacks(['step', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_step(self): def test_aggr_exception_in_step(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -401,6 +429,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
@with_tracebacks(['finalize', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_finalize(self): def test_aggr_exception_in_finalize(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -502,6 +531,14 @@ class AuthorizerRaiseExceptionTests(AuthorizerTests):
raise ValueError raise ValueError
return sqlite.SQLITE_OK return sqlite.SQLITE_OK
@with_tracebacks(['authorizer_cb', 'ValueError'])
def test_table_access(self):
super().test_table_access()
@with_tracebacks(['authorizer_cb', 'ValueError'])
def test_column_access(self):
super().test_table_access()
class AuthorizerIllegalTypeTests(AuthorizerTests): class AuthorizerIllegalTypeTests(AuthorizerTests):
@staticmethod @staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source): def authorizer_cb(action, arg1, arg2, dbname, source):