limbo/bindings/python/tests/test_database_sync_aio.py
Nikita Sivukhin 9326b0c6c0 fix tests
2025-12-19 13:24:26 +04:00

181 lines
7.3 KiB
Python

import logging
import time
import pytest
import turso.aio.sync
from .utils import TursoServer
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
@pytest.mark.asyncio
async def test_bootstrap():
with TursoServer() as server:
server.db_sql("CREATE TABLE t(x)")
server.db_sql("INSERT INTO t VALUES ('hello'), ('turso'), ('sync')")
server.db_sql("SELECT * FROM t")
conn = await turso.aio.sync.connect(":memory:", server.db_url())
rows = await (await conn.execute("SELECT * FROM t")).fetchall()
assert rows == [("hello",), ("turso",), ("sync",)]
@pytest.mark.asyncio
async def test_pull():
with TursoServer() as server:
server.db_sql("CREATE TABLE t(x)")
server.db_sql("INSERT INTO t VALUES ('hello'), ('turso'), ('sync')")
server.db_sql("SELECT * FROM t")
conn = await turso.aio.sync.connect(":memory:", server.db_url())
rows = await (await conn.execute("SELECT * FROM t")).fetchall()
assert rows == [("hello",), ("turso",), ("sync",)]
server.db_sql("INSERT INTO t VALUES ('pull works')")
rows = await (await conn.execute("SELECT * FROM t")).fetchall()
assert rows == [("hello",), ("turso",), ("sync",)]
assert await conn.pull()
rows = await (await conn.execute("SELECT * FROM t")).fetchall()
assert rows == [("hello",), ("turso",), ("sync",), ("pull works",)]
assert not await conn.pull()
@pytest.mark.asyncio
async def test_push():
with TursoServer() as server:
server.db_sql("CREATE TABLE t(x)")
server.db_sql("INSERT INTO t VALUES ('hello'), ('turso'), ('sync')")
server.db_sql("SELECT * FROM t")
conn = await turso.aio.sync.connect(":memory:", server.db_url())
rows = await (await conn.execute("SELECT * FROM t")).fetchall()
assert rows == [("hello",), ("turso",), ("sync",)]
await conn.execute("INSERT INTO t VALUES ('push works')")
await conn.commit()
r1 = server.db_sql("SELECT * FROM t")
assert r1 == [["hello"], ["turso"], ["sync"]]
await conn.push()
r2 = server.db_sql("SELECT * FROM t")
assert r2 == [["hello"], ["turso"], ["sync"], ["push works"]]
@pytest.mark.asyncio
async def test_checkpoint():
# turso.setup_logging(level=logging.DEBUG)
with TursoServer() as server:
conn = await turso.aio.sync.connect(":memory:", remote_url=server.db_url())
await conn.execute("CREATE TABLE t(x)")
await conn.commit()
for i in range(1024):
await conn.execute(f"INSERT INTO t VALUES ({i})")
await conn.commit()
stats1 = await conn.stats()
await conn.checkpoint()
stats2 = await conn.stats()
assert stats1.main_wal_size > 1024 * 1024
assert stats1.revert_wal_size == 0
assert stats2.main_wal_size == 0
assert stats2.revert_wal_size < 8 * 1024
await conn.push()
assert server.db_sql("SELECT SUM(x) FROM t") == [[f"{1024 * 1023 // 2}"]]
@pytest.mark.asyncio
async def test_partial_sync():
# turso.setup_logging(level=logging.DEBUG)
with TursoServer() as server:
server.db_sql("CREATE TABLE t(x)")
server.db_sql("INSERT INTO t SELECT randomblob(1024) FROM generate_series(1, 2000)")
conn_full = await turso.aio.sync.connect(":memory:", remote_url=server.db_url())
assert await (await conn_full.execute("SELECT LENGTH(x) FROM t LIMIT 1")).fetchall() == [(1024,)]
assert (await conn_full.stats()).network_received_bytes > 2000 * 1024
assert await (await conn_full.execute("SELECT SUM(LENGTH(x)) FROM t")).fetchall() == [(2000 * 1024,)]
conn_partial = await turso.aio.sync.connect(
":memory:",
remote_url=server.db_url(),
partial_sync_opts=turso.aio.sync.PartialSyncOpts(
bootstrap_strategy=turso.aio.sync.PartialSyncPrefixBootstrap(length=128*1024),
),
)
assert await (await conn_partial.execute("SELECT LENGTH(x) FROM t LIMIT 1")).fetchall() == [(1024,)]
assert (await conn_partial.stats()).network_received_bytes < 256 * (1024 + 10)
start = time.time()
assert await (await conn_partial.execute("SELECT SUM(LENGTH(x)) FROM t")).fetchall() == [(2000 * 1024,)]
print(time.time() - start)
assert (await conn_partial.stats()).network_received_bytes > 2000 * 1024
@pytest.mark.asyncio
async def test_partial_sync_segment_size():
# turso.setup_logging(level=logging.DEBUG)
with TursoServer() as server:
server.db_sql("CREATE TABLE t(x)")
server.db_sql("INSERT INTO t SELECT randomblob(1024) FROM generate_series(1, 256)")
conn_full = await turso.aio.sync.connect(":memory:", remote_url=server.db_url())
assert await (await conn_full.execute("SELECT LENGTH(x) FROM t LIMIT 1")).fetchall() == [(1024,)]
assert (await conn_full.stats()).network_received_bytes > 256 * 1024
assert await (await conn_full.execute("SELECT SUM(LENGTH(x)) FROM t")).fetchall() == [(256 * 1024,)]
conn_partial = await turso.aio.sync.connect(
":memory:",
remote_url=server.db_url(),
partial_sync_opts=turso.aio.sync.PartialSyncOpts(
bootstrap_strategy=turso.aio.sync.PartialSyncPrefixBootstrap(length=128*1024),
segment_size=4 * 1024,
),
)
assert await (await conn_partial.execute("SELECT LENGTH(x) FROM t LIMIT 1")).fetchall() == [(1024,)]
assert (await conn_partial.stats()).network_received_bytes < 128 * 1024 * 1.5
start = time.time()
assert await (await conn_partial.execute("SELECT SUM(LENGTH(x)) FROM t")).fetchall() == [(256 * 1024,)]
print(time.time() - start)
assert (await conn_partial.stats()).network_received_bytes > 256 * 1024
@pytest.mark.asyncio
async def test_partial_sync_prefetch():
# turso.setup_logging(level=logging.DEBUG)
with TursoServer() as server:
server.db_sql("CREATE TABLE t(x)")
server.db_sql("INSERT INTO t SELECT randomblob(1024) FROM generate_series(1, 2000)")
conn_full = await turso.aio.sync.connect(":memory:", remote_url=server.db_url())
assert await (await conn_full.execute("SELECT LENGTH(x) FROM t LIMIT 1")).fetchall() == [(1024,)]
assert (await conn_full.stats()).network_received_bytes > 2000 * 1024
assert await (await conn_full.execute("SELECT SUM(LENGTH(x)) FROM t")).fetchall() == [(2000 * 1024,)]
# turso.setup_logging(logging.DEBUG)
conn_partial = await turso.aio.sync.connect(
":memory:",
remote_url=server.db_url(),
partial_sync_opts=turso.aio.sync.PartialSyncOpts(
bootstrap_strategy=turso.aio.sync.PartialSyncPrefixBootstrap(length=128*1024),
segment_size=4 * 1024,
prefetch=True,
),
)
assert await (await conn_partial.execute("SELECT LENGTH(x) FROM t LIMIT 1")).fetchall() == [(1024,)]
assert (await conn_partial.stats()).network_received_bytes < 1200 * 1024
start = time.time()
assert await (await conn_partial.execute("SELECT SUM(LENGTH(x)) FROM t")).fetchall() == [(2000 * 1024,)]
print(time.time() - start)
assert (await conn_partial.stats()).network_received_bytes > 2000 * 1024