This commit is contained in:
Flavio Curella 2025-11-17 17:42:53 +04:00 committed by GitHub
commit f667229336
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1319 additions and 44 deletions

View file

@ -2,6 +2,7 @@ from django.core import signals
from django.db.utils import (
DEFAULT_DB_ALIAS,
DJANGO_VERSION_PICKLE_KEY,
AsyncConnectionHandler,
ConnectionHandler,
ConnectionRouter,
DatabaseError,
@ -36,6 +37,50 @@ __all__ = [
]
connections = ConnectionHandler()
async_connections = AsyncConnectionHandler()
class new_connection:
"""
Asynchronous context manager to instantiate new async connections.
"""
def __init__(self, using=DEFAULT_DB_ALIAS):
self.using = using
async def __aenter__(self):
conn = connections.create_connection(self.using)
if conn.features.supports_async is False:
raise NotSupportedError(
"The database backend does not support asynchronous execution."
)
self.force_rollback = False
if async_connections.empty is True:
if async_connections._from_testcase is True:
self.force_rollback = True
self.conn = conn
async_connections.add_connection(self.using, self.conn)
await self.conn.aensure_connection()
if self.force_rollback is True:
await self.conn.aset_autocommit(False)
return self.conn
async def __aexit__(self, exc_type, exc_value, traceback):
autocommit = await self.conn.aget_autocommit()
if autocommit is False:
if exc_type is None and self.force_rollback is False:
await self.conn.acommit()
else:
await self.conn.arollback()
await self.conn.aclose()
await async_connections.pop_connection(self.using)
router = ConnectionRouter()

View file

@ -7,7 +7,7 @@ import time
import warnings
import zoneinfo
from collections import deque
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
@ -39,6 +39,7 @@ class BaseDatabaseWrapper:
ops = None
vendor = "unknown"
display_name = "unknown"
SchemaEditorClass = None
# Classes instantiated in __init__().
client_class = None
@ -47,6 +48,7 @@ class BaseDatabaseWrapper:
introspection_class = None
ops_class = None
validation_class = BaseDatabaseValidation
_aconnection_pools = {}
queries_limit = 9000
@ -54,6 +56,7 @@ class BaseDatabaseWrapper:
# Connection related attributes.
# The underlying database connection.
self.connection = None
self.aconnection = None
# `settings_dict` should be a dictionary containing keys such as
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
# to disambiguate it from Django settings modules.
@ -187,25 +190,44 @@ class BaseDatabaseWrapper:
"method."
)
async def aget_database_version(self):
"""Return a tuple of the database's version."""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require an aget_database_version() "
"method."
)
def _validate_database_version_supported(self, db_version):
if (
self.features.minimum_database_version is not None
and db_version < self.features.minimum_database_version
):
str_db_version = ".".join(map(str, db_version))
min_db_version = ".".join(map(str, self.features.minimum_database_version))
raise NotSupportedError(
f"{self.display_name} {min_db_version} or later is required "
f"(found {str_db_version})."
)
def check_database_version_supported(self):
"""
Raise an error if the database version isn't supported by this
version of Django.
"""
if (
self.features.minimum_database_version is not None
and self.get_database_version() < self.features.minimum_database_version
):
db_version = ".".join(map(str, self.get_database_version()))
min_db_version = ".".join(map(str, self.features.minimum_database_version))
raise NotSupportedError(
f"{self.display_name} {min_db_version} or later is required "
f"(found {db_version})."
)
db_version = self.get_database_version()
self._validate_database_version_supported(db_version)
async def acheck_database_version_supported(self):
"""
Raise an error if the database version isn't supported by this
version of Django.
"""
db_version = await self.aget_database_version()
self._validate_database_version_supported(db_version)
# ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self):
def get_connection_params(self, for_async=False):
"""Return a dict of parameters suitable for get_new_connection."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a get_connection_params() "
@ -219,23 +241,42 @@ class BaseDatabaseWrapper:
"method"
)
async def aget_new_connection(self, conn_params):
"""Open a connection to the database."""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require an aget_new_connection() "
"method"
)
def init_connection_state(self):
"""Initialize the database connection settings."""
if self.alias not in RAN_DB_VERSION_CHECK:
self.check_database_version_supported()
RAN_DB_VERSION_CHECK.add(self.alias)
async def ainit_connection_state(self):
"""Initialize the database connection settings."""
if self.alias not in RAN_DB_VERSION_CHECK:
await self.acheck_database_version_supported()
RAN_DB_VERSION_CHECK.add(self.alias)
def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a create_cursor() method"
)
def create_async_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require a "
"create_async_cursor() method"
)
# ##### Backend-specific methods for creating connections #####
@async_unsafe
def connect(self):
"""Connect to the database. Assume that the connection is closed."""
@contextmanager
def connect_manager(self):
# Check for invalid configurations.
self.check_settings()
# In case the previous connection was closed while in an atomic block
@ -251,14 +292,30 @@ class BaseDatabaseWrapper:
self.errors_occurred = False
# New connections are healthy.
self.health_check_done = True
# Establish the connection
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = []
try:
yield
finally:
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = []
@async_unsafe
def connect(self):
"""Connect to the database. Assume that the connection is closed."""
with self.connect_manager():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
self.init_connection_state()
async def aconnect(self):
"""Connect to the database. Assume that the connection is closed."""
with self.connect_manager():
# Establish the connection
conn_params = self.get_connection_params(for_async=True)
self.aconnection = await self.aget_new_connection(conn_params)
await self.aset_autocommit(self.settings_dict["AUTOCOMMIT"])
await self.ainit_connection_state()
def check_settings(self):
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
@ -278,6 +335,16 @@ class BaseDatabaseWrapper:
with self.wrap_database_errors:
self.connect()
async def aensure_connection(self):
"""Guarantee that a connection to the database is established."""
if self.aconnection is None:
if self.in_atomic_block and self.closed_in_transaction:
raise ProgrammingError(
"Cannot open a new connection in an atomic block."
)
with self.wrap_database_errors:
await self.aconnect()
# ##### Backend-specific wrappers for PEP-249 connection methods #####
def _prepare_cursor(self, cursor):
@ -291,27 +358,57 @@ class BaseDatabaseWrapper:
wrapped_cursor = self.make_cursor(cursor)
return wrapped_cursor
def _aprepare_cursor(self, cursor):
"""
Validate the connection is usable and perform database cursor wrapping.
"""
self.validate_thread_sharing()
if self.queries_logged:
wrapped_cursor = self.make_debug_async_cursor(cursor)
else:
wrapped_cursor = self.make_async_cursor(cursor)
return wrapped_cursor
def _cursor(self, name=None):
self.close_if_health_check_failed()
self.ensure_connection()
with self.wrap_database_errors:
return self._prepare_cursor(self.create_cursor(name))
def _acursor(self, name=None):
return utils.AsyncCursorCtx(self, name)
def _commit(self):
if self.connection is not None:
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
return self.connection.commit()
async def _acommit(self):
if self.aconnection is not None:
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
return await self.aconnection.commit()
def _rollback(self):
if self.connection is not None:
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
return self.connection.rollback()
async def _arollback(self):
if self.aconnection is not None:
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
return await self.aconnection.rollback()
def _close(self):
if self.connection is not None:
with self.wrap_database_errors:
return self.connection.close()
async def _aclose(self):
if self.aconnection is not None:
with self.wrap_database_errors:
return await self.aconnection.close()
# ##### Generic wrappers for PEP-249 connection methods #####
@async_unsafe
@ -319,6 +416,10 @@ class BaseDatabaseWrapper:
"""Create a cursor, opening a connection if necessary."""
return self._cursor()
def acursor(self):
"""Create an async cursor, opening a connection if necessary."""
return self._acursor()
@async_unsafe
def commit(self):
"""Commit a transaction and reset the dirty flag."""
@ -329,6 +430,15 @@ class BaseDatabaseWrapper:
self.errors_occurred = False
self.run_commit_hooks_on_set_autocommit_on = True
async def acommit(self):
"""Commit a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
await self._acommit()
# A successful commit means that the database connection works.
self.errors_occurred = False
self.run_commit_hooks_on_set_autocommit_on = True
@async_unsafe
def rollback(self):
"""Roll back a transaction and reset the dirty flag."""
@ -340,6 +450,16 @@ class BaseDatabaseWrapper:
self.needs_rollback = False
self.run_on_commit = []
async def arollback(self):
"""Roll back a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
await self._arollback()
# A successful rollback means that the database connection works.
self.errors_occurred = False
self.needs_rollback = False
self.run_on_commit = []
@async_unsafe
def close(self):
"""Close the connection to the database."""
@ -360,24 +480,59 @@ class BaseDatabaseWrapper:
else:
self.connection = None
async def aclose(self):
"""Close the connection to the database."""
self.validate_thread_sharing()
self.run_on_commit = []
# Don't call validate_no_atomic_block() to avoid making it difficult
# to get rid of a connection in an invalid state. The next connect()
# will reset the transaction state anyway.
if self.closed_in_transaction or self.aconnection is None:
return
try:
await self._aclose()
finally:
if self.in_atomic_block:
self.closed_in_transaction = True
self.needs_rollback = True
else:
self.aconnection = None
# ##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
async def _asavepoint(self, sid):
async with self.acursor() as cursor:
await cursor.aexecute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
async def _asavepoint_rollback(self, sid):
async with self.acursor() as cursor:
await cursor.aexecute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
async def _asavepoint_commit(self, sid):
async with self.acursor() as cursor:
await cursor.aexecute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
return self.features.uses_savepoints and not self.get_autocommit()
async def _asavepoint_allowed(self):
# Savepoints cannot be created outside a transaction
return self.features.uses_savepoints and not (await self.aget_autocommit())
# ##### Generic savepoint management methods #####
@async_unsafe
@ -401,6 +556,26 @@ class BaseDatabaseWrapper:
return sid
async def asavepoint(self):
"""
Create a savepoint inside the current transaction. Return an
identifier for the savepoint that will be used for the subsequent
rollback or commit. Do nothing if savepoints are not supported.
"""
if not (await self._asavepoint_allowed()):
return
thread_ident = _thread.get_ident()
tid = str(thread_ident).replace("-", "")
self.savepoint_state += 1
sid = "s%s_x%d" % (tid, self.savepoint_state)
self.validate_thread_sharing()
await self._asavepoint(sid)
return sid
@async_unsafe
def savepoint_rollback(self, sid):
"""
@ -419,6 +594,23 @@ class BaseDatabaseWrapper:
if sid not in sids
]
async def asavepoint_rollback(self, sid):
"""
Roll back to a savepoint. Do nothing if savepoints are not supported.
"""
if not (await self._asavepoint_allowed()):
return
self.validate_thread_sharing()
await self._asavepoint_rollback(sid)
# Remove any callbacks registered while this savepoint was active.
self.run_on_commit = [
(sids, func, robust)
for (sids, func, robust) in self.run_on_commit
if sid not in sids
]
@async_unsafe
def savepoint_commit(self, sid):
"""
@ -430,6 +622,16 @@ class BaseDatabaseWrapper:
self.validate_thread_sharing()
self._savepoint_commit(sid)
async def asavepoint_commit(self, sid):
"""
Release a savepoint. Do nothing if savepoints are not supported.
"""
if not (await self._asavepoint_allowed()):
return
self.validate_thread_sharing()
await self._asavepoint_commit(sid)
@async_unsafe
def clean_savepoints(self):
"""
@ -447,6 +649,14 @@ class BaseDatabaseWrapper:
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
)
async def _aset_autocommit(self, autocommit):
"""
Backend-specific implementation to enable or disable autocommit.
"""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require an _aset_autocommit() method"
)
# ##### Generic transaction management methods #####
def get_autocommit(self):
@ -454,6 +664,11 @@ class BaseDatabaseWrapper:
self.ensure_connection()
return self.autocommit
async def aget_autocommit(self):
"""Get the autocommit state."""
await self.aensure_connection()
return self.autocommit
def set_autocommit(
self, autocommit, force_begin_transaction_with_broken_autocommit=False
):
@ -491,6 +706,43 @@ class BaseDatabaseWrapper:
self.run_and_clear_commit_hooks()
self.run_commit_hooks_on_set_autocommit_on = False
async def aset_autocommit(
self, autocommit, force_begin_transaction_with_broken_autocommit=False
):
"""
Enable or disable autocommit.
The usual way to start a transaction is to turn autocommit off.
SQLite does not properly start a transaction when disabling
autocommit. To avoid this buggy behavior and to actually enter a new
transaction, an explicit BEGIN is required. Using
force_begin_transaction_with_broken_autocommit=True will issue an
explicit BEGIN with SQLite. This option will be ignored for other
backends.
"""
self.validate_no_atomic_block()
await self.aclose_if_health_check_failed()
await self.aensure_connection()
start_transaction_under_autocommit = (
force_begin_transaction_with_broken_autocommit
and not autocommit
and hasattr(self, "_astart_transaction_under_autocommit")
)
if start_transaction_under_autocommit:
await self._astart_transaction_under_autocommit()
elif autocommit:
await self._aset_autocommit(autocommit)
else:
with debug_transaction(self, "BEGIN"):
await self._aset_autocommit(autocommit)
self.autocommit = autocommit
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
self.run_and_clear_commit_hooks()
self.run_commit_hooks_on_set_autocommit_on = False
def get_rollback(self):
"""Get the "needs rollback" flag -- for *advanced use* only."""
if not self.in_atomic_block:
@ -575,6 +827,19 @@ class BaseDatabaseWrapper:
"subclasses of BaseDatabaseWrapper may require an is_usable() method"
)
async def ais_usable(self):
"""
Test if the database connection is usable.
This method may assume that self.connection is not None.
Actual implementations should take care not to raise exceptions
as that may prevent Django from recycling unusable connections.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require an ais_usable() method"
)
def close_if_health_check_failed(self):
"""Close existing connection if it fails a health check."""
if (
@ -588,6 +853,20 @@ class BaseDatabaseWrapper:
self.close()
self.health_check_done = True
async def aclose_if_health_check_failed(self):
"""Close existing connection if it fails a health check."""
if (
self.aconnection is None
or not self.health_check_enabled
or self.health_check_done
):
return
is_usable = await self.ais_usable()
if not is_usable:
await self.aclose()
self.health_check_done = True
def close_if_unusable_or_obsolete(self):
"""
Close the current connection if unrecoverable errors have occurred
@ -677,10 +956,18 @@ class BaseDatabaseWrapper:
"""Create a cursor that logs all queries in self.queries_log."""
return utils.CursorDebugWrapper(cursor, self)
def make_debug_async_cursor(self, cursor):
"""Create a cursor that logs all queries in self.queries_log."""
return utils.AsyncCursorDebugWrapper(cursor, self)
def make_cursor(self, cursor):
"""Create a cursor without debug logging."""
return utils.CursorWrapper(cursor, self)
def make_async_cursor(self, cursor):
"""Create a cursor without debug logging."""
return utils.AsyncCursorWrapper(cursor, self)
@contextmanager
def temporary_connection(self):
"""
@ -698,6 +985,27 @@ class BaseDatabaseWrapper:
if must_close:
self.close()
@asynccontextmanager
async def atemporary_connection(self):
"""
Context manager that ensures that a connection is established, and
if it opened one, closes it to avoid leaving a dangling connection.
This is useful for operations outside of the request-response cycle.
Provide a cursor::
async with self.atemporary_connection() as cursor:
...
"""
# unused
must_close = self.aconnection is None
try:
async with self.acursor() as cursor:
yield cursor
finally:
if must_close:
await self.aclose()
@contextmanager
def _nodb_cursor(self):
"""

View file

@ -358,6 +358,9 @@ class BaseDatabaseFeatures:
# Does the backend support negative JSON array indexing?
supports_json_negative_indexing = True
# Asynchronous database operations
supports_async = False
# Does the backend support column collations?
supports_collation_on_charfield = True
supports_collation_on_textfield = True

View file

@ -15,6 +15,9 @@ from django.core.exceptions import ImproperlyConfigured
from django.db import DatabaseError as WrappedDatabaseError
from django.db import connections
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
from django.db.backends.utils import (
AsyncCursorDebugWrapper as AsyncBaseCursorDebugWrapper,
)
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
@ -98,6 +101,8 @@ def _get_decimal_column(data):
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "postgresql"
display_name = "PostgreSQL"
_pg_version = None
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings;
# they'll be interpolated against the values of Field.__dict__ before being
@ -231,11 +236,57 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return self._connection_pools[self.alias]
@property
def apool(self):
pool_options = self.settings_dict["OPTIONS"].get("pool")
if self.alias == NO_DB_ALIAS or not pool_options:
return None
if self.alias not in self._aconnection_pools:
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
raise ImproperlyConfigured(
"Pooling doesn't support persistent connections."
)
# Set the default options.
if pool_options is True:
pool_options = {}
try:
from psycopg_pool import AsyncConnectionPool
except ImportError as err:
raise ImproperlyConfigured(
"Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
) from err
connect_kwargs = self.get_connection_params(for_async=True)
# Ensure we run in autocommit, Django properly sets it later on.
connect_kwargs["autocommit"] = True
enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
pool = AsyncConnectionPool(
kwargs=connect_kwargs,
open=False, # Do not open the pool during startup.
configure=self._aconfigure_connection,
check=AsyncConnectionPool.check_connection if enable_checks else None,
**pool_options,
)
# setdefault() ensures that multiple threads don't set this in
# parallel. Since we do not open the pool during it's init above,
# this means that at worst during startup multiple threads generate
# pool objects and the first to set it wins.
self._aconnection_pools.setdefault(self.alias, pool)
return self._aconnection_pools[self.alias]
def close_pool(self):
if self.pool:
self.pool.close()
del self._connection_pools[self.alias]
async def aclose_pool(self):
if self.apool:
await self.apool.close()
del self._aconnection_pools[self.alias]
def get_database_version(self):
"""
Return a tuple of the database's version.
@ -243,7 +294,38 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"""
return divmod(self.pg_version, 10000)
def get_connection_params(self):
async def aget_database_version(self):
"""
Return a tuple of the database's version.
E.g. for pg_version 120004, return (12, 4).
"""
pg_version = await self.apg_version
return divmod(pg_version, 10000)
def _get_sync_cursor_factory(self, server_side_binding=None):
if is_psycopg3 and server_side_binding is True:
return ServerBindingCursor
else:
return Cursor
def _get_async_cursor_factory(self, server_side_binding=None):
if is_psycopg3 and server_side_binding is True:
return AsyncServerBindingCursor
else:
return AsyncCursor
def _get_cursor_factory(self, server_side_binding=None, for_async=False):
if for_async and not is_psycopg3:
raise ImproperlyConfigured(
"Django requires psycopg >= 3 for ORM async support."
)
if for_async:
return self._get_async_cursor_factory(server_side_binding)
else:
return self._get_sync_cursor_factory(server_side_binding)
def get_connection_params(self, for_async=False):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
@ -283,14 +365,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
server_side_binding = conn_params.pop("server_side_binding", None)
conn_params.setdefault(
"cursor_factory",
(
ServerBindingCursor
if is_psycopg3 and server_side_binding is True
else Cursor
),
cursor_factory = self._get_cursor_factory(
server_side_binding, for_async=for_async
)
conn_params.setdefault("cursor_factory", cursor_factory)
if settings_dict["USER"]:
conn_params["user"] = settings_dict["USER"]
if settings_dict["PASSWORD"]:
@ -310,8 +388,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
)
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
def _get_isolation_level(self):
# self.isolation_level must be set:
# - after connecting to the database in order to obtain the database's
# default when no value is explicitly specified in options.
@ -322,17 +399,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
try:
isolation_level_value = options["isolation_level"]
except KeyError:
self.isolation_level = IsolationLevel.READ_COMMITTED
isolation_level = IsolationLevel.READ_COMMITTED
else:
# Set the isolation level to the value from OPTIONS.
try:
self.isolation_level = IsolationLevel(isolation_level_value)
isolation_level = IsolationLevel(isolation_level_value)
set_isolation_level = True
except ValueError:
raise ImproperlyConfigured(
f"Invalid transaction isolation level {isolation_level_value} "
f"specified. Use one of the psycopg.IsolationLevel values."
)
return isolation_level, set_isolation_level
@async_unsafe
def get_new_connection(self, conn_params):
isolation_level, set_isolation_level = self._get_isolation_level()
self.isolation_level = isolation_level
if self.pool:
# If nothing else has opened the pool, open it now.
self.pool.open()
@ -340,7 +422,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else:
connection = self.Database.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = self.isolation_level
connection.isolation_level = isolation_level
if not is_psycopg3:
# Register dummy loads() to avoid a round trip from psycopg2's
# decode to json.dumps() to json.loads(), when using a custom
@ -350,6 +432,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
)
return connection
async def aget_new_connection(self, conn_params):
isolation_level, set_isolation_level = self._get_isolation_level()
self.isolation_level = isolation_level
if self.apool:
# If nothing else has opened the pool, open it now.
await self.apool.open()
connection = await self.apool.getconn()
else:
connection = await self.Database.AsyncConnection.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = isolation_level
return connection
def ensure_timezone(self):
# Close the pool so new connections pick up the correct timezone.
self.close_pool()
@ -357,6 +452,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False
return self._configure_timezone(self.connection)
async def aensure_timezone(self):
# Close the pool so new connections pick up the correct timezone.
await self.aclose_pool()
if self.connection is None:
return False
return await self._aconfigure_timezone(self.connection)
def _configure_timezone(self, connection):
conn_timezone_name = connection.info.parameter_status("TimeZone")
timezone_name = self.timezone_name
@ -366,6 +468,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return True
return False
async def _aconfigure_timezone(self, connection):
conn_timezone_name = connection.info.parameter_status("TimeZone")
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
async with connection.cursor() as cursor:
await cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
return True
return False
def _configure_role(self, connection):
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
with connection.cursor() as cursor:
@ -374,6 +485,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return True
return False
async def _aconfigure_role(self, connection):
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
async with connection.acursor() as cursor:
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
await cursor.aaexecute(sql)
return True
return False
def _configure_connection(self, connection):
# This function is called from init_connection_state and from the
# psycopg pool itself after a connection is opened.
@ -387,6 +506,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return commit_role or commit_tz
async def _aconfigure_connection(self, connection):
# This function is called from init_connection_state and from the
# psycopg pool itself after a connection is opened.
# Commit after setting the time zone.
commit_tz = await self._aconfigure_timezone(connection)
# Set the role on the connection. This is useful if the credential used
# to login is not the same as the role that owns database resources. As
# can be the case when using temporary or ephemeral credentials.
commit_role = await self._aconfigure_role(connection)
return commit_role or commit_tz
def _close(self):
if self.connection is not None:
# `wrap_database_errors` only works for `putconn` as long as there
@ -403,6 +535,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else:
return self.connection.close()
async def _aclose(self):
if self.aconnection is not None:
# `wrap_database_errors` only works for `putconn` as long as there
# is no `reset` function set in the pool because it is deferred
# into a thread and not directly executed.
with self.wrap_database_errors:
if self.apool:
# Ensure the correct pool is returned. This is a workaround
# for tests so a pool can be changed on setting changes
# (e.g. USE_TZ, TIME_ZONE).
await self.aconnection._pool.putconn(self.aconnection)
# Connection can no longer be used.
self.aconnection = None
else:
return await self.aconnection.close()
def init_connection_state(self):
super().init_connection_state()
@ -412,6 +560,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if commit and not self.get_autocommit():
self.connection.commit()
async def ainit_connection_state(self):
await super().ainit_connection_state()
if self.aconnection is not None and not self.apool:
commit = await self._aconfigure_connection(self.aconnection)
if commit:
autocommit = await self.aget_autocommit()
if not autocommit:
await self.aconnection.commit()
@async_unsafe
def create_cursor(self, name=None):
if name:
@ -447,6 +606,35 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
return cursor
def create_async_cursor(self, name=None):
if name:
if self.settings_dict["OPTIONS"].get("server_side_binding") is not True:
# psycopg >= 3 forces the usage of server-side bindings for
# named cursors so a specialized class that implements
# server-side cursors while performing client-side bindings
# must be used if `server_side_binding` is disabled (default).
cursor = AsyncServerSideCursor(
self.aconnection,
name=name,
scrollable=False,
withhold=self.aconnection.autocommit,
)
else:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
cursor = self.aconnection.cursor(
name, scrollable=False, withhold=self.aconnection.autocommit
)
else:
cursor = self.aconnection.cursor()
# Register the cursor timezone only if the connection disagrees, to
# avoid copying the adapter map.
tzloader = self.aconnection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
if self.timezone != tzloader.timezone:
register_tzloader(self.timezone, cursor)
return cursor
def tzinfo_factory(self, offset):
return self.timezone
@ -478,10 +666,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
)
)
async def achunked_cursor(self):
self._named_cursor_idx += 1
# Get the current async task
try:
current_task = asyncio.current_task()
except RuntimeError:
current_task = None
# Current task can be none even if the current_task call didn't error
if current_task:
task_ident = str(id(current_task))
else:
task_ident = "sync"
# Use that and the thread ident to get a unique name
return self._acursor(
name="_django_curs_%d_%s_%d"
% (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
self._named_cursor_idx,
)
)
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit = autocommit
async def _aset_autocommit(self, autocommit):
with self.wrap_database_errors:
await self.aconnection.set_autocommit(autocommit)
def check_constraints(self, table_names=None):
"""
Check constraints by setting them to immediate. Return them to deferred
@ -503,12 +718,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else:
return True
async def ais_usable(self):
if self.aconnection is None:
return False
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
async with self.aconnection.cursor() as cursor:
await cursor.execute("SELECT 1")
except Database.Error:
return False
else:
return True
def close_if_health_check_failed(self):
if self.pool:
# The pool only returns healthy connections.
return
return super().close_if_health_check_failed()
async def aclose_if_health_check_failed(self):
if self.apool:
# The pool only returns healthy connections.
return
return await super().aclose_if_health_check_failed()
@contextmanager
def _nodb_cursor(self):
cursor = None
@ -549,8 +782,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property
def pg_version(self):
with self.temporary_connection():
return self.connection.info.server_version
if self._pg_version is None:
with self.temporary_connection():
self._pg_version = self.connection.info.server_version
return self._pg_version
@cached_property
async def apg_version(self):
if self._pg_version is None:
async with self.atemporary_connection():
self._pg_version = self.aconnection.info.server_version
return self._pg_version
def make_debug_cursor(self, cursor):
return CursorDebugWrapper(cursor, self)
@ -607,6 +849,36 @@ if is_psycopg3:
with self.debug_sql(statement):
return self.cursor.copy(statement)
class AsyncServerBindingCursor(CursorMixin, Database.AsyncClientCursor):
pass
class AsyncCursor(CursorMixin, Database.AsyncClientCursor):
pass
class AsyncServerSideCursor(
CursorMixin,
Database.client_cursor.ClientCursorMixin,
Database.AsyncServerCursor,
):
"""
psycopg >= 3 forces the usage of server-side bindings when using named
cursors but the ORM doesn't yet support the systematic generation of
prepareable SQL (#20516).
ClientCursorMixin forces the usage of client-side bindings while
AsyncServerCursor implements the logic required to declare and scroll
through named cursors.
Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to
specify how parameters should be bound instead, which AsyncServerCursor
would inherit, but that's not the case.
"""
class AsyncCursorDebugWrapper(AsyncBaseCursorDebugWrapper):
def copy(self, statement):
with self.debug_sql(statement):
return self.cursor.copy(statement)
else:
Cursor = psycopg2.extensions.cursor

View file

@ -54,6 +54,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
END;
$$ LANGUAGE plpgsql;"""
requires_casted_case_in_updates = True
supports_async = is_psycopg3
supports_over_clause = True
supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True

View file

@ -114,6 +114,98 @@ class CursorWrapper:
return self.cursor.executemany(sql, param_list)
class AsyncCursorCtx:
"""
Asynchronous context manager to hold an async cursor.
"""
def __init__(self, db, name=None):
self.db = db
self.name = name
self.wrap_database_errors = self.db.wrap_database_errors
async def __aenter__(self):
await self.db.aclose_if_health_check_failed()
await self.db.aensure_connection()
self.wrap_database_errors.__enter__()
return self.db._aprepare_cursor(self.db.create_async_cursor(self.name))
async def __aexit__(self, type, value, traceback):
self.wrap_database_errors.__exit__(type, value, traceback)
class AsyncCursorWrapper(CursorWrapper):
async def _aexecute(self, sql, params, *ignored_wrapper_args):
# Raise a warning during app initialization (stored_app_configs is only
# ever set during testing).
if not apps.ready and not apps.stored_app_configs:
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None:
# params default might be backend specific.
return await self.cursor.execute(sql)
else:
return await self.cursor.execute(sql, params)
async def _aexecute_with_wrappers(self, sql, params, many, executor):
context = {"connection": self.db, "cursor": self}
for wrapper in reversed(self.db.execute_wrappers):
executor = functools.partial(wrapper, executor)
return await executor(sql, params, many, context)
async def aexecute(self, sql, params=None):
return await self._aexecute_with_wrappers(
sql, params, many=False, executor=self._aexecute
)
async def _aexecutemany(self, sql, param_list, *ignored_wrapper_args):
# Raise a warning during app initialization (stored_app_configs is only
# ever set during testing).
if not apps.ready and not apps.stored_app_configs:
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
return await self.cursor.executemany(sql, param_list)
async def aexecutemany(self, sql, param_list):
return await self._aexecute_with_wrappers(
sql, param_list, many=True, executor=self._aexecutemany
)
async def afetchone(self, *args, **kwargs):
return await self.cursor.fetchone(*args, **kwargs)
async def afetchmany(self, *args, **kwargs):
return await self.cursor.fetchmany(*args, **kwargs)
async def afetchall(self, *args, **kwargs):
return await self.cursor.fetchall(*args, **kwargs)
def acopy(self, *args, **kwargs):
return self.cursor.copy(*args, **kwargs)
def astream(self, *args, **kwargs):
return self.cursor.stream(*args, **kwargs)
async def ascroll(self, *args, **kwargs):
return await self.cursor.scroll(*args, **kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, type, value, traceback):
try:
await self.close()
except self.db.Database.Error:
pass
async def __aiter__(self):
with self.db.wrap_database_errors:
async for item in self.cursor:
yield item
class CursorDebugWrapper(CursorWrapper):
# XXX callproc isn't instrumented at this time.
@ -163,6 +255,57 @@ class CursorDebugWrapper(CursorWrapper):
)
class AsyncCursorDebugWrapper(AsyncCursorWrapper):
# XXX callproc isn't instrumented at this time.
async def aexecute(self, sql, params=None):
with self.debug_sql(sql, params, use_last_executed_query=True):
return await super().aexecute(sql, params)
async def aexecutemany(self, sql, param_list):
with self.debug_sql(sql, param_list, many=True):
return await super().aexecutemany(sql, param_list)
@contextmanager
def debug_sql(
self, sql=None, params=None, use_last_executed_query=False, many=False
):
start = time.monotonic()
try:
yield
finally:
stop = time.monotonic()
duration = stop - start
if use_last_executed_query:
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
try:
times = len(params) if many else ""
except TypeError:
# params could be an iterator.
times = "?"
self.db.queries_log.append(
{
"sql": "%s times: %s" % (times, sql) if many else sql,
"time": "%.3f" % duration,
"async": True,
}
)
logger.debug(
"(%.3f) %s; args=%s; alias=%s; async=True",
duration,
sql,
params,
self.db.alias,
extra={
"duration": duration,
"sql": sql,
"params": params,
"alias": self.db.alias,
"async": True,
},
)
@contextmanager
def debug_transaction(connection, sql):
start = time.monotonic()
@ -176,18 +319,21 @@ def debug_transaction(connection, sql):
{
"sql": "%s" % sql,
"time": "%.3f" % duration,
"async": connection.features.supports_async,
}
)
logger.debug(
"(%.3f) %s; args=%s; alias=%s",
"(%.3f) %s; args=%s; alias=%s; async=%s",
duration,
sql,
None,
connection.alias,
connection.features.supports_async,
extra={
"duration": duration,
"sql": sql,
"alias": connection.alias,
"async": connection.features.supports_async,
},
)

View file

@ -1,6 +1,8 @@
import pkgutil
from importlib import import_module
from asgiref.local import Local
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
@ -197,6 +199,89 @@ class ConnectionHandler(BaseConnectionHandler):
return backend.DatabaseWrapper(db, alias)
class AsyncAlias:
"""
A Context-aware list of connections.
"""
def __init__(self) -> None:
self._connections = Local()
setattr(self._connections, "_stack", [])
@property
def connections(self):
return getattr(self._connections, "_stack", [])
def __len__(self):
return len(self.connections)
def __iter__(self):
return iter(self.connections)
def __str__(self):
return ", ".join([str(id(conn)) for conn in self.connections])
def __repr__(self):
return f"<{self.__class__.__name__}: {len(self.connections)} connections>"
def add_connection(self, connection):
setattr(self._connections, "_stack", self.connections + [connection])
def pop(self):
conns = self.connections
conns.pop()
setattr(self._connections, "_stack", conns)
class AsyncConnectionHandler:
"""
Context-aware class to store async connections, mapped by alias name.
"""
_from_testcase = False
def __init__(self) -> None:
self._aliases = Local()
self._connection_count = Local()
setattr(self._connection_count, "value", 0)
def __getitem__(self, alias):
try:
async_alias = getattr(self._aliases, alias)
except AttributeError:
async_alias = AsyncAlias()
setattr(self._aliases, alias, async_alias)
return async_alias
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.count} connections>"
@property
def count(self):
return getattr(self._connection_count, "value", 0)
@property
def empty(self):
return self.count == 0
def add_connection(self, using, connection):
self[using].add_connection(connection)
setattr(self._connection_count, "value", self.count + 1)
async def pop_connection(self, using):
await self[using].connections[-1].aclose_pool()
self[using].connections.pop()
setattr(self._connection_count, "value", self.count - 1)
def get_connection(self, using):
alias = self[using]
if len(alias.connections) == 0:
raise ConnectionDoesNotExist(
f"There are no async connections using the '{using}' alias."
)
return alias.connections[-1]
class ConnectionRouter:
def __init__(self, routers=None):
"""

View file

@ -38,7 +38,13 @@ from django.core.management.color import no_style
from django.core.management.sql import emit_post_migrate_signal
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
from django.core.signals import setting_changed
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
from django.db import (
DEFAULT_DB_ALIAS,
async_connections,
connection,
connections,
transaction,
)
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
from django.forms.fields import CharField
from django.http import QueryDict
@ -1415,6 +1421,7 @@ class TestCase(TransactionTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
async_connections._from_testcase = True
if not (
cls._databases_support_transactions()
and cls._databases_support_savepoints()

View file

@ -211,7 +211,6 @@ Database backends
* MySQL connections now default to using the ``utf8mb4`` character set,
instead of ``utf8``, which is an alias for the deprecated character set
``utf8mb3``.
* Oracle backends now support :ref:`connection pools <oracle-pool>`, by setting
``"pool"`` in the :setting:`OPTIONS` part of your database configuration.

View file

@ -151,6 +151,19 @@ instance of those now-deprecated classes.
Minor features
--------------
Database backends
~~~~~~~~~~~~~~~~~
* It is now possible to perform asynchronous raw SQL queries using an async
cursor.
This is only possible on backends that support async-native connections.
Currently only supported in PostreSQL with the
``django.db.backends.postgresql`` backend.
* It is now possible to perform asynchronous raw SQL queries using an async
cursor, if the backend supports async-native connections. This is only
supported on PostgreSQL with ``psycopg`` 3.1.8+. See
:ref:`async-connection-cursor` for more details.
:mod:`django.contrib.admin`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -404,6 +404,37 @@ is equivalent to::
finally:
c.close()
.. _async-connection-cursor:
Async Connections and cursors
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. versionadded:: 6.0
On backends that support async-native connections, you can request an async
cursor::
from django.db import new_connection
async with new_connection() as connection:
async with connection.acursor() as c:
await c.aexecute(...)
Async cursors provide the following methods:
* ``.aexecute()``
* ``.aexecutemany()``
* ``.afetchone()``
* ``.afetchmany()``
* ``.afetchall()``
* ``.acopy()``
* ``.astream()``
* ``.ascroll()``
Currently, Django ships with the following async-enabled backend:
* ``django.db.backends.postgresql`` with ``psycopg3``.
Calling stored procedures
~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -0,0 +1,26 @@
from asgiref.sync import sync_to_async
from django.db import new_connection
from django.test import TransactionTestCase, skipUnlessDBFeature
from .models import SimpleModel
@skipUnlessDBFeature("supports_async")
class AsyncSyncCominglingTest(TransactionTestCase):
available_apps = ["async"]
async def change_model_with_async(self, obj):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""UPDATE "async_simplemodel" SET "field" = 10,"""
""" "created" = '2024-11-19 14:12:50.606384'::timestamp"""
""" WHERE "async_simplemodel"."id" = 1"""
)
async def test_transaction_async_comingling(self):
s1 = await sync_to_async(SimpleModel.objects.create)(field=0)
# with transaction.atomic():
await self.change_model_with_async(s1)

View file

@ -0,0 +1,133 @@
from django.db import new_connection
from django.test import SimpleTestCase, skipUnlessDBFeature
@skipUnlessDBFeature("supports_async")
class AsyncCursorTests(SimpleTestCase):
databases = {"default", "other"}
async def test_aexecute(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("SELECT 1")
async def test_aexecutemany(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("CREATE TABLE numbers (number SMALLINT)")
await cursor.aexecutemany(
"INSERT INTO numbers VALUES (%s)", [(1,), (2,), (3,)]
)
await cursor.aexecute("SELECT * FROM numbers")
result = await cursor.afetchall()
self.assertEqual(result, [(1,), (2,), (3,)])
async def test_afetchone(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("SELECT 1")
result = await cursor.afetchone()
self.assertEqual(result, (1,))
async def test_afetchmany(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
result = await cursor.afetchmany(size=2)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",)])
async def test_afetchall(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
result = await cursor.afetchall()
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_aiter(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
async for record in cursor:
result.append(record)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_acopy(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
async with cursor.acopy(
"""
COPY (
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)
) TO STDOUT"""
) as copy:
async for row in copy.rows():
result.append(row)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_astream(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
async for record in cursor.astream(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
):
result.append(record)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_ascroll(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
await cursor.ascroll(1, "absolute")
result = await cursor.afetchall()
self.assertEqual(result, [("STRAWBERRY",), ("MELON",)])
await cursor.ascroll(0, "absolute")
result = await cursor.afetchall()
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])

View file

@ -0,0 +1,21 @@
from django.db import new_connection
from django.test import SimpleTestCase, skipUnlessDBFeature
@skipUnlessDBFeature("supports_async")
class AsyncDatabaseWrapperTests(SimpleTestCase):
databases = {"default", "other"}
async def test_async_cursor(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.execute("SELECT 1")
result = (await cursor.fetchone())[0]
self.assertEqual(result, 1)
async def test_async_cursor_alias(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("SELECT 1")
result = (await cursor.afetchone())[0]
self.assertEqual(result, 1)

View file

@ -1,11 +1,26 @@
"""Tests for django.db.utils."""
import asyncio
import concurrent.futures
import unittest
from unittest import mock
from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, ProgrammingError, connection
from django.db.utils import ConnectionHandler, load_backend
from django.test import SimpleTestCase, TestCase
from django.db import (
DEFAULT_DB_ALIAS,
NotSupportedError,
ProgrammingError,
async_connections,
connection,
new_connection,
)
from django.db.utils import (
AsyncAlias,
AsyncConnectionHandler,
ConnectionHandler,
load_backend,
)
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.utils.connection import ConnectionDoesNotExist
@ -90,3 +105,82 @@ class LoadBackendTests(SimpleTestCase):
with self.assertRaisesMessage(ImproperlyConfigured, msg) as cm:
load_backend("foo")
self.assertEqual(str(cm.exception.__cause__), "No module named 'foo'")
class AsyncConnectionTests(SimpleTestCase):
databases = {"default", "other"}
def run_pool(self, coro, count=2):
def fn():
asyncio.run(coro())
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = []
for _ in range(count):
futures.append(executor.submit(fn))
for future in concurrent.futures.as_completed(futures):
exc = future.exception()
if exc is not None:
raise exc
def test_async_alias(self):
alias = AsyncAlias()
assert len(alias) == 0
assert alias.connections == []
async def coro():
assert len(alias) == 0
alias.add_connection(mock.Mock())
alias.pop()
self.run_pool(coro)
def test_async_connection_handler(self):
aconns = AsyncConnectionHandler()
assert aconns.empty is True
assert aconns["default"].connections == []
async def coro():
assert aconns["default"].connections == []
aconns.add_connection("default", mock.Mock())
aconns.pop_connection("default")
self.run_pool(coro)
@skipUnlessDBFeature("supports_async")
def test_new_connection_threading(self):
async def coro():
assert async_connections.empty is True
async with new_connection() as connection:
async with connection.acursor() as c:
await c.execute("SELECT 1")
self.run_pool(coro)
@skipUnlessDBFeature("supports_async")
async def test_new_connection(self):
with self.assertRaises(ConnectionDoesNotExist):
async_connections.get_connection(DEFAULT_DB_ALIAS)
async with new_connection():
conn1 = async_connections.get_connection(DEFAULT_DB_ALIAS)
self.assertIsNotNone(conn1.aconnection)
async with new_connection():
conn2 = async_connections.get_connection(DEFAULT_DB_ALIAS)
self.assertIsNotNone(conn1.aconnection)
self.assertIsNotNone(conn2.aconnection)
self.assertNotEqual(conn1.aconnection, conn2.aconnection)
self.assertIsNotNone(conn1.aconnection)
self.assertIsNone(conn2.aconnection)
self.assertIsNone(conn1.aconnection)
with self.assertRaises(ConnectionDoesNotExist):
async_connections.get_connection(DEFAULT_DB_ALIAS)
@skipUnlessDBFeature("supports_async")
async def test_new_connection_on_sync(self):
with self.assertRaises(NotSupportedError):
async with new_connection():
async_connections.get_connection(DEFAULT_DB_ALIAS)

View file

@ -9,6 +9,7 @@ from django.db import (
IntegrityError,
OperationalError,
connection,
new_connection,
transaction,
)
from django.test import (
@ -586,3 +587,93 @@ class DurableTransactionTests(DurableTestsBase, TransactionTestCase):
class DurableTests(DurableTestsBase, TestCase):
pass
@skipUnlessDBFeature("uses_savepoints", "supports_async")
class AsyncTransactionTestCase(TransactionTestCase):
available_apps = ["transactions"]
async def test_new_connection_nested(self):
async with new_connection() as connection:
async with new_connection() as connection2:
await connection2.aset_autocommit(False)
async with connection2.acursor() as cursor2:
await cursor2.aexecute(
"INSERT INTO transactions_reporter "
"(first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Sarah", "Hatoff", ""),
)
await cursor2.aexecute("SELECT * FROM transactions_reporter")
result = await cursor2.afetchmany()
assert len(result) == 1
async with connection.acursor() as cursor:
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany()
assert len(result) == 1
async def test_new_connection_nested2(self):
async with new_connection() as connection:
async with connection.acursor() as cursor:
await cursor.aexecute(
"INSERT INTO transactions_reporter (first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Sarah", "Hatoff", ""),
)
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany()
assert len(result) == 1
async with new_connection() as connection2:
await connection2.aset_autocommit(False)
async with connection2.acursor() as cursor2:
await cursor2.aexecute("SELECT * FROM transactions_reporter")
result = await cursor2.afetchmany()
# This connection won't see any rows, because the outer one
# hasn't committed yet.
assert len(result) == 0
async def test_new_connection_nested3(self):
async with new_connection() as connection:
async with new_connection() as connection2:
await connection2.aset_autocommit(False)
assert id(connection) != id(connection2)
async with connection2.acursor() as cursor2:
await cursor2.aexecute(
"INSERT INTO transactions_reporter "
"(first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Sarah", "Hatoff", ""),
)
await cursor2.aexecute("SELECT * FROM transactions_reporter")
result = await cursor2.afetchmany()
assert len(result) == 1
# Outermost connection doesn't see what the innermost did,
# because the innermost connection hasn't exited yet.
async with connection.acursor() as cursor:
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany()
assert len(result) == 0
async def test_asavepoint(self):
async with new_connection() as connection:
async with connection.acursor() as cursor:
sid = await connection.asavepoint()
assert sid is not None
await cursor.aexecute(
"INSERT INTO transactions_reporter (first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Archibald", "Haddock", ""),
)
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany(size=5)
assert len(result) == 1
assert result[0][1:] == ("Archibald", "Haddock", "")
await connection.asavepoint_rollback(sid)
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.fetchmany(size=5)
assert len(result) == 0