mirror of
https://github.com/django/django.git
synced 2025-11-18 02:56:45 +00:00
Merge da20298982 into 1ce6e78dd4
This commit is contained in:
commit
f667229336
16 changed files with 1319 additions and 44 deletions
|
|
@ -2,6 +2,7 @@ from django.core import signals
|
|||
from django.db.utils import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
DJANGO_VERSION_PICKLE_KEY,
|
||||
AsyncConnectionHandler,
|
||||
ConnectionHandler,
|
||||
ConnectionRouter,
|
||||
DatabaseError,
|
||||
|
|
@ -36,6 +37,50 @@ __all__ = [
|
|||
]
|
||||
|
||||
connections = ConnectionHandler()
|
||||
async_connections = AsyncConnectionHandler()
|
||||
|
||||
|
||||
class new_connection:
|
||||
"""
|
||||
Asynchronous context manager to instantiate new async connections.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, using=DEFAULT_DB_ALIAS):
|
||||
self.using = using
|
||||
|
||||
async def __aenter__(self):
|
||||
conn = connections.create_connection(self.using)
|
||||
if conn.features.supports_async is False:
|
||||
raise NotSupportedError(
|
||||
"The database backend does not support asynchronous execution."
|
||||
)
|
||||
|
||||
self.force_rollback = False
|
||||
if async_connections.empty is True:
|
||||
if async_connections._from_testcase is True:
|
||||
self.force_rollback = True
|
||||
self.conn = conn
|
||||
|
||||
async_connections.add_connection(self.using, self.conn)
|
||||
|
||||
await self.conn.aensure_connection()
|
||||
if self.force_rollback is True:
|
||||
await self.conn.aset_autocommit(False)
|
||||
|
||||
return self.conn
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
autocommit = await self.conn.aget_autocommit()
|
||||
if autocommit is False:
|
||||
if exc_type is None and self.force_rollback is False:
|
||||
await self.conn.acommit()
|
||||
else:
|
||||
await self.conn.arollback()
|
||||
await self.conn.aclose()
|
||||
|
||||
await async_connections.pop_connection(self.using)
|
||||
|
||||
|
||||
router = ConnectionRouter()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import time
|
|||
import warnings
|
||||
import zoneinfo
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
|
@ -39,6 +39,7 @@ class BaseDatabaseWrapper:
|
|||
ops = None
|
||||
vendor = "unknown"
|
||||
display_name = "unknown"
|
||||
|
||||
SchemaEditorClass = None
|
||||
# Classes instantiated in __init__().
|
||||
client_class = None
|
||||
|
|
@ -47,6 +48,7 @@ class BaseDatabaseWrapper:
|
|||
introspection_class = None
|
||||
ops_class = None
|
||||
validation_class = BaseDatabaseValidation
|
||||
_aconnection_pools = {}
|
||||
|
||||
queries_limit = 9000
|
||||
|
||||
|
|
@ -54,6 +56,7 @@ class BaseDatabaseWrapper:
|
|||
# Connection related attributes.
|
||||
# The underlying database connection.
|
||||
self.connection = None
|
||||
self.aconnection = None
|
||||
# `settings_dict` should be a dictionary containing keys such as
|
||||
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
|
||||
# to disambiguate it from Django settings modules.
|
||||
|
|
@ -187,25 +190,44 @@ class BaseDatabaseWrapper:
|
|||
"method."
|
||||
)
|
||||
|
||||
async def aget_database_version(self):
|
||||
"""Return a tuple of the database's version."""
|
||||
raise NotSupportedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an aget_database_version() "
|
||||
"method."
|
||||
)
|
||||
|
||||
def _validate_database_version_supported(self, db_version):
|
||||
if (
|
||||
self.features.minimum_database_version is not None
|
||||
and db_version < self.features.minimum_database_version
|
||||
):
|
||||
str_db_version = ".".join(map(str, db_version))
|
||||
min_db_version = ".".join(map(str, self.features.minimum_database_version))
|
||||
raise NotSupportedError(
|
||||
f"{self.display_name} {min_db_version} or later is required "
|
||||
f"(found {str_db_version})."
|
||||
)
|
||||
|
||||
def check_database_version_supported(self):
|
||||
"""
|
||||
Raise an error if the database version isn't supported by this
|
||||
version of Django.
|
||||
"""
|
||||
if (
|
||||
self.features.minimum_database_version is not None
|
||||
and self.get_database_version() < self.features.minimum_database_version
|
||||
):
|
||||
db_version = ".".join(map(str, self.get_database_version()))
|
||||
min_db_version = ".".join(map(str, self.features.minimum_database_version))
|
||||
raise NotSupportedError(
|
||||
f"{self.display_name} {min_db_version} or later is required "
|
||||
f"(found {db_version})."
|
||||
)
|
||||
db_version = self.get_database_version()
|
||||
self._validate_database_version_supported(db_version)
|
||||
|
||||
async def acheck_database_version_supported(self):
|
||||
"""
|
||||
Raise an error if the database version isn't supported by this
|
||||
version of Django.
|
||||
"""
|
||||
db_version = await self.aget_database_version()
|
||||
self._validate_database_version_supported(db_version)
|
||||
|
||||
# ##### Backend-specific methods for creating connections and cursors #####
|
||||
|
||||
def get_connection_params(self):
|
||||
def get_connection_params(self, for_async=False):
|
||||
"""Return a dict of parameters suitable for get_new_connection."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_connection_params() "
|
||||
|
|
@ -219,23 +241,42 @@ class BaseDatabaseWrapper:
|
|||
"method"
|
||||
)
|
||||
|
||||
async def aget_new_connection(self, conn_params):
|
||||
"""Open a connection to the database."""
|
||||
raise NotSupportedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an aget_new_connection() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def init_connection_state(self):
|
||||
"""Initialize the database connection settings."""
|
||||
if self.alias not in RAN_DB_VERSION_CHECK:
|
||||
self.check_database_version_supported()
|
||||
RAN_DB_VERSION_CHECK.add(self.alias)
|
||||
|
||||
async def ainit_connection_state(self):
|
||||
"""Initialize the database connection settings."""
|
||||
if self.alias not in RAN_DB_VERSION_CHECK:
|
||||
await self.acheck_database_version_supported()
|
||||
RAN_DB_VERSION_CHECK.add(self.alias)
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
"""Create a cursor. Assume that a connection is established."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a create_cursor() method"
|
||||
)
|
||||
|
||||
def create_async_cursor(self, name=None):
|
||||
"""Create a cursor. Assume that a connection is established."""
|
||||
raise NotSupportedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a "
|
||||
"create_async_cursor() method"
|
||||
)
|
||||
|
||||
# ##### Backend-specific methods for creating connections #####
|
||||
|
||||
@async_unsafe
|
||||
def connect(self):
|
||||
"""Connect to the database. Assume that the connection is closed."""
|
||||
@contextmanager
|
||||
def connect_manager(self):
|
||||
# Check for invalid configurations.
|
||||
self.check_settings()
|
||||
# In case the previous connection was closed while in an atomic block
|
||||
|
|
@ -251,14 +292,30 @@ class BaseDatabaseWrapper:
|
|||
self.errors_occurred = False
|
||||
# New connections are healthy.
|
||||
self.health_check_done = True
|
||||
# Establish the connection
|
||||
conn_params = self.get_connection_params()
|
||||
self.connection = self.get_new_connection(conn_params)
|
||||
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
|
||||
self.init_connection_state()
|
||||
connection_created.send(sender=self.__class__, connection=self)
|
||||
|
||||
self.run_on_commit = []
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
connection_created.send(sender=self.__class__, connection=self)
|
||||
self.run_on_commit = []
|
||||
|
||||
@async_unsafe
|
||||
def connect(self):
|
||||
"""Connect to the database. Assume that the connection is closed."""
|
||||
with self.connect_manager():
|
||||
conn_params = self.get_connection_params()
|
||||
self.connection = self.get_new_connection(conn_params)
|
||||
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
|
||||
self.init_connection_state()
|
||||
|
||||
async def aconnect(self):
|
||||
"""Connect to the database. Assume that the connection is closed."""
|
||||
with self.connect_manager():
|
||||
# Establish the connection
|
||||
conn_params = self.get_connection_params(for_async=True)
|
||||
self.aconnection = await self.aget_new_connection(conn_params)
|
||||
await self.aset_autocommit(self.settings_dict["AUTOCOMMIT"])
|
||||
await self.ainit_connection_state()
|
||||
|
||||
def check_settings(self):
|
||||
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
|
||||
|
|
@ -278,6 +335,16 @@ class BaseDatabaseWrapper:
|
|||
with self.wrap_database_errors:
|
||||
self.connect()
|
||||
|
||||
async def aensure_connection(self):
|
||||
"""Guarantee that a connection to the database is established."""
|
||||
if self.aconnection is None:
|
||||
if self.in_atomic_block and self.closed_in_transaction:
|
||||
raise ProgrammingError(
|
||||
"Cannot open a new connection in an atomic block."
|
||||
)
|
||||
with self.wrap_database_errors:
|
||||
await self.aconnect()
|
||||
|
||||
# ##### Backend-specific wrappers for PEP-249 connection methods #####
|
||||
|
||||
def _prepare_cursor(self, cursor):
|
||||
|
|
@ -291,27 +358,57 @@ class BaseDatabaseWrapper:
|
|||
wrapped_cursor = self.make_cursor(cursor)
|
||||
return wrapped_cursor
|
||||
|
||||
def _aprepare_cursor(self, cursor):
|
||||
"""
|
||||
Validate the connection is usable and perform database cursor wrapping.
|
||||
"""
|
||||
|
||||
self.validate_thread_sharing()
|
||||
if self.queries_logged:
|
||||
wrapped_cursor = self.make_debug_async_cursor(cursor)
|
||||
else:
|
||||
wrapped_cursor = self.make_async_cursor(cursor)
|
||||
return wrapped_cursor
|
||||
|
||||
def _cursor(self, name=None):
|
||||
self.close_if_health_check_failed()
|
||||
self.ensure_connection()
|
||||
with self.wrap_database_errors:
|
||||
return self._prepare_cursor(self.create_cursor(name))
|
||||
|
||||
def _acursor(self, name=None):
|
||||
return utils.AsyncCursorCtx(self, name)
|
||||
|
||||
def _commit(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
||||
return self.connection.commit()
|
||||
|
||||
async def _acommit(self):
|
||||
if self.aconnection is not None:
|
||||
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
||||
return await self.aconnection.commit()
|
||||
|
||||
def _rollback(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
||||
return self.connection.rollback()
|
||||
|
||||
async def _arollback(self):
|
||||
if self.aconnection is not None:
|
||||
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
||||
return await self.aconnection.rollback()
|
||||
|
||||
def _close(self):
|
||||
if self.connection is not None:
|
||||
with self.wrap_database_errors:
|
||||
return self.connection.close()
|
||||
|
||||
async def _aclose(self):
|
||||
if self.aconnection is not None:
|
||||
with self.wrap_database_errors:
|
||||
return await self.aconnection.close()
|
||||
|
||||
# ##### Generic wrappers for PEP-249 connection methods #####
|
||||
|
||||
@async_unsafe
|
||||
|
|
@ -319,6 +416,10 @@ class BaseDatabaseWrapper:
|
|||
"""Create a cursor, opening a connection if necessary."""
|
||||
return self._cursor()
|
||||
|
||||
def acursor(self):
|
||||
"""Create an async cursor, opening a connection if necessary."""
|
||||
return self._acursor()
|
||||
|
||||
@async_unsafe
|
||||
def commit(self):
|
||||
"""Commit a transaction and reset the dirty flag."""
|
||||
|
|
@ -329,6 +430,15 @@ class BaseDatabaseWrapper:
|
|||
self.errors_occurred = False
|
||||
self.run_commit_hooks_on_set_autocommit_on = True
|
||||
|
||||
async def acommit(self):
|
||||
"""Commit a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
await self._acommit()
|
||||
# A successful commit means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.run_commit_hooks_on_set_autocommit_on = True
|
||||
|
||||
@async_unsafe
|
||||
def rollback(self):
|
||||
"""Roll back a transaction and reset the dirty flag."""
|
||||
|
|
@ -340,6 +450,16 @@ class BaseDatabaseWrapper:
|
|||
self.needs_rollback = False
|
||||
self.run_on_commit = []
|
||||
|
||||
async def arollback(self):
|
||||
"""Roll back a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
await self._arollback()
|
||||
# A successful rollback means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.needs_rollback = False
|
||||
self.run_on_commit = []
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
"""Close the connection to the database."""
|
||||
|
|
@ -360,24 +480,59 @@ class BaseDatabaseWrapper:
|
|||
else:
|
||||
self.connection = None
|
||||
|
||||
async def aclose(self):
|
||||
"""Close the connection to the database."""
|
||||
self.validate_thread_sharing()
|
||||
self.run_on_commit = []
|
||||
|
||||
# Don't call validate_no_atomic_block() to avoid making it difficult
|
||||
# to get rid of a connection in an invalid state. The next connect()
|
||||
# will reset the transaction state anyway.
|
||||
if self.closed_in_transaction or self.aconnection is None:
|
||||
return
|
||||
try:
|
||||
await self._aclose()
|
||||
finally:
|
||||
if self.in_atomic_block:
|
||||
self.closed_in_transaction = True
|
||||
self.needs_rollback = True
|
||||
else:
|
||||
self.aconnection = None
|
||||
|
||||
# ##### Backend-specific savepoint management methods #####
|
||||
|
||||
def _savepoint(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_create_sql(sid))
|
||||
|
||||
async def _asavepoint(self, sid):
|
||||
async with self.acursor() as cursor:
|
||||
await cursor.aexecute(self.ops.savepoint_create_sql(sid))
|
||||
|
||||
def _savepoint_rollback(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_rollback_sql(sid))
|
||||
|
||||
async def _asavepoint_rollback(self, sid):
|
||||
async with self.acursor() as cursor:
|
||||
await cursor.aexecute(self.ops.savepoint_rollback_sql(sid))
|
||||
|
||||
def _savepoint_commit(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_commit_sql(sid))
|
||||
|
||||
async def _asavepoint_commit(self, sid):
|
||||
async with self.acursor() as cursor:
|
||||
await cursor.aexecute(self.ops.savepoint_commit_sql(sid))
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# Savepoints cannot be created outside a transaction
|
||||
return self.features.uses_savepoints and not self.get_autocommit()
|
||||
|
||||
async def _asavepoint_allowed(self):
|
||||
# Savepoints cannot be created outside a transaction
|
||||
return self.features.uses_savepoints and not (await self.aget_autocommit())
|
||||
|
||||
# ##### Generic savepoint management methods #####
|
||||
|
||||
@async_unsafe
|
||||
|
|
@ -401,6 +556,26 @@ class BaseDatabaseWrapper:
|
|||
|
||||
return sid
|
||||
|
||||
async def asavepoint(self):
|
||||
"""
|
||||
Create a savepoint inside the current transaction. Return an
|
||||
identifier for the savepoint that will be used for the subsequent
|
||||
rollback or commit. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not (await self._asavepoint_allowed()):
|
||||
return
|
||||
|
||||
thread_ident = _thread.get_ident()
|
||||
tid = str(thread_ident).replace("-", "")
|
||||
|
||||
self.savepoint_state += 1
|
||||
sid = "s%s_x%d" % (tid, self.savepoint_state)
|
||||
|
||||
self.validate_thread_sharing()
|
||||
await self._asavepoint(sid)
|
||||
|
||||
return sid
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_rollback(self, sid):
|
||||
"""
|
||||
|
|
@ -419,6 +594,23 @@ class BaseDatabaseWrapper:
|
|||
if sid not in sids
|
||||
]
|
||||
|
||||
async def asavepoint_rollback(self, sid):
|
||||
"""
|
||||
Roll back to a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not (await self._asavepoint_allowed()):
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
await self._asavepoint_rollback(sid)
|
||||
|
||||
# Remove any callbacks registered while this savepoint was active.
|
||||
self.run_on_commit = [
|
||||
(sids, func, robust)
|
||||
for (sids, func, robust) in self.run_on_commit
|
||||
if sid not in sids
|
||||
]
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_commit(self, sid):
|
||||
"""
|
||||
|
|
@ -430,6 +622,16 @@ class BaseDatabaseWrapper:
|
|||
self.validate_thread_sharing()
|
||||
self._savepoint_commit(sid)
|
||||
|
||||
async def asavepoint_commit(self, sid):
|
||||
"""
|
||||
Release a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not (await self._asavepoint_allowed()):
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
await self._asavepoint_commit(sid)
|
||||
|
||||
@async_unsafe
|
||||
def clean_savepoints(self):
|
||||
"""
|
||||
|
|
@ -447,6 +649,14 @@ class BaseDatabaseWrapper:
|
|||
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
|
||||
)
|
||||
|
||||
async def _aset_autocommit(self, autocommit):
|
||||
"""
|
||||
Backend-specific implementation to enable or disable autocommit.
|
||||
"""
|
||||
raise NotSupportedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an _aset_autocommit() method"
|
||||
)
|
||||
|
||||
# ##### Generic transaction management methods #####
|
||||
|
||||
def get_autocommit(self):
|
||||
|
|
@ -454,6 +664,11 @@ class BaseDatabaseWrapper:
|
|||
self.ensure_connection()
|
||||
return self.autocommit
|
||||
|
||||
async def aget_autocommit(self):
|
||||
"""Get the autocommit state."""
|
||||
await self.aensure_connection()
|
||||
return self.autocommit
|
||||
|
||||
def set_autocommit(
|
||||
self, autocommit, force_begin_transaction_with_broken_autocommit=False
|
||||
):
|
||||
|
|
@ -491,6 +706,43 @@ class BaseDatabaseWrapper:
|
|||
self.run_and_clear_commit_hooks()
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
async def aset_autocommit(
|
||||
self, autocommit, force_begin_transaction_with_broken_autocommit=False
|
||||
):
|
||||
"""
|
||||
Enable or disable autocommit.
|
||||
|
||||
The usual way to start a transaction is to turn autocommit off.
|
||||
SQLite does not properly start a transaction when disabling
|
||||
autocommit. To avoid this buggy behavior and to actually enter a new
|
||||
transaction, an explicit BEGIN is required. Using
|
||||
force_begin_transaction_with_broken_autocommit=True will issue an
|
||||
explicit BEGIN with SQLite. This option will be ignored for other
|
||||
backends.
|
||||
"""
|
||||
self.validate_no_atomic_block()
|
||||
await self.aclose_if_health_check_failed()
|
||||
await self.aensure_connection()
|
||||
|
||||
start_transaction_under_autocommit = (
|
||||
force_begin_transaction_with_broken_autocommit
|
||||
and not autocommit
|
||||
and hasattr(self, "_astart_transaction_under_autocommit")
|
||||
)
|
||||
|
||||
if start_transaction_under_autocommit:
|
||||
await self._astart_transaction_under_autocommit()
|
||||
elif autocommit:
|
||||
await self._aset_autocommit(autocommit)
|
||||
else:
|
||||
with debug_transaction(self, "BEGIN"):
|
||||
await self._aset_autocommit(autocommit)
|
||||
self.autocommit = autocommit
|
||||
|
||||
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
|
||||
self.run_and_clear_commit_hooks()
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
def get_rollback(self):
|
||||
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
||||
if not self.in_atomic_block:
|
||||
|
|
@ -575,6 +827,19 @@ class BaseDatabaseWrapper:
|
|||
"subclasses of BaseDatabaseWrapper may require an is_usable() method"
|
||||
)
|
||||
|
||||
async def ais_usable(self):
|
||||
"""
|
||||
Test if the database connection is usable.
|
||||
|
||||
This method may assume that self.connection is not None.
|
||||
|
||||
Actual implementations should take care not to raise exceptions
|
||||
as that may prevent Django from recycling unusable connections.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an ais_usable() method"
|
||||
)
|
||||
|
||||
def close_if_health_check_failed(self):
|
||||
"""Close existing connection if it fails a health check."""
|
||||
if (
|
||||
|
|
@ -588,6 +853,20 @@ class BaseDatabaseWrapper:
|
|||
self.close()
|
||||
self.health_check_done = True
|
||||
|
||||
async def aclose_if_health_check_failed(self):
|
||||
"""Close existing connection if it fails a health check."""
|
||||
if (
|
||||
self.aconnection is None
|
||||
or not self.health_check_enabled
|
||||
or self.health_check_done
|
||||
):
|
||||
return
|
||||
|
||||
is_usable = await self.ais_usable()
|
||||
if not is_usable:
|
||||
await self.aclose()
|
||||
self.health_check_done = True
|
||||
|
||||
def close_if_unusable_or_obsolete(self):
|
||||
"""
|
||||
Close the current connection if unrecoverable errors have occurred
|
||||
|
|
@ -677,10 +956,18 @@ class BaseDatabaseWrapper:
|
|||
"""Create a cursor that logs all queries in self.queries_log."""
|
||||
return utils.CursorDebugWrapper(cursor, self)
|
||||
|
||||
def make_debug_async_cursor(self, cursor):
|
||||
"""Create a cursor that logs all queries in self.queries_log."""
|
||||
return utils.AsyncCursorDebugWrapper(cursor, self)
|
||||
|
||||
def make_cursor(self, cursor):
|
||||
"""Create a cursor without debug logging."""
|
||||
return utils.CursorWrapper(cursor, self)
|
||||
|
||||
def make_async_cursor(self, cursor):
|
||||
"""Create a cursor without debug logging."""
|
||||
return utils.AsyncCursorWrapper(cursor, self)
|
||||
|
||||
@contextmanager
|
||||
def temporary_connection(self):
|
||||
"""
|
||||
|
|
@ -698,6 +985,27 @@ class BaseDatabaseWrapper:
|
|||
if must_close:
|
||||
self.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def atemporary_connection(self):
|
||||
"""
|
||||
Context manager that ensures that a connection is established, and
|
||||
if it opened one, closes it to avoid leaving a dangling connection.
|
||||
This is useful for operations outside of the request-response cycle.
|
||||
|
||||
Provide a cursor::
|
||||
async with self.atemporary_connection() as cursor:
|
||||
...
|
||||
"""
|
||||
# unused
|
||||
|
||||
must_close = self.aconnection is None
|
||||
try:
|
||||
async with self.acursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
if must_close:
|
||||
await self.aclose()
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -358,6 +358,9 @@ class BaseDatabaseFeatures:
|
|||
# Does the backend support negative JSON array indexing?
|
||||
supports_json_negative_indexing = True
|
||||
|
||||
# Asynchronous database operations
|
||||
supports_async = False
|
||||
|
||||
# Does the backend support column collations?
|
||||
supports_collation_on_charfield = True
|
||||
supports_collation_on_textfield = True
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ from django.core.exceptions import ImproperlyConfigured
|
|||
from django.db import DatabaseError as WrappedDatabaseError
|
||||
from django.db import connections
|
||||
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
||||
from django.db.backends.utils import (
|
||||
AsyncCursorDebugWrapper as AsyncBaseCursorDebugWrapper,
|
||||
)
|
||||
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
|
|
@ -98,6 +101,8 @@ def _get_decimal_column(data):
|
|||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "postgresql"
|
||||
display_name = "PostgreSQL"
|
||||
_pg_version = None
|
||||
|
||||
# This dictionary maps Field objects to their associated PostgreSQL column
|
||||
# types, as strings. Column-type strings can contain format strings;
|
||||
# they'll be interpolated against the values of Field.__dict__ before being
|
||||
|
|
@ -231,11 +236,57 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
|
||||
return self._connection_pools[self.alias]
|
||||
|
||||
@property
|
||||
def apool(self):
|
||||
pool_options = self.settings_dict["OPTIONS"].get("pool")
|
||||
if self.alias == NO_DB_ALIAS or not pool_options:
|
||||
return None
|
||||
|
||||
if self.alias not in self._aconnection_pools:
|
||||
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
|
||||
raise ImproperlyConfigured(
|
||||
"Pooling doesn't support persistent connections."
|
||||
)
|
||||
# Set the default options.
|
||||
if pool_options is True:
|
||||
pool_options = {}
|
||||
|
||||
try:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
except ImportError as err:
|
||||
raise ImproperlyConfigured(
|
||||
"Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
|
||||
) from err
|
||||
|
||||
connect_kwargs = self.get_connection_params(for_async=True)
|
||||
# Ensure we run in autocommit, Django properly sets it later on.
|
||||
connect_kwargs["autocommit"] = True
|
||||
enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
|
||||
pool = AsyncConnectionPool(
|
||||
kwargs=connect_kwargs,
|
||||
open=False, # Do not open the pool during startup.
|
||||
configure=self._aconfigure_connection,
|
||||
check=AsyncConnectionPool.check_connection if enable_checks else None,
|
||||
**pool_options,
|
||||
)
|
||||
# setdefault() ensures that multiple threads don't set this in
|
||||
# parallel. Since we do not open the pool during it's init above,
|
||||
# this means that at worst during startup multiple threads generate
|
||||
# pool objects and the first to set it wins.
|
||||
self._aconnection_pools.setdefault(self.alias, pool)
|
||||
|
||||
return self._aconnection_pools[self.alias]
|
||||
|
||||
def close_pool(self):
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
del self._connection_pools[self.alias]
|
||||
|
||||
async def aclose_pool(self):
|
||||
if self.apool:
|
||||
await self.apool.close()
|
||||
del self._aconnection_pools[self.alias]
|
||||
|
||||
def get_database_version(self):
|
||||
"""
|
||||
Return a tuple of the database's version.
|
||||
|
|
@ -243,7 +294,38 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
"""
|
||||
return divmod(self.pg_version, 10000)
|
||||
|
||||
def get_connection_params(self):
|
||||
async def aget_database_version(self):
|
||||
"""
|
||||
Return a tuple of the database's version.
|
||||
E.g. for pg_version 120004, return (12, 4).
|
||||
"""
|
||||
pg_version = await self.apg_version
|
||||
return divmod(pg_version, 10000)
|
||||
|
||||
def _get_sync_cursor_factory(self, server_side_binding=None):
|
||||
if is_psycopg3 and server_side_binding is True:
|
||||
return ServerBindingCursor
|
||||
else:
|
||||
return Cursor
|
||||
|
||||
def _get_async_cursor_factory(self, server_side_binding=None):
|
||||
if is_psycopg3 and server_side_binding is True:
|
||||
return AsyncServerBindingCursor
|
||||
else:
|
||||
return AsyncCursor
|
||||
|
||||
def _get_cursor_factory(self, server_side_binding=None, for_async=False):
|
||||
if for_async and not is_psycopg3:
|
||||
raise ImproperlyConfigured(
|
||||
"Django requires psycopg >= 3 for ORM async support."
|
||||
)
|
||||
|
||||
if for_async:
|
||||
return self._get_async_cursor_factory(server_side_binding)
|
||||
else:
|
||||
return self._get_sync_cursor_factory(server_side_binding)
|
||||
|
||||
def get_connection_params(self, for_async=False):
|
||||
settings_dict = self.settings_dict
|
||||
# None may be used to connect to the default 'postgres' db
|
||||
if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
|
||||
|
|
@ -283,14 +365,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
|
||||
|
||||
server_side_binding = conn_params.pop("server_side_binding", None)
|
||||
conn_params.setdefault(
|
||||
"cursor_factory",
|
||||
(
|
||||
ServerBindingCursor
|
||||
if is_psycopg3 and server_side_binding is True
|
||||
else Cursor
|
||||
),
|
||||
cursor_factory = self._get_cursor_factory(
|
||||
server_side_binding, for_async=for_async
|
||||
)
|
||||
conn_params.setdefault("cursor_factory", cursor_factory)
|
||||
if settings_dict["USER"]:
|
||||
conn_params["user"] = settings_dict["USER"]
|
||||
if settings_dict["PASSWORD"]:
|
||||
|
|
@ -310,8 +388,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
)
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
def _get_isolation_level(self):
|
||||
# self.isolation_level must be set:
|
||||
# - after connecting to the database in order to obtain the database's
|
||||
# default when no value is explicitly specified in options.
|
||||
|
|
@ -322,17 +399,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
try:
|
||||
isolation_level_value = options["isolation_level"]
|
||||
except KeyError:
|
||||
self.isolation_level = IsolationLevel.READ_COMMITTED
|
||||
isolation_level = IsolationLevel.READ_COMMITTED
|
||||
else:
|
||||
# Set the isolation level to the value from OPTIONS.
|
||||
try:
|
||||
self.isolation_level = IsolationLevel(isolation_level_value)
|
||||
isolation_level = IsolationLevel(isolation_level_value)
|
||||
set_isolation_level = True
|
||||
except ValueError:
|
||||
raise ImproperlyConfigured(
|
||||
f"Invalid transaction isolation level {isolation_level_value} "
|
||||
f"specified. Use one of the psycopg.IsolationLevel values."
|
||||
)
|
||||
return isolation_level, set_isolation_level
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
isolation_level, set_isolation_level = self._get_isolation_level()
|
||||
self.isolation_level = isolation_level
|
||||
if self.pool:
|
||||
# If nothing else has opened the pool, open it now.
|
||||
self.pool.open()
|
||||
|
|
@ -340,7 +422,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
else:
|
||||
connection = self.Database.connect(**conn_params)
|
||||
if set_isolation_level:
|
||||
connection.isolation_level = self.isolation_level
|
||||
connection.isolation_level = isolation_level
|
||||
if not is_psycopg3:
|
||||
# Register dummy loads() to avoid a round trip from psycopg2's
|
||||
# decode to json.dumps() to json.loads(), when using a custom
|
||||
|
|
@ -350,6 +432,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
)
|
||||
return connection
|
||||
|
||||
async def aget_new_connection(self, conn_params):
|
||||
isolation_level, set_isolation_level = self._get_isolation_level()
|
||||
self.isolation_level = isolation_level
|
||||
if self.apool:
|
||||
# If nothing else has opened the pool, open it now.
|
||||
await self.apool.open()
|
||||
connection = await self.apool.getconn()
|
||||
else:
|
||||
connection = await self.Database.AsyncConnection.connect(**conn_params)
|
||||
if set_isolation_level:
|
||||
connection.isolation_level = isolation_level
|
||||
return connection
|
||||
|
||||
def ensure_timezone(self):
|
||||
# Close the pool so new connections pick up the correct timezone.
|
||||
self.close_pool()
|
||||
|
|
@ -357,6 +452,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
return False
|
||||
return self._configure_timezone(self.connection)
|
||||
|
||||
async def aensure_timezone(self):
|
||||
# Close the pool so new connections pick up the correct timezone.
|
||||
await self.aclose_pool()
|
||||
if self.connection is None:
|
||||
return False
|
||||
return await self._aconfigure_timezone(self.connection)
|
||||
|
||||
def _configure_timezone(self, connection):
|
||||
conn_timezone_name = connection.info.parameter_status("TimeZone")
|
||||
timezone_name = self.timezone_name
|
||||
|
|
@ -366,6 +468,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
return True
|
||||
return False
|
||||
|
||||
async def _aconfigure_timezone(self, connection):
|
||||
conn_timezone_name = connection.info.parameter_status("TimeZone")
|
||||
timezone_name = self.timezone_name
|
||||
if timezone_name and conn_timezone_name != timezone_name:
|
||||
async with connection.cursor() as cursor:
|
||||
await cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
|
||||
return True
|
||||
return False
|
||||
|
||||
def _configure_role(self, connection):
|
||||
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
|
||||
with connection.cursor() as cursor:
|
||||
|
|
@ -374,6 +485,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
return True
|
||||
return False
|
||||
|
||||
async def _aconfigure_role(self, connection):
|
||||
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
|
||||
async with connection.acursor() as cursor:
|
||||
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
|
||||
await cursor.aaexecute(sql)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _configure_connection(self, connection):
|
||||
# This function is called from init_connection_state and from the
|
||||
# psycopg pool itself after a connection is opened.
|
||||
|
|
@ -387,6 +506,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
|
||||
return commit_role or commit_tz
|
||||
|
||||
async def _aconfigure_connection(self, connection):
|
||||
# This function is called from init_connection_state and from the
|
||||
# psycopg pool itself after a connection is opened.
|
||||
|
||||
# Commit after setting the time zone.
|
||||
commit_tz = await self._aconfigure_timezone(connection)
|
||||
# Set the role on the connection. This is useful if the credential used
|
||||
# to login is not the same as the role that owns database resources. As
|
||||
# can be the case when using temporary or ephemeral credentials.
|
||||
commit_role = await self._aconfigure_role(connection)
|
||||
|
||||
return commit_role or commit_tz
|
||||
|
||||
def _close(self):
|
||||
if self.connection is not None:
|
||||
# `wrap_database_errors` only works for `putconn` as long as there
|
||||
|
|
@ -403,6 +535,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
else:
|
||||
return self.connection.close()
|
||||
|
||||
async def _aclose(self):
|
||||
if self.aconnection is not None:
|
||||
# `wrap_database_errors` only works for `putconn` as long as there
|
||||
# is no `reset` function set in the pool because it is deferred
|
||||
# into a thread and not directly executed.
|
||||
with self.wrap_database_errors:
|
||||
if self.apool:
|
||||
# Ensure the correct pool is returned. This is a workaround
|
||||
# for tests so a pool can be changed on setting changes
|
||||
# (e.g. USE_TZ, TIME_ZONE).
|
||||
await self.aconnection._pool.putconn(self.aconnection)
|
||||
# Connection can no longer be used.
|
||||
self.aconnection = None
|
||||
else:
|
||||
return await self.aconnection.close()
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
|
||||
|
|
@ -412,6 +560,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
if commit and not self.get_autocommit():
|
||||
self.connection.commit()
|
||||
|
||||
async def ainit_connection_state(self):
|
||||
await super().ainit_connection_state()
|
||||
|
||||
if self.aconnection is not None and not self.apool:
|
||||
commit = await self._aconfigure_connection(self.aconnection)
|
||||
|
||||
if commit:
|
||||
autocommit = await self.aget_autocommit()
|
||||
if not autocommit:
|
||||
await self.aconnection.commit()
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
if name:
|
||||
|
|
@ -447,6 +606,35 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
||||
return cursor
|
||||
|
||||
def create_async_cursor(self, name=None):
|
||||
if name:
|
||||
if self.settings_dict["OPTIONS"].get("server_side_binding") is not True:
|
||||
# psycopg >= 3 forces the usage of server-side bindings for
|
||||
# named cursors so a specialized class that implements
|
||||
# server-side cursors while performing client-side bindings
|
||||
# must be used if `server_side_binding` is disabled (default).
|
||||
cursor = AsyncServerSideCursor(
|
||||
self.aconnection,
|
||||
name=name,
|
||||
scrollable=False,
|
||||
withhold=self.aconnection.autocommit,
|
||||
)
|
||||
else:
|
||||
# In autocommit mode, the cursor will be used outside of a
|
||||
# transaction, hence use a holdable cursor.
|
||||
cursor = self.aconnection.cursor(
|
||||
name, scrollable=False, withhold=self.aconnection.autocommit
|
||||
)
|
||||
else:
|
||||
cursor = self.aconnection.cursor()
|
||||
|
||||
# Register the cursor timezone only if the connection disagrees, to
|
||||
# avoid copying the adapter map.
|
||||
tzloader = self.aconnection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
|
||||
if self.timezone != tzloader.timezone:
|
||||
register_tzloader(self.timezone, cursor)
|
||||
return cursor
|
||||
|
||||
def tzinfo_factory(self, offset):
|
||||
return self.timezone
|
||||
|
||||
|
|
@ -478,10 +666,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
)
|
||||
)
|
||||
|
||||
async def achunked_cursor(self):
|
||||
self._named_cursor_idx += 1
|
||||
# Get the current async task
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
except RuntimeError:
|
||||
current_task = None
|
||||
# Current task can be none even if the current_task call didn't error
|
||||
if current_task:
|
||||
task_ident = str(id(current_task))
|
||||
else:
|
||||
task_ident = "sync"
|
||||
# Use that and the thread ident to get a unique name
|
||||
return self._acursor(
|
||||
name="_django_curs_%d_%s_%d"
|
||||
% (
|
||||
# Avoid reusing name in other threads / tasks
|
||||
threading.current_thread().ident,
|
||||
task_ident,
|
||||
self._named_cursor_idx,
|
||||
)
|
||||
)
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit = autocommit
|
||||
|
||||
async def _aset_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
await self.aconnection.set_autocommit(autocommit)
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check constraints by setting them to immediate. Return them to deferred
|
||||
|
|
@ -503,12 +718,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
else:
|
||||
return True
|
||||
|
||||
async def ais_usable(self):
|
||||
if self.aconnection is None:
|
||||
return False
|
||||
try:
|
||||
# Use a psycopg cursor directly, bypassing Django's utilities.
|
||||
async with self.aconnection.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def close_if_health_check_failed(self):
|
||||
if self.pool:
|
||||
# The pool only returns healthy connections.
|
||||
return
|
||||
return super().close_if_health_check_failed()
|
||||
|
||||
async def aclose_if_health_check_failed(self):
|
||||
if self.apool:
|
||||
# The pool only returns healthy connections.
|
||||
return
|
||||
return await super().aclose_if_health_check_failed()
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
cursor = None
|
||||
|
|
@ -549,8 +782,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
|
||||
@cached_property
|
||||
def pg_version(self):
|
||||
with self.temporary_connection():
|
||||
return self.connection.info.server_version
|
||||
if self._pg_version is None:
|
||||
with self.temporary_connection():
|
||||
self._pg_version = self.connection.info.server_version
|
||||
return self._pg_version
|
||||
|
||||
@cached_property
|
||||
async def apg_version(self):
|
||||
if self._pg_version is None:
|
||||
async with self.atemporary_connection():
|
||||
self._pg_version = self.aconnection.info.server_version
|
||||
return self._pg_version
|
||||
|
||||
def make_debug_cursor(self, cursor):
|
||||
return CursorDebugWrapper(cursor, self)
|
||||
|
|
@ -607,6 +849,36 @@ if is_psycopg3:
|
|||
with self.debug_sql(statement):
|
||||
return self.cursor.copy(statement)
|
||||
|
||||
class AsyncServerBindingCursor(CursorMixin, Database.AsyncClientCursor):
|
||||
pass
|
||||
|
||||
class AsyncCursor(CursorMixin, Database.AsyncClientCursor):
|
||||
pass
|
||||
|
||||
class AsyncServerSideCursor(
|
||||
CursorMixin,
|
||||
Database.client_cursor.ClientCursorMixin,
|
||||
Database.AsyncServerCursor,
|
||||
):
|
||||
"""
|
||||
psycopg >= 3 forces the usage of server-side bindings when using named
|
||||
cursors but the ORM doesn't yet support the systematic generation of
|
||||
prepareable SQL (#20516).
|
||||
|
||||
ClientCursorMixin forces the usage of client-side bindings while
|
||||
AsyncServerCursor implements the logic required to declare and scroll
|
||||
through named cursors.
|
||||
|
||||
Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to
|
||||
specify how parameters should be bound instead, which AsyncServerCursor
|
||||
would inherit, but that's not the case.
|
||||
"""
|
||||
|
||||
class AsyncCursorDebugWrapper(AsyncBaseCursorDebugWrapper):
|
||||
def copy(self, statement):
|
||||
with self.debug_sql(statement):
|
||||
return self.cursor.copy(statement)
|
||||
|
||||
else:
|
||||
Cursor = psycopg2.extensions.cursor
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
requires_casted_case_in_updates = True
|
||||
supports_async = is_psycopg3
|
||||
supports_over_clause = True
|
||||
supports_frame_exclusion = True
|
||||
only_supports_unbounded_with_preceding_and_following = True
|
||||
|
|
|
|||
|
|
@ -114,6 +114,98 @@ class CursorWrapper:
|
|||
return self.cursor.executemany(sql, param_list)
|
||||
|
||||
|
||||
class AsyncCursorCtx:
|
||||
"""
|
||||
Asynchronous context manager to hold an async cursor.
|
||||
"""
|
||||
|
||||
def __init__(self, db, name=None):
|
||||
self.db = db
|
||||
self.name = name
|
||||
self.wrap_database_errors = self.db.wrap_database_errors
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.db.aclose_if_health_check_failed()
|
||||
await self.db.aensure_connection()
|
||||
self.wrap_database_errors.__enter__()
|
||||
return self.db._aprepare_cursor(self.db.create_async_cursor(self.name))
|
||||
|
||||
async def __aexit__(self, type, value, traceback):
|
||||
self.wrap_database_errors.__exit__(type, value, traceback)
|
||||
|
||||
|
||||
class AsyncCursorWrapper(CursorWrapper):
|
||||
async def _aexecute(self, sql, params, *ignored_wrapper_args):
|
||||
# Raise a warning during app initialization (stored_app_configs is only
|
||||
# ever set during testing).
|
||||
if not apps.ready and not apps.stored_app_configs:
|
||||
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
if params is None:
|
||||
# params default might be backend specific.
|
||||
return await self.cursor.execute(sql)
|
||||
else:
|
||||
return await self.cursor.execute(sql, params)
|
||||
|
||||
async def _aexecute_with_wrappers(self, sql, params, many, executor):
|
||||
context = {"connection": self.db, "cursor": self}
|
||||
for wrapper in reversed(self.db.execute_wrappers):
|
||||
executor = functools.partial(wrapper, executor)
|
||||
return await executor(sql, params, many, context)
|
||||
|
||||
async def aexecute(self, sql, params=None):
|
||||
return await self._aexecute_with_wrappers(
|
||||
sql, params, many=False, executor=self._aexecute
|
||||
)
|
||||
|
||||
async def _aexecutemany(self, sql, param_list, *ignored_wrapper_args):
|
||||
# Raise a warning during app initialization (stored_app_configs is only
|
||||
# ever set during testing).
|
||||
if not apps.ready and not apps.stored_app_configs:
|
||||
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
return await self.cursor.executemany(sql, param_list)
|
||||
|
||||
async def aexecutemany(self, sql, param_list):
|
||||
return await self._aexecute_with_wrappers(
|
||||
sql, param_list, many=True, executor=self._aexecutemany
|
||||
)
|
||||
|
||||
async def afetchone(self, *args, **kwargs):
|
||||
return await self.cursor.fetchone(*args, **kwargs)
|
||||
|
||||
async def afetchmany(self, *args, **kwargs):
|
||||
return await self.cursor.fetchmany(*args, **kwargs)
|
||||
|
||||
async def afetchall(self, *args, **kwargs):
|
||||
return await self.cursor.fetchall(*args, **kwargs)
|
||||
|
||||
def acopy(self, *args, **kwargs):
|
||||
return self.cursor.copy(*args, **kwargs)
|
||||
|
||||
def astream(self, *args, **kwargs):
|
||||
return self.cursor.stream(*args, **kwargs)
|
||||
|
||||
async def ascroll(self, *args, **kwargs):
|
||||
return await self.cursor.scroll(*args, **kwargs)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, type, value, traceback):
|
||||
try:
|
||||
await self.close()
|
||||
except self.db.Database.Error:
|
||||
pass
|
||||
|
||||
async def __aiter__(self):
|
||||
with self.db.wrap_database_errors:
|
||||
async for item in self.cursor:
|
||||
yield item
|
||||
|
||||
|
||||
class CursorDebugWrapper(CursorWrapper):
|
||||
# XXX callproc isn't instrumented at this time.
|
||||
|
||||
|
|
@ -163,6 +255,57 @@ class CursorDebugWrapper(CursorWrapper):
|
|||
)
|
||||
|
||||
|
||||
class AsyncCursorDebugWrapper(AsyncCursorWrapper):
|
||||
# XXX callproc isn't instrumented at this time.
|
||||
|
||||
async def aexecute(self, sql, params=None):
|
||||
with self.debug_sql(sql, params, use_last_executed_query=True):
|
||||
return await super().aexecute(sql, params)
|
||||
|
||||
async def aexecutemany(self, sql, param_list):
|
||||
with self.debug_sql(sql, param_list, many=True):
|
||||
return await super().aexecutemany(sql, param_list)
|
||||
|
||||
@contextmanager
|
||||
def debug_sql(
|
||||
self, sql=None, params=None, use_last_executed_query=False, many=False
|
||||
):
|
||||
start = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stop = time.monotonic()
|
||||
duration = stop - start
|
||||
if use_last_executed_query:
|
||||
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
|
||||
try:
|
||||
times = len(params) if many else ""
|
||||
except TypeError:
|
||||
# params could be an iterator.
|
||||
times = "?"
|
||||
self.db.queries_log.append(
|
||||
{
|
||||
"sql": "%s times: %s" % (times, sql) if many else sql,
|
||||
"time": "%.3f" % duration,
|
||||
"async": True,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"(%.3f) %s; args=%s; alias=%s; async=True",
|
||||
duration,
|
||||
sql,
|
||||
params,
|
||||
self.db.alias,
|
||||
extra={
|
||||
"duration": duration,
|
||||
"sql": sql,
|
||||
"params": params,
|
||||
"alias": self.db.alias,
|
||||
"async": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def debug_transaction(connection, sql):
|
||||
start = time.monotonic()
|
||||
|
|
@ -176,18 +319,21 @@ def debug_transaction(connection, sql):
|
|||
{
|
||||
"sql": "%s" % sql,
|
||||
"time": "%.3f" % duration,
|
||||
"async": connection.features.supports_async,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"(%.3f) %s; args=%s; alias=%s",
|
||||
"(%.3f) %s; args=%s; alias=%s; async=%s",
|
||||
duration,
|
||||
sql,
|
||||
None,
|
||||
connection.alias,
|
||||
connection.features.supports_async,
|
||||
extra={
|
||||
"duration": duration,
|
||||
"sql": sql,
|
||||
"alias": connection.alias,
|
||||
"async": connection.features.supports_async,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import pkgutil
|
||||
from importlib import import_module
|
||||
|
||||
from asgiref.local import Local
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
|
|
@ -197,6 +199,89 @@ class ConnectionHandler(BaseConnectionHandler):
|
|||
return backend.DatabaseWrapper(db, alias)
|
||||
|
||||
|
||||
class AsyncAlias:
|
||||
"""
|
||||
A Context-aware list of connections.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections = Local()
|
||||
setattr(self._connections, "_stack", [])
|
||||
|
||||
@property
|
||||
def connections(self):
|
||||
return getattr(self._connections, "_stack", [])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.connections)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.connections)
|
||||
|
||||
def __str__(self):
|
||||
return ", ".join([str(id(conn)) for conn in self.connections])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}: {len(self.connections)} connections>"
|
||||
|
||||
def add_connection(self, connection):
|
||||
setattr(self._connections, "_stack", self.connections + [connection])
|
||||
|
||||
def pop(self):
|
||||
conns = self.connections
|
||||
conns.pop()
|
||||
setattr(self._connections, "_stack", conns)
|
||||
|
||||
|
||||
class AsyncConnectionHandler:
|
||||
"""
|
||||
Context-aware class to store async connections, mapped by alias name.
|
||||
"""
|
||||
|
||||
_from_testcase = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._aliases = Local()
|
||||
self._connection_count = Local()
|
||||
setattr(self._connection_count, "value", 0)
|
||||
|
||||
def __getitem__(self, alias):
|
||||
try:
|
||||
async_alias = getattr(self._aliases, alias)
|
||||
except AttributeError:
|
||||
async_alias = AsyncAlias()
|
||||
setattr(self._aliases, alias, async_alias)
|
||||
return async_alias
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}: {self.count} connections>"
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return getattr(self._connection_count, "value", 0)
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return self.count == 0
|
||||
|
||||
def add_connection(self, using, connection):
|
||||
self[using].add_connection(connection)
|
||||
setattr(self._connection_count, "value", self.count + 1)
|
||||
|
||||
async def pop_connection(self, using):
|
||||
await self[using].connections[-1].aclose_pool()
|
||||
self[using].connections.pop()
|
||||
setattr(self._connection_count, "value", self.count - 1)
|
||||
|
||||
def get_connection(self, using):
|
||||
alias = self[using]
|
||||
if len(alias.connections) == 0:
|
||||
raise ConnectionDoesNotExist(
|
||||
f"There are no async connections using the '{using}' alias."
|
||||
)
|
||||
return alias.connections[-1]
|
||||
|
||||
|
||||
class ConnectionRouter:
|
||||
def __init__(self, routers=None):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -38,7 +38,13 @@ from django.core.management.color import no_style
|
|||
from django.core.management.sql import emit_post_migrate_signal
|
||||
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
|
||||
from django.core.signals import setting_changed
|
||||
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
|
||||
from django.db import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
async_connections,
|
||||
connection,
|
||||
connections,
|
||||
transaction,
|
||||
)
|
||||
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
||||
from django.forms.fields import CharField
|
||||
from django.http import QueryDict
|
||||
|
|
@ -1415,6 +1421,7 @@ class TestCase(TransactionTestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
async_connections._from_testcase = True
|
||||
if not (
|
||||
cls._databases_support_transactions()
|
||||
and cls._databases_support_savepoints()
|
||||
|
|
|
|||
|
|
@ -211,7 +211,6 @@ Database backends
|
|||
* MySQL connections now default to using the ``utf8mb4`` character set,
|
||||
instead of ``utf8``, which is an alias for the deprecated character set
|
||||
``utf8mb3``.
|
||||
|
||||
* Oracle backends now support :ref:`connection pools <oracle-pool>`, by setting
|
||||
``"pool"`` in the :setting:`OPTIONS` part of your database configuration.
|
||||
|
||||
|
|
|
|||
|
|
@ -151,6 +151,19 @@ instance of those now-deprecated classes.
|
|||
Minor features
|
||||
--------------
|
||||
|
||||
Database backends
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
* It is now possible to perform asynchronous raw SQL queries using an async
|
||||
cursor.
|
||||
This is only possible on backends that support async-native connections.
|
||||
Currently only supported in PostreSQL with the
|
||||
``django.db.backends.postgresql`` backend.
|
||||
* It is now possible to perform asynchronous raw SQL queries using an async
|
||||
cursor, if the backend supports async-native connections. This is only
|
||||
supported on PostgreSQL with ``psycopg`` 3.1.8+. See
|
||||
:ref:`async-connection-cursor` for more details.
|
||||
|
||||
:mod:`django.contrib.admin`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -404,6 +404,37 @@ is equivalent to::
|
|||
finally:
|
||||
c.close()
|
||||
|
||||
.. _async-connection-cursor:
|
||||
|
||||
Async Connections and cursors
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. versionadded:: 6.0
|
||||
|
||||
On backends that support async-native connections, you can request an async
|
||||
cursor::
|
||||
|
||||
from django.db import new_connection
|
||||
|
||||
async with new_connection() as connection:
|
||||
async with connection.acursor() as c:
|
||||
await c.aexecute(...)
|
||||
|
||||
Async cursors provide the following methods:
|
||||
|
||||
* ``.aexecute()``
|
||||
* ``.aexecutemany()``
|
||||
* ``.afetchone()``
|
||||
* ``.afetchmany()``
|
||||
* ``.afetchall()``
|
||||
* ``.acopy()``
|
||||
* ``.astream()``
|
||||
* ``.ascroll()``
|
||||
|
||||
Currently, Django ships with the following async-enabled backend:
|
||||
|
||||
* ``django.db.backends.postgresql`` with ``psycopg3``.
|
||||
|
||||
Calling stored procedures
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
26
tests/async/test_async_connections.py
Normal file
26
tests/async/test_async_connections.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.db import new_connection
|
||||
from django.test import TransactionTestCase, skipUnlessDBFeature
|
||||
|
||||
from .models import SimpleModel
|
||||
|
||||
|
||||
@skipUnlessDBFeature("supports_async")
|
||||
class AsyncSyncCominglingTest(TransactionTestCase):
|
||||
|
||||
available_apps = ["async"]
|
||||
|
||||
async def change_model_with_async(self, obj):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute(
|
||||
"""UPDATE "async_simplemodel" SET "field" = 10,"""
|
||||
""" "created" = '2024-11-19 14:12:50.606384'::timestamp"""
|
||||
""" WHERE "async_simplemodel"."id" = 1"""
|
||||
)
|
||||
|
||||
async def test_transaction_async_comingling(self):
|
||||
s1 = await sync_to_async(SimpleModel.objects.create)(field=0)
|
||||
# with transaction.atomic():
|
||||
await self.change_model_with_async(s1)
|
||||
133
tests/async/test_async_cursor.py
Normal file
133
tests/async/test_async_cursor.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
from django.db import new_connection
|
||||
from django.test import SimpleTestCase, skipUnlessDBFeature
|
||||
|
||||
|
||||
@skipUnlessDBFeature("supports_async")
|
||||
class AsyncCursorTests(SimpleTestCase):
|
||||
databases = {"default", "other"}
|
||||
|
||||
async def test_aexecute(self):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute("SELECT 1")
|
||||
|
||||
async def test_aexecutemany(self):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute("CREATE TABLE numbers (number SMALLINT)")
|
||||
await cursor.aexecutemany(
|
||||
"INSERT INTO numbers VALUES (%s)", [(1,), (2,), (3,)]
|
||||
)
|
||||
await cursor.aexecute("SELECT * FROM numbers")
|
||||
result = await cursor.afetchall()
|
||||
self.assertEqual(result, [(1,), (2,), (3,)])
|
||||
|
||||
async def test_afetchone(self):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute("SELECT 1")
|
||||
result = await cursor.afetchone()
|
||||
self.assertEqual(result, (1,))
|
||||
|
||||
async def test_afetchmany(self):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM (VALUES
|
||||
('BANANA'),
|
||||
('STRAWBERRY'),
|
||||
('MELON')
|
||||
) AS v (NAME)"""
|
||||
)
|
||||
result = await cursor.afetchmany(size=2)
|
||||
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",)])
|
||||
|
||||
async def test_afetchall(self):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM (VALUES
|
||||
('BANANA'),
|
||||
('STRAWBERRY'),
|
||||
('MELON')
|
||||
) AS v (NAME)"""
|
||||
)
|
||||
result = await cursor.afetchall()
|
||||
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||
|
||||
async def test_aiter(self):
|
||||
result = []
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM (VALUES
|
||||
('BANANA'),
|
||||
('STRAWBERRY'),
|
||||
('MELON')
|
||||
) AS v (NAME)"""
|
||||
)
|
||||
async for record in cursor:
|
||||
result.append(record)
|
||||
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||
|
||||
async def test_acopy(self):
|
||||
result = []
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
async with cursor.acopy(
|
||||
"""
|
||||
COPY (
|
||||
SELECT *
|
||||
FROM (VALUES
|
||||
('BANANA'),
|
||||
('STRAWBERRY'),
|
||||
('MELON')
|
||||
) AS v (NAME)
|
||||
) TO STDOUT"""
|
||||
) as copy:
|
||||
async for row in copy.rows():
|
||||
result.append(row)
|
||||
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||
|
||||
async def test_astream(self):
|
||||
result = []
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
async for record in cursor.astream(
|
||||
"""
|
||||
SELECT *
|
||||
FROM (VALUES
|
||||
('BANANA'),
|
||||
('STRAWBERRY'),
|
||||
('MELON')
|
||||
) AS v (NAME)"""
|
||||
):
|
||||
result.append(record)
|
||||
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||
|
||||
async def test_ascroll(self):
|
||||
result = []
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM (VALUES
|
||||
('BANANA'),
|
||||
('STRAWBERRY'),
|
||||
('MELON')
|
||||
) AS v (NAME)"""
|
||||
)
|
||||
await cursor.ascroll(1, "absolute")
|
||||
|
||||
result = await cursor.afetchall()
|
||||
self.assertEqual(result, [("STRAWBERRY",), ("MELON",)])
|
||||
await cursor.ascroll(0, "absolute")
|
||||
result = await cursor.afetchall()
|
||||
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||
21
tests/backends/base/test_base_async.py
Normal file
21
tests/backends/base/test_base_async.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from django.db import new_connection
|
||||
from django.test import SimpleTestCase, skipUnlessDBFeature
|
||||
|
||||
|
||||
@skipUnlessDBFeature("supports_async")
|
||||
class AsyncDatabaseWrapperTests(SimpleTestCase):
|
||||
databases = {"default", "other"}
|
||||
|
||||
async def test_async_cursor(self):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
result = (await cursor.fetchone())[0]
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
async def test_async_cursor_alias(self):
|
||||
async with new_connection() as conn:
|
||||
async with conn.acursor() as cursor:
|
||||
await cursor.aexecute("SELECT 1")
|
||||
result = (await cursor.afetchone())[0]
|
||||
self.assertEqual(result, 1)
|
||||
|
|
@ -1,11 +1,26 @@
|
|||
"""Tests for django.db.utils."""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DEFAULT_DB_ALIAS, ProgrammingError, connection
|
||||
from django.db.utils import ConnectionHandler, load_backend
|
||||
from django.test import SimpleTestCase, TestCase
|
||||
from django.db import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
NotSupportedError,
|
||||
ProgrammingError,
|
||||
async_connections,
|
||||
connection,
|
||||
new_connection,
|
||||
)
|
||||
from django.db.utils import (
|
||||
AsyncAlias,
|
||||
AsyncConnectionHandler,
|
||||
ConnectionHandler,
|
||||
load_backend,
|
||||
)
|
||||
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
||||
from django.utils.connection import ConnectionDoesNotExist
|
||||
|
||||
|
||||
|
|
@ -90,3 +105,82 @@ class LoadBackendTests(SimpleTestCase):
|
|||
with self.assertRaisesMessage(ImproperlyConfigured, msg) as cm:
|
||||
load_backend("foo")
|
||||
self.assertEqual(str(cm.exception.__cause__), "No module named 'foo'")
|
||||
|
||||
|
||||
class AsyncConnectionTests(SimpleTestCase):
|
||||
databases = {"default", "other"}
|
||||
|
||||
def run_pool(self, coro, count=2):
|
||||
def fn():
|
||||
asyncio.run(coro())
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = []
|
||||
for _ in range(count):
|
||||
futures.append(executor.submit(fn))
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
exc = future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
def test_async_alias(self):
|
||||
alias = AsyncAlias()
|
||||
assert len(alias) == 0
|
||||
assert alias.connections == []
|
||||
|
||||
async def coro():
|
||||
assert len(alias) == 0
|
||||
alias.add_connection(mock.Mock())
|
||||
alias.pop()
|
||||
|
||||
self.run_pool(coro)
|
||||
|
||||
def test_async_connection_handler(self):
|
||||
aconns = AsyncConnectionHandler()
|
||||
assert aconns.empty is True
|
||||
assert aconns["default"].connections == []
|
||||
|
||||
async def coro():
|
||||
assert aconns["default"].connections == []
|
||||
aconns.add_connection("default", mock.Mock())
|
||||
aconns.pop_connection("default")
|
||||
|
||||
self.run_pool(coro)
|
||||
|
||||
@skipUnlessDBFeature("supports_async")
|
||||
def test_new_connection_threading(self):
|
||||
async def coro():
|
||||
assert async_connections.empty is True
|
||||
async with new_connection() as connection:
|
||||
async with connection.acursor() as c:
|
||||
await c.execute("SELECT 1")
|
||||
|
||||
self.run_pool(coro)
|
||||
|
||||
@skipUnlessDBFeature("supports_async")
|
||||
async def test_new_connection(self):
|
||||
with self.assertRaises(ConnectionDoesNotExist):
|
||||
async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||
|
||||
async with new_connection():
|
||||
conn1 = async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||
self.assertIsNotNone(conn1.aconnection)
|
||||
async with new_connection():
|
||||
conn2 = async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||
self.assertIsNotNone(conn1.aconnection)
|
||||
self.assertIsNotNone(conn2.aconnection)
|
||||
self.assertNotEqual(conn1.aconnection, conn2.aconnection)
|
||||
|
||||
self.assertIsNotNone(conn1.aconnection)
|
||||
self.assertIsNone(conn2.aconnection)
|
||||
self.assertIsNone(conn1.aconnection)
|
||||
|
||||
with self.assertRaises(ConnectionDoesNotExist):
|
||||
async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||
|
||||
@skipUnlessDBFeature("supports_async")
|
||||
async def test_new_connection_on_sync(self):
|
||||
with self.assertRaises(NotSupportedError):
|
||||
async with new_connection():
|
||||
async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from django.db import (
|
|||
IntegrityError,
|
||||
OperationalError,
|
||||
connection,
|
||||
new_connection,
|
||||
transaction,
|
||||
)
|
||||
from django.test import (
|
||||
|
|
@ -586,3 +587,93 @@ class DurableTransactionTests(DurableTestsBase, TransactionTestCase):
|
|||
|
||||
class DurableTests(DurableTestsBase, TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@skipUnlessDBFeature("uses_savepoints", "supports_async")
|
||||
class AsyncTransactionTestCase(TransactionTestCase):
|
||||
available_apps = ["transactions"]
|
||||
|
||||
async def test_new_connection_nested(self):
|
||||
async with new_connection() as connection:
|
||||
async with new_connection() as connection2:
|
||||
await connection2.aset_autocommit(False)
|
||||
async with connection2.acursor() as cursor2:
|
||||
await cursor2.aexecute(
|
||||
"INSERT INTO transactions_reporter "
|
||||
"(first_name, last_name, email) "
|
||||
"VALUES (%s, %s, %s)",
|
||||
("Sarah", "Hatoff", ""),
|
||||
)
|
||||
await cursor2.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor2.afetchmany()
|
||||
assert len(result) == 1
|
||||
|
||||
async with connection.acursor() as cursor:
|
||||
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor.afetchmany()
|
||||
assert len(result) == 1
|
||||
|
||||
async def test_new_connection_nested2(self):
|
||||
async with new_connection() as connection:
|
||||
async with connection.acursor() as cursor:
|
||||
await cursor.aexecute(
|
||||
"INSERT INTO transactions_reporter (first_name, last_name, email) "
|
||||
"VALUES (%s, %s, %s)",
|
||||
("Sarah", "Hatoff", ""),
|
||||
)
|
||||
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor.afetchmany()
|
||||
assert len(result) == 1
|
||||
|
||||
async with new_connection() as connection2:
|
||||
await connection2.aset_autocommit(False)
|
||||
async with connection2.acursor() as cursor2:
|
||||
await cursor2.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor2.afetchmany()
|
||||
# This connection won't see any rows, because the outer one
|
||||
# hasn't committed yet.
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_new_connection_nested3(self):
|
||||
async with new_connection() as connection:
|
||||
async with new_connection() as connection2:
|
||||
await connection2.aset_autocommit(False)
|
||||
assert id(connection) != id(connection2)
|
||||
async with connection2.acursor() as cursor2:
|
||||
await cursor2.aexecute(
|
||||
"INSERT INTO transactions_reporter "
|
||||
"(first_name, last_name, email) "
|
||||
"VALUES (%s, %s, %s)",
|
||||
("Sarah", "Hatoff", ""),
|
||||
)
|
||||
await cursor2.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor2.afetchmany()
|
||||
assert len(result) == 1
|
||||
|
||||
# Outermost connection doesn't see what the innermost did,
|
||||
# because the innermost connection hasn't exited yet.
|
||||
async with connection.acursor() as cursor:
|
||||
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor.afetchmany()
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_asavepoint(self):
|
||||
async with new_connection() as connection:
|
||||
async with connection.acursor() as cursor:
|
||||
sid = await connection.asavepoint()
|
||||
assert sid is not None
|
||||
|
||||
await cursor.aexecute(
|
||||
"INSERT INTO transactions_reporter (first_name, last_name, email) "
|
||||
"VALUES (%s, %s, %s)",
|
||||
("Archibald", "Haddock", ""),
|
||||
)
|
||||
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor.afetchmany(size=5)
|
||||
assert len(result) == 1
|
||||
assert result[0][1:] == ("Archibald", "Haddock", "")
|
||||
|
||||
await connection.asavepoint_rollback(sid)
|
||||
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||
result = await cursor.fetchmany(size=5)
|
||||
assert len(result) == 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue