Issue #11688: Add sqlite3.Connection.set_trace_callback(). Patch by Torsten Landschoff.

This commit is contained in:
Antoine Pitrou 2011-04-04 00:12:04 +02:00
parent d7edf3b82d
commit 5bfa0622ec
4 changed files with 128 additions and 1 deletions

View file

@ -175,10 +175,56 @@ class ProgressTests(unittest.TestCase):
con.execute("select 1 union select 2 union select 3").fetchall()
self.assertEqual(action, 0, "progress handler was not cleared")
class TraceCallbackTests(unittest.TestCase):
def CheckTraceCallbackUsed(self):
"""
Test that the trace callback is invoked once it is set.
"""
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.execute("create table foo(a, b)")
self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
def CheckClearTraceCallback(self):
"""
Test that setting the trace callback to None clears the previously set callback.
"""
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.set_trace_callback(None)
con.execute("create table foo(a, b)")
self.assertFalse(traced_statements, "trace callback was not cleared")
def CheckUnicodeContent(self):
"""
Test that the statement can contain unicode literals.
"""
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.execute("create table foo(x)")
con.execute("insert into foo(x) values (?)", (unicode_value,))
con.commit()
self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
"Unicode data garbled in trace callback")
def suite():
collation_suite = unittest.makeSuite(CollationTests, "Check")
progress_suite = unittest.makeSuite(ProgressTests, "Check")
return unittest.TestSuite((collation_suite, progress_suite))
trace_suite = unittest.makeSuite(TraceCallbackTests, "Check")
return unittest.TestSuite((collation_suite, progress_suite, trace_suite))
def test():
runner = unittest.TextTestRunner()