sqlite3: Handle strings with embedded zeros correctly

Closes #13676.
This commit is contained in:
Petri Lehtinen 2012-02-01 22:18:19 +02:00
parent fc3ba6b8fc
commit 023fe334bb
5 changed files with 54 additions and 9 deletions

View file

@ -225,6 +225,13 @@ class CursorTests(unittest.TestCase):
def CheckExecuteArgString(self):
self.cu.execute("insert into test(name) values (?)", ("Hugo",))
def CheckExecuteArgStringWithZeroByte(self):
self.cu.execute("insert into test(name) values (?)", ("Hu\x00go",))
self.cu.execute("select name from test where id=?", (self.cu.lastrowid,))
row = self.cu.fetchone()
self.assertEqual(row[0], "Hu\x00go")
def CheckExecuteWrongNoOfArgs1(self):
# too many parameters
try:

View file

@ -189,13 +189,48 @@ class TextFactoryTests(unittest.TestCase):
def tearDown(self):
self.con.close()
class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.con.execute("create table test (value text)")
self.con.execute("insert into test (value) values (?)", ("a\x00b",))
def CheckString(self):
# text_factory defaults to str
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), str)
self.assertEqual(row[0], "a\x00b")
def CheckBytes(self):
self.con.text_factory = bytes
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b")
def CheckBytearray(self):
self.con.text_factory = bytearray
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytearray)
self.assertEqual(row[0], b"a\x00b")
def CheckCustom(self):
# A custom factory should receive a bytes argument
self.con.text_factory = lambda x: x
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b")
def tearDown(self):
self.con.close()
def suite():
connection_suite = unittest.makeSuite(ConnectionFactoryTests, "Check")
cursor_suite = unittest.makeSuite(CursorFactoryTests, "Check")
row_suite_compat = unittest.makeSuite(RowFactoryTestsBackwardsCompat, "Check")
row_suite = unittest.makeSuite(RowFactoryTests, "Check")
text_suite = unittest.makeSuite(TextFactoryTests, "Check")
return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite))
text_zero_bytes_suite = unittest.makeSuite(TextFactoryTestsWithEmbeddedZeroBytes, "Check")
return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite, text_zero_bytes_suite))
def test():
runner = unittest.TextTestRunner()