mirror of
https://github.com/python/cpython.git
synced 2025-11-02 03:01:58 +00:00
gh-79097: Add support for aggregate window functions in sqlite3 (GH-20903)
This commit is contained in:
parent
f45aa8f304
commit
9ebcece82f
10 changed files with 477 additions and 13 deletions
|
|
@ -1084,6 +1084,8 @@ class ThreadTests(unittest.TestCase):
|
|||
if hasattr(sqlite.Connection, "serialize"):
|
||||
fns.append(lambda: self.con.serialize())
|
||||
fns.append(lambda: self.con.deserialize(b""))
|
||||
if sqlite.sqlite_version_info >= (3, 25, 0):
|
||||
fns.append(lambda: self.con.create_window_function("foo", 0, None))
|
||||
|
||||
for fn in fns:
|
||||
with self.subTest(fn=fn):
|
||||
|
|
|
|||
|
|
@ -27,9 +27,9 @@ import io
|
|||
import re
|
||||
import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import sqlite3 as sqlite
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from test.support import bigmemtest, catch_unraisable_exception, gc_collect
|
||||
|
||||
from test.test_sqlite3.test_dbapi import cx_limit
|
||||
|
|
@ -393,7 +393,7 @@ class FunctionTests(unittest.TestCase):
|
|||
# indices, which allows testing based on syntax, iso. the query optimizer.
|
||||
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
|
||||
def test_func_non_deterministic(self):
|
||||
mock = unittest.mock.Mock(return_value=None)
|
||||
mock = Mock(return_value=None)
|
||||
self.con.create_function("nondeterministic", 0, mock, deterministic=False)
|
||||
if sqlite.sqlite_version_info < (3, 15, 0):
|
||||
self.con.execute("select nondeterministic() = nondeterministic()")
|
||||
|
|
@ -404,7 +404,7 @@ class FunctionTests(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
|
||||
def test_func_deterministic(self):
|
||||
mock = unittest.mock.Mock(return_value=None)
|
||||
mock = Mock(return_value=None)
|
||||
self.con.create_function("deterministic", 0, mock, deterministic=True)
|
||||
if sqlite.sqlite_version_info < (3, 15, 0):
|
||||
self.con.execute("select deterministic() = deterministic()")
|
||||
|
|
@ -482,6 +482,164 @@ class FunctionTests(unittest.TestCase):
|
|||
self.con.execute, "select badreturn()")
|
||||
|
||||
|
||||
class WindowSumInt:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def step(self, value):
|
||||
self.count += value
|
||||
|
||||
def value(self):
|
||||
return self.count
|
||||
|
||||
def inverse(self, value):
|
||||
self.count -= value
|
||||
|
||||
def finalize(self):
|
||||
return self.count
|
||||
|
||||
class BadWindow(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
|
||||
"Requires SQLite 3.25.0 or newer")
|
||||
class WindowFunctionTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
self.cur = self.con.cursor()
|
||||
|
||||
# Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
|
||||
values = [
|
||||
("a", 4),
|
||||
("b", 5),
|
||||
("c", 3),
|
||||
("d", 8),
|
||||
("e", 1),
|
||||
]
|
||||
with self.con:
|
||||
self.con.execute("create table test(x, y)")
|
||||
self.con.executemany("insert into test values(?, ?)", values)
|
||||
self.expected = [
|
||||
("a", 9),
|
||||
("b", 12),
|
||||
("c", 16),
|
||||
("d", 12),
|
||||
("e", 9),
|
||||
]
|
||||
self.query = """
|
||||
select x, %s(y) over (
|
||||
order by x rows between 1 preceding and 1 following
|
||||
) as sum_y
|
||||
from test order by x
|
||||
"""
|
||||
self.con.create_window_function("sumint", 1, WindowSumInt)
|
||||
|
||||
def test_win_sum_int(self):
|
||||
self.cur.execute(self.query % "sumint")
|
||||
self.assertEqual(self.cur.fetchall(), self.expected)
|
||||
|
||||
def test_win_error_on_create(self):
|
||||
self.assertRaises(sqlite.ProgrammingError,
|
||||
self.con.create_window_function,
|
||||
"shouldfail", -100, WindowSumInt)
|
||||
|
||||
@with_tracebacks(BadWindow)
|
||||
def test_win_exception_in_method(self):
|
||||
for meth in "__init__", "step", "value", "inverse":
|
||||
with self.subTest(meth=meth):
|
||||
with patch.object(WindowSumInt, meth, side_effect=BadWindow):
|
||||
name = f"exc_{meth}"
|
||||
self.con.create_window_function(name, 1, WindowSumInt)
|
||||
msg = f"'{meth}' method raised error"
|
||||
with self.assertRaisesRegex(sqlite.OperationalError, msg):
|
||||
self.cur.execute(self.query % name)
|
||||
self.cur.fetchall()
|
||||
|
||||
@with_tracebacks(BadWindow)
|
||||
def test_win_exception_in_finalize(self):
|
||||
# Note: SQLite does not (as of version 3.38.0) propagate finalize
|
||||
# callback errors to sqlite3_step(); this implies that OperationalError
|
||||
# is _not_ raised.
|
||||
with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
|
||||
name = f"exception_in_finalize"
|
||||
self.con.create_window_function(name, 1, WindowSumInt)
|
||||
self.cur.execute(self.query % name)
|
||||
self.cur.fetchall()
|
||||
|
||||
@with_tracebacks(AttributeError)
|
||||
def test_win_missing_method(self):
|
||||
class MissingValue:
|
||||
def step(self, x): pass
|
||||
def inverse(self, x): pass
|
||||
def finalize(self): return 42
|
||||
|
||||
class MissingInverse:
|
||||
def step(self, x): pass
|
||||
def value(self): return 42
|
||||
def finalize(self): return 42
|
||||
|
||||
class MissingStep:
|
||||
def value(self): return 42
|
||||
def inverse(self, x): pass
|
||||
def finalize(self): return 42
|
||||
|
||||
dataset = (
|
||||
("step", MissingStep),
|
||||
("value", MissingValue),
|
||||
("inverse", MissingInverse),
|
||||
)
|
||||
for meth, cls in dataset:
|
||||
with self.subTest(meth=meth, cls=cls):
|
||||
name = f"exc_{meth}"
|
||||
self.con.create_window_function(name, 1, cls)
|
||||
with self.assertRaisesRegex(sqlite.OperationalError,
|
||||
f"'{meth}' method not defined"):
|
||||
self.cur.execute(self.query % name)
|
||||
self.cur.fetchall()
|
||||
|
||||
@with_tracebacks(AttributeError)
|
||||
def test_win_missing_finalize(self):
|
||||
# Note: SQLite does not (as of version 3.38.0) propagate finalize
|
||||
# callback errors to sqlite3_step(); this implies that OperationalError
|
||||
# is _not_ raised.
|
||||
class MissingFinalize:
|
||||
def step(self, x): pass
|
||||
def value(self): return 42
|
||||
def inverse(self, x): pass
|
||||
|
||||
name = "missing_finalize"
|
||||
self.con.create_window_function(name, 1, MissingFinalize)
|
||||
self.cur.execute(self.query % name)
|
||||
self.cur.fetchall()
|
||||
|
||||
def test_win_clear_function(self):
|
||||
self.con.create_window_function("sumint", 1, None)
|
||||
self.assertRaises(sqlite.OperationalError, self.cur.execute,
|
||||
self.query % "sumint")
|
||||
|
||||
def test_win_redefine_function(self):
|
||||
# Redefine WindowSumInt; adjust the expected results accordingly.
|
||||
class Redefined(WindowSumInt):
|
||||
def step(self, value): self.count += value * 2
|
||||
def inverse(self, value): self.count -= value * 2
|
||||
expected = [(v[0], v[1]*2) for v in self.expected]
|
||||
|
||||
self.con.create_window_function("sumint", 1, Redefined)
|
||||
self.cur.execute(self.query % "sumint")
|
||||
self.assertEqual(self.cur.fetchall(), expected)
|
||||
|
||||
def test_win_error_value_return(self):
|
||||
class ErrorValueReturn:
|
||||
def __init__(self): pass
|
||||
def step(self, x): pass
|
||||
def value(self): return 1 << 65
|
||||
|
||||
self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
|
||||
self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
|
||||
self.cur.execute, self.query % "err_val_ret")
|
||||
|
||||
|
||||
class AggregateTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.con = sqlite.connect(":memory:")
|
||||
|
|
@ -527,10 +685,10 @@ class AggregateTests(unittest.TestCase):
|
|||
|
||||
def test_aggr_no_finalize(self):
|
||||
cur = self.con.cursor()
|
||||
with self.assertRaises(sqlite.OperationalError) as cm:
|
||||
msg = "user-defined aggregate's 'finalize' method not defined"
|
||||
with self.assertRaisesRegex(sqlite.OperationalError, msg):
|
||||
cur.execute("select nofinalize(t) from test")
|
||||
val = cur.fetchone()[0]
|
||||
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
|
||||
|
||||
@with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
|
||||
def test_aggr_exception_in_init(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue