This commit is contained in:
Flavio Curella 2025-11-17 17:42:53 +04:00 committed by GitHub
commit f667229336
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1319 additions and 44 deletions

View file

@ -2,6 +2,7 @@ from django.core import signals
from django.db.utils import ( from django.db.utils import (
DEFAULT_DB_ALIAS, DEFAULT_DB_ALIAS,
DJANGO_VERSION_PICKLE_KEY, DJANGO_VERSION_PICKLE_KEY,
AsyncConnectionHandler,
ConnectionHandler, ConnectionHandler,
ConnectionRouter, ConnectionRouter,
DatabaseError, DatabaseError,
@ -36,6 +37,50 @@ __all__ = [
] ]
connections = ConnectionHandler() connections = ConnectionHandler()
async_connections = AsyncConnectionHandler()
class new_connection:
"""
Asynchronous context manager to instantiate new async connections.
"""
def __init__(self, using=DEFAULT_DB_ALIAS):
self.using = using
async def __aenter__(self):
conn = connections.create_connection(self.using)
if conn.features.supports_async is False:
raise NotSupportedError(
"The database backend does not support asynchronous execution."
)
self.force_rollback = False
if async_connections.empty is True:
if async_connections._from_testcase is True:
self.force_rollback = True
self.conn = conn
async_connections.add_connection(self.using, self.conn)
await self.conn.aensure_connection()
if self.force_rollback is True:
await self.conn.aset_autocommit(False)
return self.conn
async def __aexit__(self, exc_type, exc_value, traceback):
autocommit = await self.conn.aget_autocommit()
if autocommit is False:
if exc_type is None and self.force_rollback is False:
await self.conn.acommit()
else:
await self.conn.arollback()
await self.conn.aclose()
await async_connections.pop_connection(self.using)
router = ConnectionRouter() router = ConnectionRouter()

View file

@ -7,7 +7,7 @@ import time
import warnings import warnings
import zoneinfo import zoneinfo
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import asynccontextmanager, contextmanager
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@ -39,6 +39,7 @@ class BaseDatabaseWrapper:
ops = None ops = None
vendor = "unknown" vendor = "unknown"
display_name = "unknown" display_name = "unknown"
SchemaEditorClass = None SchemaEditorClass = None
# Classes instantiated in __init__(). # Classes instantiated in __init__().
client_class = None client_class = None
@ -47,6 +48,7 @@ class BaseDatabaseWrapper:
introspection_class = None introspection_class = None
ops_class = None ops_class = None
validation_class = BaseDatabaseValidation validation_class = BaseDatabaseValidation
_aconnection_pools = {}
queries_limit = 9000 queries_limit = 9000
@ -54,6 +56,7 @@ class BaseDatabaseWrapper:
# Connection related attributes. # Connection related attributes.
# The underlying database connection. # The underlying database connection.
self.connection = None self.connection = None
self.aconnection = None
# `settings_dict` should be a dictionary containing keys such as # `settings_dict` should be a dictionary containing keys such as
# NAME, USER, etc. It's called `settings_dict` instead of `settings` # NAME, USER, etc. It's called `settings_dict` instead of `settings`
# to disambiguate it from Django settings modules. # to disambiguate it from Django settings modules.
@ -187,25 +190,44 @@ class BaseDatabaseWrapper:
"method." "method."
) )
async def aget_database_version(self):
"""Return a tuple of the database's version."""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require an aget_database_version() "
"method."
)
def _validate_database_version_supported(self, db_version):
if (
self.features.minimum_database_version is not None
and db_version < self.features.minimum_database_version
):
str_db_version = ".".join(map(str, db_version))
min_db_version = ".".join(map(str, self.features.minimum_database_version))
raise NotSupportedError(
f"{self.display_name} {min_db_version} or later is required "
f"(found {str_db_version})."
)
def check_database_version_supported(self): def check_database_version_supported(self):
""" """
Raise an error if the database version isn't supported by this Raise an error if the database version isn't supported by this
version of Django. version of Django.
""" """
if ( db_version = self.get_database_version()
self.features.minimum_database_version is not None self._validate_database_version_supported(db_version)
and self.get_database_version() < self.features.minimum_database_version
): async def acheck_database_version_supported(self):
db_version = ".".join(map(str, self.get_database_version())) """
min_db_version = ".".join(map(str, self.features.minimum_database_version)) Raise an error if the database version isn't supported by this
raise NotSupportedError( version of Django.
f"{self.display_name} {min_db_version} or later is required " """
f"(found {db_version})." db_version = await self.aget_database_version()
) self._validate_database_version_supported(db_version)
# ##### Backend-specific methods for creating connections and cursors ##### # ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self): def get_connection_params(self, for_async=False):
"""Return a dict of parameters suitable for get_new_connection.""" """Return a dict of parameters suitable for get_new_connection."""
raise NotImplementedError( raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a get_connection_params() " "subclasses of BaseDatabaseWrapper may require a get_connection_params() "
@ -219,23 +241,42 @@ class BaseDatabaseWrapper:
"method" "method"
) )
async def aget_new_connection(self, conn_params):
"""Open a connection to the database."""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require an aget_new_connection() "
"method"
)
def init_connection_state(self): def init_connection_state(self):
"""Initialize the database connection settings.""" """Initialize the database connection settings."""
if self.alias not in RAN_DB_VERSION_CHECK: if self.alias not in RAN_DB_VERSION_CHECK:
self.check_database_version_supported() self.check_database_version_supported()
RAN_DB_VERSION_CHECK.add(self.alias) RAN_DB_VERSION_CHECK.add(self.alias)
async def ainit_connection_state(self):
"""Initialize the database connection settings."""
if self.alias not in RAN_DB_VERSION_CHECK:
await self.acheck_database_version_supported()
RAN_DB_VERSION_CHECK.add(self.alias)
def create_cursor(self, name=None): def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established.""" """Create a cursor. Assume that a connection is established."""
raise NotImplementedError( raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a create_cursor() method" "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
) )
def create_async_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require a "
"create_async_cursor() method"
)
# ##### Backend-specific methods for creating connections ##### # ##### Backend-specific methods for creating connections #####
@async_unsafe @contextmanager
def connect(self): def connect_manager(self):
"""Connect to the database. Assume that the connection is closed."""
# Check for invalid configurations. # Check for invalid configurations.
self.check_settings() self.check_settings()
# In case the previous connection was closed while in an atomic block # In case the previous connection was closed while in an atomic block
@ -251,14 +292,30 @@ class BaseDatabaseWrapper:
self.errors_occurred = False self.errors_occurred = False
# New connections are healthy. # New connections are healthy.
self.health_check_done = True self.health_check_done = True
# Establish the connection
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = [] try:
yield
finally:
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = []
@async_unsafe
def connect(self):
"""Connect to the database. Assume that the connection is closed."""
with self.connect_manager():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
self.init_connection_state()
async def aconnect(self):
"""Connect to the database. Assume that the connection is closed."""
with self.connect_manager():
# Establish the connection
conn_params = self.get_connection_params(for_async=True)
self.aconnection = await self.aget_new_connection(conn_params)
await self.aset_autocommit(self.settings_dict["AUTOCOMMIT"])
await self.ainit_connection_state()
def check_settings(self): def check_settings(self):
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ: if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
@ -278,6 +335,16 @@ class BaseDatabaseWrapper:
with self.wrap_database_errors: with self.wrap_database_errors:
self.connect() self.connect()
async def aensure_connection(self):
"""Guarantee that a connection to the database is established."""
if self.aconnection is None:
if self.in_atomic_block and self.closed_in_transaction:
raise ProgrammingError(
"Cannot open a new connection in an atomic block."
)
with self.wrap_database_errors:
await self.aconnect()
# ##### Backend-specific wrappers for PEP-249 connection methods ##### # ##### Backend-specific wrappers for PEP-249 connection methods #####
def _prepare_cursor(self, cursor): def _prepare_cursor(self, cursor):
@ -291,27 +358,57 @@ class BaseDatabaseWrapper:
wrapped_cursor = self.make_cursor(cursor) wrapped_cursor = self.make_cursor(cursor)
return wrapped_cursor return wrapped_cursor
def _aprepare_cursor(self, cursor):
"""
Validate the connection is usable and perform database cursor wrapping.
"""
self.validate_thread_sharing()
if self.queries_logged:
wrapped_cursor = self.make_debug_async_cursor(cursor)
else:
wrapped_cursor = self.make_async_cursor(cursor)
return wrapped_cursor
def _cursor(self, name=None): def _cursor(self, name=None):
self.close_if_health_check_failed() self.close_if_health_check_failed()
self.ensure_connection() self.ensure_connection()
with self.wrap_database_errors: with self.wrap_database_errors:
return self._prepare_cursor(self.create_cursor(name)) return self._prepare_cursor(self.create_cursor(name))
def _acursor(self, name=None):
return utils.AsyncCursorCtx(self, name)
def _commit(self): def _commit(self):
if self.connection is not None: if self.connection is not None:
with debug_transaction(self, "COMMIT"), self.wrap_database_errors: with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
return self.connection.commit() return self.connection.commit()
async def _acommit(self):
if self.aconnection is not None:
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
return await self.aconnection.commit()
def _rollback(self): def _rollback(self):
if self.connection is not None: if self.connection is not None:
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors: with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
return self.connection.rollback() return self.connection.rollback()
async def _arollback(self):
if self.aconnection is not None:
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
return await self.aconnection.rollback()
def _close(self): def _close(self):
if self.connection is not None: if self.connection is not None:
with self.wrap_database_errors: with self.wrap_database_errors:
return self.connection.close() return self.connection.close()
async def _aclose(self):
if self.aconnection is not None:
with self.wrap_database_errors:
return await self.aconnection.close()
# ##### Generic wrappers for PEP-249 connection methods ##### # ##### Generic wrappers for PEP-249 connection methods #####
@async_unsafe @async_unsafe
@ -319,6 +416,10 @@ class BaseDatabaseWrapper:
"""Create a cursor, opening a connection if necessary.""" """Create a cursor, opening a connection if necessary."""
return self._cursor() return self._cursor()
def acursor(self):
"""Create an async cursor, opening a connection if necessary."""
return self._acursor()
@async_unsafe @async_unsafe
def commit(self): def commit(self):
"""Commit a transaction and reset the dirty flag.""" """Commit a transaction and reset the dirty flag."""
@ -329,6 +430,15 @@ class BaseDatabaseWrapper:
self.errors_occurred = False self.errors_occurred = False
self.run_commit_hooks_on_set_autocommit_on = True self.run_commit_hooks_on_set_autocommit_on = True
async def acommit(self):
"""Commit a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
await self._acommit()
# A successful commit means that the database connection works.
self.errors_occurred = False
self.run_commit_hooks_on_set_autocommit_on = True
@async_unsafe @async_unsafe
def rollback(self): def rollback(self):
"""Roll back a transaction and reset the dirty flag.""" """Roll back a transaction and reset the dirty flag."""
@ -340,6 +450,16 @@ class BaseDatabaseWrapper:
self.needs_rollback = False self.needs_rollback = False
self.run_on_commit = [] self.run_on_commit = []
async def arollback(self):
"""Roll back a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
await self._arollback()
# A successful rollback means that the database connection works.
self.errors_occurred = False
self.needs_rollback = False
self.run_on_commit = []
@async_unsafe @async_unsafe
def close(self): def close(self):
"""Close the connection to the database.""" """Close the connection to the database."""
@ -360,24 +480,59 @@ class BaseDatabaseWrapper:
else: else:
self.connection = None self.connection = None
async def aclose(self):
"""Close the connection to the database."""
self.validate_thread_sharing()
self.run_on_commit = []
# Don't call validate_no_atomic_block() to avoid making it difficult
# to get rid of a connection in an invalid state. The next connect()
# will reset the transaction state anyway.
if self.closed_in_transaction or self.aconnection is None:
return
try:
await self._aclose()
finally:
if self.in_atomic_block:
self.closed_in_transaction = True
self.needs_rollback = True
else:
self.aconnection = None
# ##### Backend-specific savepoint management methods ##### # ##### Backend-specific savepoint management methods #####
def _savepoint(self, sid): def _savepoint(self, sid):
with self.cursor() as cursor: with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid)) cursor.execute(self.ops.savepoint_create_sql(sid))
async def _asavepoint(self, sid):
async with self.acursor() as cursor:
await cursor.aexecute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid): def _savepoint_rollback(self, sid):
with self.cursor() as cursor: with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid)) cursor.execute(self.ops.savepoint_rollback_sql(sid))
async def _asavepoint_rollback(self, sid):
async with self.acursor() as cursor:
await cursor.aexecute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid): def _savepoint_commit(self, sid):
with self.cursor() as cursor: with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid)) cursor.execute(self.ops.savepoint_commit_sql(sid))
async def _asavepoint_commit(self, sid):
async with self.acursor() as cursor:
await cursor.aexecute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self): def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction # Savepoints cannot be created outside a transaction
return self.features.uses_savepoints and not self.get_autocommit() return self.features.uses_savepoints and not self.get_autocommit()
async def _asavepoint_allowed(self):
# Savepoints cannot be created outside a transaction
return self.features.uses_savepoints and not (await self.aget_autocommit())
# ##### Generic savepoint management methods ##### # ##### Generic savepoint management methods #####
@async_unsafe @async_unsafe
@ -401,6 +556,26 @@ class BaseDatabaseWrapper:
return sid return sid
async def asavepoint(self):
"""
Create a savepoint inside the current transaction. Return an
identifier for the savepoint that will be used for the subsequent
rollback or commit. Do nothing if savepoints are not supported.
"""
if not (await self._asavepoint_allowed()):
return
thread_ident = _thread.get_ident()
tid = str(thread_ident).replace("-", "")
self.savepoint_state += 1
sid = "s%s_x%d" % (tid, self.savepoint_state)
self.validate_thread_sharing()
await self._asavepoint(sid)
return sid
@async_unsafe @async_unsafe
def savepoint_rollback(self, sid): def savepoint_rollback(self, sid):
""" """
@ -419,6 +594,23 @@ class BaseDatabaseWrapper:
if sid not in sids if sid not in sids
] ]
async def asavepoint_rollback(self, sid):
"""
Roll back to a savepoint. Do nothing if savepoints are not supported.
"""
if not (await self._asavepoint_allowed()):
return
self.validate_thread_sharing()
await self._asavepoint_rollback(sid)
# Remove any callbacks registered while this savepoint was active.
self.run_on_commit = [
(sids, func, robust)
for (sids, func, robust) in self.run_on_commit
if sid not in sids
]
@async_unsafe @async_unsafe
def savepoint_commit(self, sid): def savepoint_commit(self, sid):
""" """
@ -430,6 +622,16 @@ class BaseDatabaseWrapper:
self.validate_thread_sharing() self.validate_thread_sharing()
self._savepoint_commit(sid) self._savepoint_commit(sid)
async def asavepoint_commit(self, sid):
"""
Release a savepoint. Do nothing if savepoints are not supported.
"""
if not (await self._asavepoint_allowed()):
return
self.validate_thread_sharing()
await self._asavepoint_commit(sid)
@async_unsafe @async_unsafe
def clean_savepoints(self): def clean_savepoints(self):
""" """
@ -447,6 +649,14 @@ class BaseDatabaseWrapper:
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method" "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
) )
async def _aset_autocommit(self, autocommit):
"""
Backend-specific implementation to enable or disable autocommit.
"""
raise NotSupportedError(
"subclasses of BaseDatabaseWrapper may require an _aset_autocommit() method"
)
# ##### Generic transaction management methods ##### # ##### Generic transaction management methods #####
def get_autocommit(self): def get_autocommit(self):
@ -454,6 +664,11 @@ class BaseDatabaseWrapper:
self.ensure_connection() self.ensure_connection()
return self.autocommit return self.autocommit
async def aget_autocommit(self):
"""Get the autocommit state."""
await self.aensure_connection()
return self.autocommit
def set_autocommit( def set_autocommit(
self, autocommit, force_begin_transaction_with_broken_autocommit=False self, autocommit, force_begin_transaction_with_broken_autocommit=False
): ):
@ -491,6 +706,43 @@ class BaseDatabaseWrapper:
self.run_and_clear_commit_hooks() self.run_and_clear_commit_hooks()
self.run_commit_hooks_on_set_autocommit_on = False self.run_commit_hooks_on_set_autocommit_on = False
async def aset_autocommit(
self, autocommit, force_begin_transaction_with_broken_autocommit=False
):
"""
Enable or disable autocommit.
The usual way to start a transaction is to turn autocommit off.
SQLite does not properly start a transaction when disabling
autocommit. To avoid this buggy behavior and to actually enter a new
transaction, an explicit BEGIN is required. Using
force_begin_transaction_with_broken_autocommit=True will issue an
explicit BEGIN with SQLite. This option will be ignored for other
backends.
"""
self.validate_no_atomic_block()
await self.aclose_if_health_check_failed()
await self.aensure_connection()
start_transaction_under_autocommit = (
force_begin_transaction_with_broken_autocommit
and not autocommit
and hasattr(self, "_astart_transaction_under_autocommit")
)
if start_transaction_under_autocommit:
await self._astart_transaction_under_autocommit()
elif autocommit:
await self._aset_autocommit(autocommit)
else:
with debug_transaction(self, "BEGIN"):
await self._aset_autocommit(autocommit)
self.autocommit = autocommit
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
self.run_and_clear_commit_hooks()
self.run_commit_hooks_on_set_autocommit_on = False
def get_rollback(self): def get_rollback(self):
"""Get the "needs rollback" flag -- for *advanced use* only.""" """Get the "needs rollback" flag -- for *advanced use* only."""
if not self.in_atomic_block: if not self.in_atomic_block:
@ -575,6 +827,19 @@ class BaseDatabaseWrapper:
"subclasses of BaseDatabaseWrapper may require an is_usable() method" "subclasses of BaseDatabaseWrapper may require an is_usable() method"
) )
async def ais_usable(self):
"""
Test if the database connection is usable.
This method may assume that self.connection is not None.
Actual implementations should take care not to raise exceptions
as that may prevent Django from recycling unusable connections.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require an ais_usable() method"
)
def close_if_health_check_failed(self): def close_if_health_check_failed(self):
"""Close existing connection if it fails a health check.""" """Close existing connection if it fails a health check."""
if ( if (
@ -588,6 +853,20 @@ class BaseDatabaseWrapper:
self.close() self.close()
self.health_check_done = True self.health_check_done = True
async def aclose_if_health_check_failed(self):
"""Close existing connection if it fails a health check."""
if (
self.aconnection is None
or not self.health_check_enabled
or self.health_check_done
):
return
is_usable = await self.ais_usable()
if not is_usable:
await self.aclose()
self.health_check_done = True
def close_if_unusable_or_obsolete(self): def close_if_unusable_or_obsolete(self):
""" """
Close the current connection if unrecoverable errors have occurred Close the current connection if unrecoverable errors have occurred
@ -677,10 +956,18 @@ class BaseDatabaseWrapper:
"""Create a cursor that logs all queries in self.queries_log.""" """Create a cursor that logs all queries in self.queries_log."""
return utils.CursorDebugWrapper(cursor, self) return utils.CursorDebugWrapper(cursor, self)
def make_debug_async_cursor(self, cursor):
"""Create a cursor that logs all queries in self.queries_log."""
return utils.AsyncCursorDebugWrapper(cursor, self)
def make_cursor(self, cursor): def make_cursor(self, cursor):
"""Create a cursor without debug logging.""" """Create a cursor without debug logging."""
return utils.CursorWrapper(cursor, self) return utils.CursorWrapper(cursor, self)
def make_async_cursor(self, cursor):
"""Create a cursor without debug logging."""
return utils.AsyncCursorWrapper(cursor, self)
@contextmanager @contextmanager
def temporary_connection(self): def temporary_connection(self):
""" """
@ -698,6 +985,27 @@ class BaseDatabaseWrapper:
if must_close: if must_close:
self.close() self.close()
@asynccontextmanager
async def atemporary_connection(self):
"""
Context manager that ensures that a connection is established, and
if it opened one, closes it to avoid leaving a dangling connection.
This is useful for operations outside of the request-response cycle.
Provide a cursor::
async with self.atemporary_connection() as cursor:
...
"""
# unused
must_close = self.aconnection is None
try:
async with self.acursor() as cursor:
yield cursor
finally:
if must_close:
await self.aclose()
@contextmanager @contextmanager
def _nodb_cursor(self): def _nodb_cursor(self):
""" """

View file

@ -358,6 +358,9 @@ class BaseDatabaseFeatures:
# Does the backend support negative JSON array indexing? # Does the backend support negative JSON array indexing?
supports_json_negative_indexing = True supports_json_negative_indexing = True
# Asynchronous database operations
supports_async = False
# Does the backend support column collations? # Does the backend support column collations?
supports_collation_on_charfield = True supports_collation_on_charfield = True
supports_collation_on_textfield = True supports_collation_on_textfield = True

View file

@ -15,6 +15,9 @@ from django.core.exceptions import ImproperlyConfigured
from django.db import DatabaseError as WrappedDatabaseError from django.db import DatabaseError as WrappedDatabaseError
from django.db import connections from django.db import connections
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
from django.db.backends.utils import (
AsyncCursorDebugWrapper as AsyncBaseCursorDebugWrapper,
)
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.utils.asyncio import async_unsafe from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -98,6 +101,8 @@ def _get_decimal_column(data):
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "postgresql" vendor = "postgresql"
display_name = "PostgreSQL" display_name = "PostgreSQL"
_pg_version = None
# This dictionary maps Field objects to their associated PostgreSQL column # This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; # types, as strings. Column-type strings can contain format strings;
# they'll be interpolated against the values of Field.__dict__ before being # they'll be interpolated against the values of Field.__dict__ before being
@ -231,11 +236,57 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return self._connection_pools[self.alias] return self._connection_pools[self.alias]
@property
def apool(self):
pool_options = self.settings_dict["OPTIONS"].get("pool")
if self.alias == NO_DB_ALIAS or not pool_options:
return None
if self.alias not in self._aconnection_pools:
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
raise ImproperlyConfigured(
"Pooling doesn't support persistent connections."
)
# Set the default options.
if pool_options is True:
pool_options = {}
try:
from psycopg_pool import AsyncConnectionPool
except ImportError as err:
raise ImproperlyConfigured(
"Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
) from err
connect_kwargs = self.get_connection_params(for_async=True)
# Ensure we run in autocommit, Django properly sets it later on.
connect_kwargs["autocommit"] = True
enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
pool = AsyncConnectionPool(
kwargs=connect_kwargs,
open=False, # Do not open the pool during startup.
configure=self._aconfigure_connection,
check=AsyncConnectionPool.check_connection if enable_checks else None,
**pool_options,
)
# setdefault() ensures that multiple threads don't set this in
# parallel. Since we do not open the pool during it's init above,
# this means that at worst during startup multiple threads generate
# pool objects and the first to set it wins.
self._aconnection_pools.setdefault(self.alias, pool)
return self._aconnection_pools[self.alias]
def close_pool(self): def close_pool(self):
if self.pool: if self.pool:
self.pool.close() self.pool.close()
del self._connection_pools[self.alias] del self._connection_pools[self.alias]
async def aclose_pool(self):
if self.apool:
await self.apool.close()
del self._aconnection_pools[self.alias]
def get_database_version(self): def get_database_version(self):
""" """
Return a tuple of the database's version. Return a tuple of the database's version.
@ -243,7 +294,38 @@ class DatabaseWrapper(BaseDatabaseWrapper):
""" """
return divmod(self.pg_version, 10000) return divmod(self.pg_version, 10000)
def get_connection_params(self): async def aget_database_version(self):
"""
Return a tuple of the database's version.
E.g. for pg_version 120004, return (12, 4).
"""
pg_version = await self.apg_version
return divmod(pg_version, 10000)
def _get_sync_cursor_factory(self, server_side_binding=None):
if is_psycopg3 and server_side_binding is True:
return ServerBindingCursor
else:
return Cursor
def _get_async_cursor_factory(self, server_side_binding=None):
if is_psycopg3 and server_side_binding is True:
return AsyncServerBindingCursor
else:
return AsyncCursor
def _get_cursor_factory(self, server_side_binding=None, for_async=False):
if for_async and not is_psycopg3:
raise ImproperlyConfigured(
"Django requires psycopg >= 3 for ORM async support."
)
if for_async:
return self._get_async_cursor_factory(server_side_binding)
else:
return self._get_sync_cursor_factory(server_side_binding)
def get_connection_params(self, for_async=False):
settings_dict = self.settings_dict settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db # None may be used to connect to the default 'postgres' db
if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"): if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
@ -283,14 +365,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
raise ImproperlyConfigured("Database pooling requires psycopg >= 3") raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
server_side_binding = conn_params.pop("server_side_binding", None) server_side_binding = conn_params.pop("server_side_binding", None)
conn_params.setdefault( cursor_factory = self._get_cursor_factory(
"cursor_factory", server_side_binding, for_async=for_async
(
ServerBindingCursor
if is_psycopg3 and server_side_binding is True
else Cursor
),
) )
conn_params.setdefault("cursor_factory", cursor_factory)
if settings_dict["USER"]: if settings_dict["USER"]:
conn_params["user"] = settings_dict["USER"] conn_params["user"] = settings_dict["USER"]
if settings_dict["PASSWORD"]: if settings_dict["PASSWORD"]:
@ -310,8 +388,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
) )
return conn_params return conn_params
@async_unsafe def _get_isolation_level(self):
def get_new_connection(self, conn_params):
# self.isolation_level must be set: # self.isolation_level must be set:
# - after connecting to the database in order to obtain the database's # - after connecting to the database in order to obtain the database's
# default when no value is explicitly specified in options. # default when no value is explicitly specified in options.
@ -322,17 +399,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
try: try:
isolation_level_value = options["isolation_level"] isolation_level_value = options["isolation_level"]
except KeyError: except KeyError:
self.isolation_level = IsolationLevel.READ_COMMITTED isolation_level = IsolationLevel.READ_COMMITTED
else: else:
# Set the isolation level to the value from OPTIONS.
try: try:
self.isolation_level = IsolationLevel(isolation_level_value) isolation_level = IsolationLevel(isolation_level_value)
set_isolation_level = True set_isolation_level = True
except ValueError: except ValueError:
raise ImproperlyConfigured( raise ImproperlyConfigured(
f"Invalid transaction isolation level {isolation_level_value} " f"Invalid transaction isolation level {isolation_level_value} "
f"specified. Use one of the psycopg.IsolationLevel values." f"specified. Use one of the psycopg.IsolationLevel values."
) )
return isolation_level, set_isolation_level
@async_unsafe
def get_new_connection(self, conn_params):
isolation_level, set_isolation_level = self._get_isolation_level()
self.isolation_level = isolation_level
if self.pool: if self.pool:
# If nothing else has opened the pool, open it now. # If nothing else has opened the pool, open it now.
self.pool.open() self.pool.open()
@ -340,7 +422,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else: else:
connection = self.Database.connect(**conn_params) connection = self.Database.connect(**conn_params)
if set_isolation_level: if set_isolation_level:
connection.isolation_level = self.isolation_level connection.isolation_level = isolation_level
if not is_psycopg3: if not is_psycopg3:
# Register dummy loads() to avoid a round trip from psycopg2's # Register dummy loads() to avoid a round trip from psycopg2's
# decode to json.dumps() to json.loads(), when using a custom # decode to json.dumps() to json.loads(), when using a custom
@ -350,6 +432,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
) )
return connection return connection
async def aget_new_connection(self, conn_params):
isolation_level, set_isolation_level = self._get_isolation_level()
self.isolation_level = isolation_level
if self.apool:
# If nothing else has opened the pool, open it now.
await self.apool.open()
connection = await self.apool.getconn()
else:
connection = await self.Database.AsyncConnection.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = isolation_level
return connection
def ensure_timezone(self): def ensure_timezone(self):
# Close the pool so new connections pick up the correct timezone. # Close the pool so new connections pick up the correct timezone.
self.close_pool() self.close_pool()
@ -357,6 +452,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False return False
return self._configure_timezone(self.connection) return self._configure_timezone(self.connection)
async def aensure_timezone(self):
# Close the pool so new connections pick up the correct timezone.
await self.aclose_pool()
if self.connection is None:
return False
return await self._aconfigure_timezone(self.connection)
def _configure_timezone(self, connection): def _configure_timezone(self, connection):
conn_timezone_name = connection.info.parameter_status("TimeZone") conn_timezone_name = connection.info.parameter_status("TimeZone")
timezone_name = self.timezone_name timezone_name = self.timezone_name
@ -366,6 +468,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return True return True
return False return False
async def _aconfigure_timezone(self, connection):
conn_timezone_name = connection.info.parameter_status("TimeZone")
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
async with connection.cursor() as cursor:
await cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
return True
return False
def _configure_role(self, connection): def _configure_role(self, connection):
if new_role := self.settings_dict["OPTIONS"].get("assume_role"): if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
with connection.cursor() as cursor: with connection.cursor() as cursor:
@ -374,6 +485,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return True return True
return False return False
async def _aconfigure_role(self, connection):
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
async with connection.acursor() as cursor:
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
await cursor.aaexecute(sql)
return True
return False
def _configure_connection(self, connection): def _configure_connection(self, connection):
# This function is called from init_connection_state and from the # This function is called from init_connection_state and from the
# psycopg pool itself after a connection is opened. # psycopg pool itself after a connection is opened.
@ -387,6 +506,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return commit_role or commit_tz return commit_role or commit_tz
async def _aconfigure_connection(self, connection):
# This function is called from init_connection_state and from the
# psycopg pool itself after a connection is opened.
# Commit after setting the time zone.
commit_tz = await self._aconfigure_timezone(connection)
# Set the role on the connection. This is useful if the credential used
# to login is not the same as the role that owns database resources. As
# can be the case when using temporary or ephemeral credentials.
commit_role = await self._aconfigure_role(connection)
return commit_role or commit_tz
def _close(self): def _close(self):
if self.connection is not None: if self.connection is not None:
# `wrap_database_errors` only works for `putconn` as long as there # `wrap_database_errors` only works for `putconn` as long as there
@ -403,6 +535,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else: else:
return self.connection.close() return self.connection.close()
async def _aclose(self):
if self.aconnection is not None:
# `wrap_database_errors` only works for `putconn` as long as there
# is no `reset` function set in the pool because it is deferred
# into a thread and not directly executed.
with self.wrap_database_errors:
if self.apool:
# Ensure the correct pool is returned. This is a workaround
# for tests so a pool can be changed on setting changes
# (e.g. USE_TZ, TIME_ZONE).
await self.aconnection._pool.putconn(self.aconnection)
# Connection can no longer be used.
self.aconnection = None
else:
return await self.aconnection.close()
def init_connection_state(self): def init_connection_state(self):
super().init_connection_state() super().init_connection_state()
@ -412,6 +560,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if commit and not self.get_autocommit(): if commit and not self.get_autocommit():
self.connection.commit() self.connection.commit()
async def ainit_connection_state(self):
await super().ainit_connection_state()
if self.aconnection is not None and not self.apool:
commit = await self._aconfigure_connection(self.aconnection)
if commit:
autocommit = await self.aget_autocommit()
if not autocommit:
await self.aconnection.commit()
@async_unsafe @async_unsafe
def create_cursor(self, name=None): def create_cursor(self, name=None):
if name: if name:
@ -447,6 +606,35 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
return cursor return cursor
def create_async_cursor(self, name=None):
if name:
if self.settings_dict["OPTIONS"].get("server_side_binding") is not True:
# psycopg >= 3 forces the usage of server-side bindings for
# named cursors so a specialized class that implements
# server-side cursors while performing client-side bindings
# must be used if `server_side_binding` is disabled (default).
cursor = AsyncServerSideCursor(
self.aconnection,
name=name,
scrollable=False,
withhold=self.aconnection.autocommit,
)
else:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
cursor = self.aconnection.cursor(
name, scrollable=False, withhold=self.aconnection.autocommit
)
else:
cursor = self.aconnection.cursor()
# Register the cursor timezone only if the connection disagrees, to
# avoid copying the adapter map.
tzloader = self.aconnection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
if self.timezone != tzloader.timezone:
register_tzloader(self.timezone, cursor)
return cursor
def tzinfo_factory(self, offset): def tzinfo_factory(self, offset):
return self.timezone return self.timezone
@ -478,10 +666,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
) )
) )
async def achunked_cursor(self):
self._named_cursor_idx += 1
# Get the current async task
try:
current_task = asyncio.current_task()
except RuntimeError:
current_task = None
# Current task can be none even if the current_task call didn't error
if current_task:
task_ident = str(id(current_task))
else:
task_ident = "sync"
# Use that and the thread ident to get a unique name
return self._acursor(
name="_django_curs_%d_%s_%d"
% (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
self._named_cursor_idx,
)
)
def _set_autocommit(self, autocommit): def _set_autocommit(self, autocommit):
with self.wrap_database_errors: with self.wrap_database_errors:
self.connection.autocommit = autocommit self.connection.autocommit = autocommit
async def _aset_autocommit(self, autocommit):
with self.wrap_database_errors:
await self.aconnection.set_autocommit(autocommit)
def check_constraints(self, table_names=None): def check_constraints(self, table_names=None):
""" """
Check constraints by setting them to immediate. Return them to deferred Check constraints by setting them to immediate. Return them to deferred
@ -503,12 +718,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else: else:
return True return True
async def ais_usable(self):
if self.aconnection is None:
return False
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
async with self.aconnection.cursor() as cursor:
await cursor.execute("SELECT 1")
except Database.Error:
return False
else:
return True
def close_if_health_check_failed(self): def close_if_health_check_failed(self):
if self.pool: if self.pool:
# The pool only returns healthy connections. # The pool only returns healthy connections.
return return
return super().close_if_health_check_failed() return super().close_if_health_check_failed()
async def aclose_if_health_check_failed(self):
if self.apool:
# The pool only returns healthy connections.
return
return await super().aclose_if_health_check_failed()
@contextmanager @contextmanager
def _nodb_cursor(self): def _nodb_cursor(self):
cursor = None cursor = None
@ -549,8 +782,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property @cached_property
def pg_version(self): def pg_version(self):
with self.temporary_connection(): if self._pg_version is None:
return self.connection.info.server_version with self.temporary_connection():
self._pg_version = self.connection.info.server_version
return self._pg_version
@cached_property
async def apg_version(self):
if self._pg_version is None:
async with self.atemporary_connection():
self._pg_version = self.aconnection.info.server_version
return self._pg_version
def make_debug_cursor(self, cursor): def make_debug_cursor(self, cursor):
return CursorDebugWrapper(cursor, self) return CursorDebugWrapper(cursor, self)
@ -607,6 +849,36 @@ if is_psycopg3:
with self.debug_sql(statement): with self.debug_sql(statement):
return self.cursor.copy(statement) return self.cursor.copy(statement)
class AsyncServerBindingCursor(CursorMixin, Database.AsyncClientCursor):
pass
class AsyncCursor(CursorMixin, Database.AsyncClientCursor):
pass
class AsyncServerSideCursor(
CursorMixin,
Database.client_cursor.ClientCursorMixin,
Database.AsyncServerCursor,
):
"""
psycopg >= 3 forces the usage of server-side bindings when using named
cursors but the ORM doesn't yet support the systematic generation of
prepareable SQL (#20516).
ClientCursorMixin forces the usage of client-side bindings while
AsyncServerCursor implements the logic required to declare and scroll
through named cursors.
Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to
specify how parameters should be bound instead, which AsyncServerCursor
would inherit, but that's not the case.
"""
class AsyncCursorDebugWrapper(AsyncBaseCursorDebugWrapper):
def copy(self, statement):
with self.debug_sql(statement):
return self.cursor.copy(statement)
else: else:
Cursor = psycopg2.extensions.cursor Cursor = psycopg2.extensions.cursor

View file

@ -54,6 +54,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
END; END;
$$ LANGUAGE plpgsql;""" $$ LANGUAGE plpgsql;"""
requires_casted_case_in_updates = True requires_casted_case_in_updates = True
supports_async = is_psycopg3
supports_over_clause = True supports_over_clause = True
supports_frame_exclusion = True supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True only_supports_unbounded_with_preceding_and_following = True

View file

@ -114,6 +114,98 @@ class CursorWrapper:
return self.cursor.executemany(sql, param_list) return self.cursor.executemany(sql, param_list)
class AsyncCursorCtx:
"""
Asynchronous context manager to hold an async cursor.
"""
def __init__(self, db, name=None):
self.db = db
self.name = name
self.wrap_database_errors = self.db.wrap_database_errors
async def __aenter__(self):
await self.db.aclose_if_health_check_failed()
await self.db.aensure_connection()
self.wrap_database_errors.__enter__()
return self.db._aprepare_cursor(self.db.create_async_cursor(self.name))
async def __aexit__(self, type, value, traceback):
self.wrap_database_errors.__exit__(type, value, traceback)
class AsyncCursorWrapper(CursorWrapper):
async def _aexecute(self, sql, params, *ignored_wrapper_args):
# Raise a warning during app initialization (stored_app_configs is only
# ever set during testing).
if not apps.ready and not apps.stored_app_configs:
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None:
# params default might be backend specific.
return await self.cursor.execute(sql)
else:
return await self.cursor.execute(sql, params)
async def _aexecute_with_wrappers(self, sql, params, many, executor):
context = {"connection": self.db, "cursor": self}
for wrapper in reversed(self.db.execute_wrappers):
executor = functools.partial(wrapper, executor)
return await executor(sql, params, many, context)
async def aexecute(self, sql, params=None):
return await self._aexecute_with_wrappers(
sql, params, many=False, executor=self._aexecute
)
async def _aexecutemany(self, sql, param_list, *ignored_wrapper_args):
# Raise a warning during app initialization (stored_app_configs is only
# ever set during testing).
if not apps.ready and not apps.stored_app_configs:
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
return await self.cursor.executemany(sql, param_list)
async def aexecutemany(self, sql, param_list):
return await self._aexecute_with_wrappers(
sql, param_list, many=True, executor=self._aexecutemany
)
async def afetchone(self, *args, **kwargs):
return await self.cursor.fetchone(*args, **kwargs)
async def afetchmany(self, *args, **kwargs):
return await self.cursor.fetchmany(*args, **kwargs)
async def afetchall(self, *args, **kwargs):
return await self.cursor.fetchall(*args, **kwargs)
def acopy(self, *args, **kwargs):
return self.cursor.copy(*args, **kwargs)
def astream(self, *args, **kwargs):
return self.cursor.stream(*args, **kwargs)
async def ascroll(self, *args, **kwargs):
return await self.cursor.scroll(*args, **kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, type, value, traceback):
try:
await self.close()
except self.db.Database.Error:
pass
async def __aiter__(self):
with self.db.wrap_database_errors:
async for item in self.cursor:
yield item
class CursorDebugWrapper(CursorWrapper): class CursorDebugWrapper(CursorWrapper):
# XXX callproc isn't instrumented at this time. # XXX callproc isn't instrumented at this time.
@ -163,6 +255,57 @@ class CursorDebugWrapper(CursorWrapper):
) )
class AsyncCursorDebugWrapper(AsyncCursorWrapper):
# XXX callproc isn't instrumented at this time.
async def aexecute(self, sql, params=None):
with self.debug_sql(sql, params, use_last_executed_query=True):
return await super().aexecute(sql, params)
async def aexecutemany(self, sql, param_list):
with self.debug_sql(sql, param_list, many=True):
return await super().aexecutemany(sql, param_list)
@contextmanager
def debug_sql(
self, sql=None, params=None, use_last_executed_query=False, many=False
):
start = time.monotonic()
try:
yield
finally:
stop = time.monotonic()
duration = stop - start
if use_last_executed_query:
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
try:
times = len(params) if many else ""
except TypeError:
# params could be an iterator.
times = "?"
self.db.queries_log.append(
{
"sql": "%s times: %s" % (times, sql) if many else sql,
"time": "%.3f" % duration,
"async": True,
}
)
logger.debug(
"(%.3f) %s; args=%s; alias=%s; async=True",
duration,
sql,
params,
self.db.alias,
extra={
"duration": duration,
"sql": sql,
"params": params,
"alias": self.db.alias,
"async": True,
},
)
@contextmanager @contextmanager
def debug_transaction(connection, sql): def debug_transaction(connection, sql):
start = time.monotonic() start = time.monotonic()
@ -176,18 +319,21 @@ def debug_transaction(connection, sql):
{ {
"sql": "%s" % sql, "sql": "%s" % sql,
"time": "%.3f" % duration, "time": "%.3f" % duration,
"async": connection.features.supports_async,
} }
) )
logger.debug( logger.debug(
"(%.3f) %s; args=%s; alias=%s", "(%.3f) %s; args=%s; alias=%s; async=%s",
duration, duration,
sql, sql,
None, None,
connection.alias, connection.alias,
connection.features.supports_async,
extra={ extra={
"duration": duration, "duration": duration,
"sql": sql, "sql": sql,
"alias": connection.alias, "alias": connection.alias,
"async": connection.features.supports_async,
}, },
) )

View file

@ -1,6 +1,8 @@
import pkgutil import pkgutil
from importlib import import_module from importlib import import_module
from asgiref.local import Local
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@ -197,6 +199,89 @@ class ConnectionHandler(BaseConnectionHandler):
return backend.DatabaseWrapper(db, alias) return backend.DatabaseWrapper(db, alias)
class AsyncAlias:
"""
A Context-aware list of connections.
"""
def __init__(self) -> None:
self._connections = Local()
setattr(self._connections, "_stack", [])
@property
def connections(self):
return getattr(self._connections, "_stack", [])
def __len__(self):
return len(self.connections)
def __iter__(self):
return iter(self.connections)
def __str__(self):
return ", ".join([str(id(conn)) for conn in self.connections])
def __repr__(self):
return f"<{self.__class__.__name__}: {len(self.connections)} connections>"
def add_connection(self, connection):
setattr(self._connections, "_stack", self.connections + [connection])
def pop(self):
conns = self.connections
conns.pop()
setattr(self._connections, "_stack", conns)
class AsyncConnectionHandler:
"""
Context-aware class to store async connections, mapped by alias name.
"""
_from_testcase = False
def __init__(self) -> None:
self._aliases = Local()
self._connection_count = Local()
setattr(self._connection_count, "value", 0)
def __getitem__(self, alias):
try:
async_alias = getattr(self._aliases, alias)
except AttributeError:
async_alias = AsyncAlias()
setattr(self._aliases, alias, async_alias)
return async_alias
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.count} connections>"
@property
def count(self):
return getattr(self._connection_count, "value", 0)
@property
def empty(self):
return self.count == 0
def add_connection(self, using, connection):
self[using].add_connection(connection)
setattr(self._connection_count, "value", self.count + 1)
async def pop_connection(self, using):
await self[using].connections[-1].aclose_pool()
self[using].connections.pop()
setattr(self._connection_count, "value", self.count - 1)
def get_connection(self, using):
alias = self[using]
if len(alias.connections) == 0:
raise ConnectionDoesNotExist(
f"There are no async connections using the '{using}' alias."
)
return alias.connections[-1]
class ConnectionRouter: class ConnectionRouter:
def __init__(self, routers=None): def __init__(self, routers=None):
""" """

View file

@ -38,7 +38,13 @@ from django.core.management.color import no_style
from django.core.management.sql import emit_post_migrate_signal from django.core.management.sql import emit_post_migrate_signal
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
from django.core.signals import setting_changed from django.core.signals import setting_changed
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction from django.db import (
DEFAULT_DB_ALIAS,
async_connections,
connection,
connections,
transaction,
)
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
from django.forms.fields import CharField from django.forms.fields import CharField
from django.http import QueryDict from django.http import QueryDict
@ -1415,6 +1421,7 @@ class TestCase(TransactionTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
async_connections._from_testcase = True
if not ( if not (
cls._databases_support_transactions() cls._databases_support_transactions()
and cls._databases_support_savepoints() and cls._databases_support_savepoints()

View file

@ -211,7 +211,6 @@ Database backends
* MySQL connections now default to using the ``utf8mb4`` character set, * MySQL connections now default to using the ``utf8mb4`` character set,
instead of ``utf8``, which is an alias for the deprecated character set instead of ``utf8``, which is an alias for the deprecated character set
``utf8mb3``. ``utf8mb3``.
* Oracle backends now support :ref:`connection pools <oracle-pool>`, by setting * Oracle backends now support :ref:`connection pools <oracle-pool>`, by setting
``"pool"`` in the :setting:`OPTIONS` part of your database configuration. ``"pool"`` in the :setting:`OPTIONS` part of your database configuration.

View file

@ -151,6 +151,19 @@ instance of those now-deprecated classes.
Minor features Minor features
-------------- --------------
Database backends
~~~~~~~~~~~~~~~~~
* It is now possible to perform asynchronous raw SQL queries using an async
cursor.
This is only possible on backends that support async-native connections.
Currently only supported in PostreSQL with the
``django.db.backends.postgresql`` backend.
* It is now possible to perform asynchronous raw SQL queries using an async
cursor, if the backend supports async-native connections. This is only
supported on PostgreSQL with ``psycopg`` 3.1.8+. See
:ref:`async-connection-cursor` for more details.
:mod:`django.contrib.admin` :mod:`django.contrib.admin`
~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -404,6 +404,37 @@ is equivalent to::
finally: finally:
c.close() c.close()
.. _async-connection-cursor:
Async Connections and cursors
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. versionadded:: 6.0
On backends that support async-native connections, you can request an async
cursor::
from django.db import new_connection
async with new_connection() as connection:
async with connection.acursor() as c:
await c.aexecute(...)
Async cursors provide the following methods:
* ``.aexecute()``
* ``.aexecutemany()``
* ``.afetchone()``
* ``.afetchmany()``
* ``.afetchall()``
* ``.acopy()``
* ``.astream()``
* ``.ascroll()``
Currently, Django ships with the following async-enabled backend:
* ``django.db.backends.postgresql`` with ``psycopg3``.
Calling stored procedures Calling stored procedures
~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -0,0 +1,26 @@
from asgiref.sync import sync_to_async
from django.db import new_connection
from django.test import TransactionTestCase, skipUnlessDBFeature
from .models import SimpleModel
@skipUnlessDBFeature("supports_async")
class AsyncSyncCominglingTest(TransactionTestCase):
available_apps = ["async"]
async def change_model_with_async(self, obj):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""UPDATE "async_simplemodel" SET "field" = 10,"""
""" "created" = '2024-11-19 14:12:50.606384'::timestamp"""
""" WHERE "async_simplemodel"."id" = 1"""
)
async def test_transaction_async_comingling(self):
s1 = await sync_to_async(SimpleModel.objects.create)(field=0)
# with transaction.atomic():
await self.change_model_with_async(s1)

View file

@ -0,0 +1,133 @@
from django.db import new_connection
from django.test import SimpleTestCase, skipUnlessDBFeature
@skipUnlessDBFeature("supports_async")
class AsyncCursorTests(SimpleTestCase):
databases = {"default", "other"}
async def test_aexecute(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("SELECT 1")
async def test_aexecutemany(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("CREATE TABLE numbers (number SMALLINT)")
await cursor.aexecutemany(
"INSERT INTO numbers VALUES (%s)", [(1,), (2,), (3,)]
)
await cursor.aexecute("SELECT * FROM numbers")
result = await cursor.afetchall()
self.assertEqual(result, [(1,), (2,), (3,)])
async def test_afetchone(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("SELECT 1")
result = await cursor.afetchone()
self.assertEqual(result, (1,))
async def test_afetchmany(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
result = await cursor.afetchmany(size=2)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",)])
async def test_afetchall(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
result = await cursor.afetchall()
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_aiter(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
async for record in cursor:
result.append(record)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_acopy(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
async with cursor.acopy(
"""
COPY (
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)
) TO STDOUT"""
) as copy:
async for row in copy.rows():
result.append(row)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_astream(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
async for record in cursor.astream(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
):
result.append(record)
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
async def test_ascroll(self):
result = []
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute(
"""
SELECT *
FROM (VALUES
('BANANA'),
('STRAWBERRY'),
('MELON')
) AS v (NAME)"""
)
await cursor.ascroll(1, "absolute")
result = await cursor.afetchall()
self.assertEqual(result, [("STRAWBERRY",), ("MELON",)])
await cursor.ascroll(0, "absolute")
result = await cursor.afetchall()
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])

View file

@ -0,0 +1,21 @@
from django.db import new_connection
from django.test import SimpleTestCase, skipUnlessDBFeature
@skipUnlessDBFeature("supports_async")
class AsyncDatabaseWrapperTests(SimpleTestCase):
databases = {"default", "other"}
async def test_async_cursor(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.execute("SELECT 1")
result = (await cursor.fetchone())[0]
self.assertEqual(result, 1)
async def test_async_cursor_alias(self):
async with new_connection() as conn:
async with conn.acursor() as cursor:
await cursor.aexecute("SELECT 1")
result = (await cursor.afetchone())[0]
self.assertEqual(result, 1)

View file

@ -1,11 +1,26 @@
"""Tests for django.db.utils.""" """Tests for django.db.utils."""
import asyncio
import concurrent.futures
import unittest import unittest
from unittest import mock
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, ProgrammingError, connection from django.db import (
from django.db.utils import ConnectionHandler, load_backend DEFAULT_DB_ALIAS,
from django.test import SimpleTestCase, TestCase NotSupportedError,
ProgrammingError,
async_connections,
connection,
new_connection,
)
from django.db.utils import (
AsyncAlias,
AsyncConnectionHandler,
ConnectionHandler,
load_backend,
)
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.utils.connection import ConnectionDoesNotExist from django.utils.connection import ConnectionDoesNotExist
@ -90,3 +105,82 @@ class LoadBackendTests(SimpleTestCase):
with self.assertRaisesMessage(ImproperlyConfigured, msg) as cm: with self.assertRaisesMessage(ImproperlyConfigured, msg) as cm:
load_backend("foo") load_backend("foo")
self.assertEqual(str(cm.exception.__cause__), "No module named 'foo'") self.assertEqual(str(cm.exception.__cause__), "No module named 'foo'")
class AsyncConnectionTests(SimpleTestCase):
databases = {"default", "other"}
def run_pool(self, coro, count=2):
def fn():
asyncio.run(coro())
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = []
for _ in range(count):
futures.append(executor.submit(fn))
for future in concurrent.futures.as_completed(futures):
exc = future.exception()
if exc is not None:
raise exc
def test_async_alias(self):
alias = AsyncAlias()
assert len(alias) == 0
assert alias.connections == []
async def coro():
assert len(alias) == 0
alias.add_connection(mock.Mock())
alias.pop()
self.run_pool(coro)
def test_async_connection_handler(self):
aconns = AsyncConnectionHandler()
assert aconns.empty is True
assert aconns["default"].connections == []
async def coro():
assert aconns["default"].connections == []
aconns.add_connection("default", mock.Mock())
aconns.pop_connection("default")
self.run_pool(coro)
@skipUnlessDBFeature("supports_async")
def test_new_connection_threading(self):
async def coro():
assert async_connections.empty is True
async with new_connection() as connection:
async with connection.acursor() as c:
await c.execute("SELECT 1")
self.run_pool(coro)
@skipUnlessDBFeature("supports_async")
async def test_new_connection(self):
with self.assertRaises(ConnectionDoesNotExist):
async_connections.get_connection(DEFAULT_DB_ALIAS)
async with new_connection():
conn1 = async_connections.get_connection(DEFAULT_DB_ALIAS)
self.assertIsNotNone(conn1.aconnection)
async with new_connection():
conn2 = async_connections.get_connection(DEFAULT_DB_ALIAS)
self.assertIsNotNone(conn1.aconnection)
self.assertIsNotNone(conn2.aconnection)
self.assertNotEqual(conn1.aconnection, conn2.aconnection)
self.assertIsNotNone(conn1.aconnection)
self.assertIsNone(conn2.aconnection)
self.assertIsNone(conn1.aconnection)
with self.assertRaises(ConnectionDoesNotExist):
async_connections.get_connection(DEFAULT_DB_ALIAS)
@skipUnlessDBFeature("supports_async")
async def test_new_connection_on_sync(self):
with self.assertRaises(NotSupportedError):
async with new_connection():
async_connections.get_connection(DEFAULT_DB_ALIAS)

View file

@ -9,6 +9,7 @@ from django.db import (
IntegrityError, IntegrityError,
OperationalError, OperationalError,
connection, connection,
new_connection,
transaction, transaction,
) )
from django.test import ( from django.test import (
@ -586,3 +587,93 @@ class DurableTransactionTests(DurableTestsBase, TransactionTestCase):
class DurableTests(DurableTestsBase, TestCase): class DurableTests(DurableTestsBase, TestCase):
pass pass
@skipUnlessDBFeature("uses_savepoints", "supports_async")
class AsyncTransactionTestCase(TransactionTestCase):
available_apps = ["transactions"]
async def test_new_connection_nested(self):
async with new_connection() as connection:
async with new_connection() as connection2:
await connection2.aset_autocommit(False)
async with connection2.acursor() as cursor2:
await cursor2.aexecute(
"INSERT INTO transactions_reporter "
"(first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Sarah", "Hatoff", ""),
)
await cursor2.aexecute("SELECT * FROM transactions_reporter")
result = await cursor2.afetchmany()
assert len(result) == 1
async with connection.acursor() as cursor:
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany()
assert len(result) == 1
async def test_new_connection_nested2(self):
async with new_connection() as connection:
async with connection.acursor() as cursor:
await cursor.aexecute(
"INSERT INTO transactions_reporter (first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Sarah", "Hatoff", ""),
)
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany()
assert len(result) == 1
async with new_connection() as connection2:
await connection2.aset_autocommit(False)
async with connection2.acursor() as cursor2:
await cursor2.aexecute("SELECT * FROM transactions_reporter")
result = await cursor2.afetchmany()
# This connection won't see any rows, because the outer one
# hasn't committed yet.
assert len(result) == 0
async def test_new_connection_nested3(self):
async with new_connection() as connection:
async with new_connection() as connection2:
await connection2.aset_autocommit(False)
assert id(connection) != id(connection2)
async with connection2.acursor() as cursor2:
await cursor2.aexecute(
"INSERT INTO transactions_reporter "
"(first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Sarah", "Hatoff", ""),
)
await cursor2.aexecute("SELECT * FROM transactions_reporter")
result = await cursor2.afetchmany()
assert len(result) == 1
# Outermost connection doesn't see what the innermost did,
# because the innermost connection hasn't exited yet.
async with connection.acursor() as cursor:
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany()
assert len(result) == 0
async def test_asavepoint(self):
async with new_connection() as connection:
async with connection.acursor() as cursor:
sid = await connection.asavepoint()
assert sid is not None
await cursor.aexecute(
"INSERT INTO transactions_reporter (first_name, last_name, email) "
"VALUES (%s, %s, %s)",
("Archibald", "Haddock", ""),
)
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.afetchmany(size=5)
assert len(result) == 1
assert result[0][1:] == ("Archibald", "Haddock", "")
await connection.asavepoint_rollback(sid)
await cursor.aexecute("SELECT * FROM transactions_reporter")
result = await cursor.fetchmany(size=5)
assert len(result) == 0