bpo-44822: Don't truncate strs with embedded NULL chars returned by sqlite3 UDF callbacks (GH-27588)

This commit is contained in:
Erlend Egeberg Aasland 2021-08-05 09:22:08 +02:00 committed by GitHub
parent 3e4cb7f40f
commit 8f010dc920
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 3 deletions

View file

@ -53,6 +53,8 @@ def with_tracebacks(strings):
def func_returntext():
return "foo"
def func_returntextwithnull():
return "1\x002"
def func_returnunicode():
return "bar"
def func_returnint():
@ -163,11 +165,21 @@ class AggrSum:
def finalize(self):
return self.val
class AggrText:
def __init__(self):
self.txt = ""
def step(self, txt):
self.txt = self.txt + txt
def finalize(self):
return self.txt
class FunctionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.con.create_function("returntext", 0, func_returntext)
self.con.create_function("returntextwithnull", 0, func_returntextwithnull)
self.con.create_function("returnunicode", 0, func_returnunicode)
self.con.create_function("returnint", 0, func_returnint)
self.con.create_function("returnfloat", 0, func_returnfloat)
@ -211,6 +223,12 @@ class FunctionTests(unittest.TestCase):
self.assertEqual(type(val), str)
self.assertEqual(val, "foo")
def test_func_return_text_with_null_char(self):
cur = self.con.cursor()
res = cur.execute("select returntextwithnull()").fetchone()[0]
self.assertEqual(type(res), str)
self.assertEqual(res, "1\x002")
def test_func_return_unicode(self):
cur = self.con.cursor()
cur.execute("select returnunicode()")
@ -390,6 +408,7 @@ class AggregateTests(unittest.TestCase):
self.con.create_aggregate("checkType", 2, AggrCheckType)
self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
self.con.create_aggregate("mysum", 1, AggrSum)
self.con.create_aggregate("aggtxt", 1, AggrText)
def tearDown(self):
#self.cur.close()
@ -486,6 +505,15 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0]
self.assertIsNone(val)
def test_aggr_text(self):
cur = self.con.cursor()
for txt in ["foo", "1\x002"]:
with self.subTest(txt=txt):
cur.execute("select aggtxt(?) from test", (txt,))
val = cur.fetchone()[0]
self.assertEqual(val, txt)
class AuthorizerTests(unittest.TestCase):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):