mirror of
https://github.com/django/django.git
synced 2025-11-18 19:01:40 +00:00
Merge da20298982 into 1ce6e78dd4
This commit is contained in:
commit
f667229336
16 changed files with 1319 additions and 44 deletions
|
|
@ -2,6 +2,7 @@ from django.core import signals
|
||||||
from django.db.utils import (
|
from django.db.utils import (
|
||||||
DEFAULT_DB_ALIAS,
|
DEFAULT_DB_ALIAS,
|
||||||
DJANGO_VERSION_PICKLE_KEY,
|
DJANGO_VERSION_PICKLE_KEY,
|
||||||
|
AsyncConnectionHandler,
|
||||||
ConnectionHandler,
|
ConnectionHandler,
|
||||||
ConnectionRouter,
|
ConnectionRouter,
|
||||||
DatabaseError,
|
DatabaseError,
|
||||||
|
|
@ -36,6 +37,50 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
connections = ConnectionHandler()
|
connections = ConnectionHandler()
|
||||||
|
async_connections = AsyncConnectionHandler()
|
||||||
|
|
||||||
|
|
||||||
|
class new_connection:
|
||||||
|
"""
|
||||||
|
Asynchronous context manager to instantiate new async connections.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, using=DEFAULT_DB_ALIAS):
|
||||||
|
self.using = using
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
conn = connections.create_connection(self.using)
|
||||||
|
if conn.features.supports_async is False:
|
||||||
|
raise NotSupportedError(
|
||||||
|
"The database backend does not support asynchronous execution."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.force_rollback = False
|
||||||
|
if async_connections.empty is True:
|
||||||
|
if async_connections._from_testcase is True:
|
||||||
|
self.force_rollback = True
|
||||||
|
self.conn = conn
|
||||||
|
|
||||||
|
async_connections.add_connection(self.using, self.conn)
|
||||||
|
|
||||||
|
await self.conn.aensure_connection()
|
||||||
|
if self.force_rollback is True:
|
||||||
|
await self.conn.aset_autocommit(False)
|
||||||
|
|
||||||
|
return self.conn
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
autocommit = await self.conn.aget_autocommit()
|
||||||
|
if autocommit is False:
|
||||||
|
if exc_type is None and self.force_rollback is False:
|
||||||
|
await self.conn.acommit()
|
||||||
|
else:
|
||||||
|
await self.conn.arollback()
|
||||||
|
await self.conn.aclose()
|
||||||
|
|
||||||
|
await async_connections.pop_connection(self.using)
|
||||||
|
|
||||||
|
|
||||||
router = ConnectionRouter()
|
router = ConnectionRouter()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import time
|
||||||
import warnings
|
import warnings
|
||||||
import zoneinfo
|
import zoneinfo
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
@ -39,6 +39,7 @@ class BaseDatabaseWrapper:
|
||||||
ops = None
|
ops = None
|
||||||
vendor = "unknown"
|
vendor = "unknown"
|
||||||
display_name = "unknown"
|
display_name = "unknown"
|
||||||
|
|
||||||
SchemaEditorClass = None
|
SchemaEditorClass = None
|
||||||
# Classes instantiated in __init__().
|
# Classes instantiated in __init__().
|
||||||
client_class = None
|
client_class = None
|
||||||
|
|
@ -47,6 +48,7 @@ class BaseDatabaseWrapper:
|
||||||
introspection_class = None
|
introspection_class = None
|
||||||
ops_class = None
|
ops_class = None
|
||||||
validation_class = BaseDatabaseValidation
|
validation_class = BaseDatabaseValidation
|
||||||
|
_aconnection_pools = {}
|
||||||
|
|
||||||
queries_limit = 9000
|
queries_limit = 9000
|
||||||
|
|
||||||
|
|
@ -54,6 +56,7 @@ class BaseDatabaseWrapper:
|
||||||
# Connection related attributes.
|
# Connection related attributes.
|
||||||
# The underlying database connection.
|
# The underlying database connection.
|
||||||
self.connection = None
|
self.connection = None
|
||||||
|
self.aconnection = None
|
||||||
# `settings_dict` should be a dictionary containing keys such as
|
# `settings_dict` should be a dictionary containing keys such as
|
||||||
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
|
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
|
||||||
# to disambiguate it from Django settings modules.
|
# to disambiguate it from Django settings modules.
|
||||||
|
|
@ -187,25 +190,44 @@ class BaseDatabaseWrapper:
|
||||||
"method."
|
"method."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aget_database_version(self):
|
||||||
|
"""Return a tuple of the database's version."""
|
||||||
|
raise NotSupportedError(
|
||||||
|
"subclasses of BaseDatabaseWrapper may require an aget_database_version() "
|
||||||
|
"method."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_database_version_supported(self, db_version):
|
||||||
|
if (
|
||||||
|
self.features.minimum_database_version is not None
|
||||||
|
and db_version < self.features.minimum_database_version
|
||||||
|
):
|
||||||
|
str_db_version = ".".join(map(str, db_version))
|
||||||
|
min_db_version = ".".join(map(str, self.features.minimum_database_version))
|
||||||
|
raise NotSupportedError(
|
||||||
|
f"{self.display_name} {min_db_version} or later is required "
|
||||||
|
f"(found {str_db_version})."
|
||||||
|
)
|
||||||
|
|
||||||
def check_database_version_supported(self):
|
def check_database_version_supported(self):
|
||||||
"""
|
"""
|
||||||
Raise an error if the database version isn't supported by this
|
Raise an error if the database version isn't supported by this
|
||||||
version of Django.
|
version of Django.
|
||||||
"""
|
"""
|
||||||
if (
|
db_version = self.get_database_version()
|
||||||
self.features.minimum_database_version is not None
|
self._validate_database_version_supported(db_version)
|
||||||
and self.get_database_version() < self.features.minimum_database_version
|
|
||||||
):
|
async def acheck_database_version_supported(self):
|
||||||
db_version = ".".join(map(str, self.get_database_version()))
|
"""
|
||||||
min_db_version = ".".join(map(str, self.features.minimum_database_version))
|
Raise an error if the database version isn't supported by this
|
||||||
raise NotSupportedError(
|
version of Django.
|
||||||
f"{self.display_name} {min_db_version} or later is required "
|
"""
|
||||||
f"(found {db_version})."
|
db_version = await self.aget_database_version()
|
||||||
)
|
self._validate_database_version_supported(db_version)
|
||||||
|
|
||||||
# ##### Backend-specific methods for creating connections and cursors #####
|
# ##### Backend-specific methods for creating connections and cursors #####
|
||||||
|
|
||||||
def get_connection_params(self):
|
def get_connection_params(self, for_async=False):
|
||||||
"""Return a dict of parameters suitable for get_new_connection."""
|
"""Return a dict of parameters suitable for get_new_connection."""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"subclasses of BaseDatabaseWrapper may require a get_connection_params() "
|
"subclasses of BaseDatabaseWrapper may require a get_connection_params() "
|
||||||
|
|
@ -219,23 +241,42 @@ class BaseDatabaseWrapper:
|
||||||
"method"
|
"method"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aget_new_connection(self, conn_params):
|
||||||
|
"""Open a connection to the database."""
|
||||||
|
raise NotSupportedError(
|
||||||
|
"subclasses of BaseDatabaseWrapper may require an aget_new_connection() "
|
||||||
|
"method"
|
||||||
|
)
|
||||||
|
|
||||||
def init_connection_state(self):
|
def init_connection_state(self):
|
||||||
"""Initialize the database connection settings."""
|
"""Initialize the database connection settings."""
|
||||||
if self.alias not in RAN_DB_VERSION_CHECK:
|
if self.alias not in RAN_DB_VERSION_CHECK:
|
||||||
self.check_database_version_supported()
|
self.check_database_version_supported()
|
||||||
RAN_DB_VERSION_CHECK.add(self.alias)
|
RAN_DB_VERSION_CHECK.add(self.alias)
|
||||||
|
|
||||||
|
async def ainit_connection_state(self):
|
||||||
|
"""Initialize the database connection settings."""
|
||||||
|
if self.alias not in RAN_DB_VERSION_CHECK:
|
||||||
|
await self.acheck_database_version_supported()
|
||||||
|
RAN_DB_VERSION_CHECK.add(self.alias)
|
||||||
|
|
||||||
def create_cursor(self, name=None):
|
def create_cursor(self, name=None):
|
||||||
"""Create a cursor. Assume that a connection is established."""
|
"""Create a cursor. Assume that a connection is established."""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"subclasses of BaseDatabaseWrapper may require a create_cursor() method"
|
"subclasses of BaseDatabaseWrapper may require a create_cursor() method"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_async_cursor(self, name=None):
|
||||||
|
"""Create a cursor. Assume that a connection is established."""
|
||||||
|
raise NotSupportedError(
|
||||||
|
"subclasses of BaseDatabaseWrapper may require a "
|
||||||
|
"create_async_cursor() method"
|
||||||
|
)
|
||||||
|
|
||||||
# ##### Backend-specific methods for creating connections #####
|
# ##### Backend-specific methods for creating connections #####
|
||||||
|
|
||||||
@async_unsafe
|
@contextmanager
|
||||||
def connect(self):
|
def connect_manager(self):
|
||||||
"""Connect to the database. Assume that the connection is closed."""
|
|
||||||
# Check for invalid configurations.
|
# Check for invalid configurations.
|
||||||
self.check_settings()
|
self.check_settings()
|
||||||
# In case the previous connection was closed while in an atomic block
|
# In case the previous connection was closed while in an atomic block
|
||||||
|
|
@ -251,14 +292,30 @@ class BaseDatabaseWrapper:
|
||||||
self.errors_occurred = False
|
self.errors_occurred = False
|
||||||
# New connections are healthy.
|
# New connections are healthy.
|
||||||
self.health_check_done = True
|
self.health_check_done = True
|
||||||
# Establish the connection
|
|
||||||
conn_params = self.get_connection_params()
|
|
||||||
self.connection = self.get_new_connection(conn_params)
|
|
||||||
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
|
|
||||||
self.init_connection_state()
|
|
||||||
connection_created.send(sender=self.__class__, connection=self)
|
|
||||||
|
|
||||||
self.run_on_commit = []
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
connection_created.send(sender=self.__class__, connection=self)
|
||||||
|
self.run_on_commit = []
|
||||||
|
|
||||||
|
@async_unsafe
|
||||||
|
def connect(self):
|
||||||
|
"""Connect to the database. Assume that the connection is closed."""
|
||||||
|
with self.connect_manager():
|
||||||
|
conn_params = self.get_connection_params()
|
||||||
|
self.connection = self.get_new_connection(conn_params)
|
||||||
|
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
|
||||||
|
self.init_connection_state()
|
||||||
|
|
||||||
|
async def aconnect(self):
|
||||||
|
"""Connect to the database. Assume that the connection is closed."""
|
||||||
|
with self.connect_manager():
|
||||||
|
# Establish the connection
|
||||||
|
conn_params = self.get_connection_params(for_async=True)
|
||||||
|
self.aconnection = await self.aget_new_connection(conn_params)
|
||||||
|
await self.aset_autocommit(self.settings_dict["AUTOCOMMIT"])
|
||||||
|
await self.ainit_connection_state()
|
||||||
|
|
||||||
def check_settings(self):
|
def check_settings(self):
|
||||||
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
|
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
|
||||||
|
|
@ -278,6 +335,16 @@ class BaseDatabaseWrapper:
|
||||||
with self.wrap_database_errors:
|
with self.wrap_database_errors:
|
||||||
self.connect()
|
self.connect()
|
||||||
|
|
||||||
|
async def aensure_connection(self):
|
||||||
|
"""Guarantee that a connection to the database is established."""
|
||||||
|
if self.aconnection is None:
|
||||||
|
if self.in_atomic_block and self.closed_in_transaction:
|
||||||
|
raise ProgrammingError(
|
||||||
|
"Cannot open a new connection in an atomic block."
|
||||||
|
)
|
||||||
|
with self.wrap_database_errors:
|
||||||
|
await self.aconnect()
|
||||||
|
|
||||||
# ##### Backend-specific wrappers for PEP-249 connection methods #####
|
# ##### Backend-specific wrappers for PEP-249 connection methods #####
|
||||||
|
|
||||||
def _prepare_cursor(self, cursor):
|
def _prepare_cursor(self, cursor):
|
||||||
|
|
@ -291,27 +358,57 @@ class BaseDatabaseWrapper:
|
||||||
wrapped_cursor = self.make_cursor(cursor)
|
wrapped_cursor = self.make_cursor(cursor)
|
||||||
return wrapped_cursor
|
return wrapped_cursor
|
||||||
|
|
||||||
|
def _aprepare_cursor(self, cursor):
|
||||||
|
"""
|
||||||
|
Validate the connection is usable and perform database cursor wrapping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.validate_thread_sharing()
|
||||||
|
if self.queries_logged:
|
||||||
|
wrapped_cursor = self.make_debug_async_cursor(cursor)
|
||||||
|
else:
|
||||||
|
wrapped_cursor = self.make_async_cursor(cursor)
|
||||||
|
return wrapped_cursor
|
||||||
|
|
||||||
def _cursor(self, name=None):
|
def _cursor(self, name=None):
|
||||||
self.close_if_health_check_failed()
|
self.close_if_health_check_failed()
|
||||||
self.ensure_connection()
|
self.ensure_connection()
|
||||||
with self.wrap_database_errors:
|
with self.wrap_database_errors:
|
||||||
return self._prepare_cursor(self.create_cursor(name))
|
return self._prepare_cursor(self.create_cursor(name))
|
||||||
|
|
||||||
|
def _acursor(self, name=None):
|
||||||
|
return utils.AsyncCursorCtx(self, name)
|
||||||
|
|
||||||
def _commit(self):
|
def _commit(self):
|
||||||
if self.connection is not None:
|
if self.connection is not None:
|
||||||
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
||||||
return self.connection.commit()
|
return self.connection.commit()
|
||||||
|
|
||||||
|
async def _acommit(self):
|
||||||
|
if self.aconnection is not None:
|
||||||
|
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
||||||
|
return await self.aconnection.commit()
|
||||||
|
|
||||||
def _rollback(self):
|
def _rollback(self):
|
||||||
if self.connection is not None:
|
if self.connection is not None:
|
||||||
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
||||||
return self.connection.rollback()
|
return self.connection.rollback()
|
||||||
|
|
||||||
|
async def _arollback(self):
|
||||||
|
if self.aconnection is not None:
|
||||||
|
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
||||||
|
return await self.aconnection.rollback()
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
if self.connection is not None:
|
if self.connection is not None:
|
||||||
with self.wrap_database_errors:
|
with self.wrap_database_errors:
|
||||||
return self.connection.close()
|
return self.connection.close()
|
||||||
|
|
||||||
|
async def _aclose(self):
|
||||||
|
if self.aconnection is not None:
|
||||||
|
with self.wrap_database_errors:
|
||||||
|
return await self.aconnection.close()
|
||||||
|
|
||||||
# ##### Generic wrappers for PEP-249 connection methods #####
|
# ##### Generic wrappers for PEP-249 connection methods #####
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
|
|
@ -319,6 +416,10 @@ class BaseDatabaseWrapper:
|
||||||
"""Create a cursor, opening a connection if necessary."""
|
"""Create a cursor, opening a connection if necessary."""
|
||||||
return self._cursor()
|
return self._cursor()
|
||||||
|
|
||||||
|
def acursor(self):
|
||||||
|
"""Create an async cursor, opening a connection if necessary."""
|
||||||
|
return self._acursor()
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
def commit(self):
|
def commit(self):
|
||||||
"""Commit a transaction and reset the dirty flag."""
|
"""Commit a transaction and reset the dirty flag."""
|
||||||
|
|
@ -329,6 +430,15 @@ class BaseDatabaseWrapper:
|
||||||
self.errors_occurred = False
|
self.errors_occurred = False
|
||||||
self.run_commit_hooks_on_set_autocommit_on = True
|
self.run_commit_hooks_on_set_autocommit_on = True
|
||||||
|
|
||||||
|
async def acommit(self):
|
||||||
|
"""Commit a transaction and reset the dirty flag."""
|
||||||
|
self.validate_thread_sharing()
|
||||||
|
self.validate_no_atomic_block()
|
||||||
|
await self._acommit()
|
||||||
|
# A successful commit means that the database connection works.
|
||||||
|
self.errors_occurred = False
|
||||||
|
self.run_commit_hooks_on_set_autocommit_on = True
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
def rollback(self):
|
def rollback(self):
|
||||||
"""Roll back a transaction and reset the dirty flag."""
|
"""Roll back a transaction and reset the dirty flag."""
|
||||||
|
|
@ -340,6 +450,16 @@ class BaseDatabaseWrapper:
|
||||||
self.needs_rollback = False
|
self.needs_rollback = False
|
||||||
self.run_on_commit = []
|
self.run_on_commit = []
|
||||||
|
|
||||||
|
async def arollback(self):
|
||||||
|
"""Roll back a transaction and reset the dirty flag."""
|
||||||
|
self.validate_thread_sharing()
|
||||||
|
self.validate_no_atomic_block()
|
||||||
|
await self._arollback()
|
||||||
|
# A successful rollback means that the database connection works.
|
||||||
|
self.errors_occurred = False
|
||||||
|
self.needs_rollback = False
|
||||||
|
self.run_on_commit = []
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the connection to the database."""
|
"""Close the connection to the database."""
|
||||||
|
|
@ -360,24 +480,59 @@ class BaseDatabaseWrapper:
|
||||||
else:
|
else:
|
||||||
self.connection = None
|
self.connection = None
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
"""Close the connection to the database."""
|
||||||
|
self.validate_thread_sharing()
|
||||||
|
self.run_on_commit = []
|
||||||
|
|
||||||
|
# Don't call validate_no_atomic_block() to avoid making it difficult
|
||||||
|
# to get rid of a connection in an invalid state. The next connect()
|
||||||
|
# will reset the transaction state anyway.
|
||||||
|
if self.closed_in_transaction or self.aconnection is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self._aclose()
|
||||||
|
finally:
|
||||||
|
if self.in_atomic_block:
|
||||||
|
self.closed_in_transaction = True
|
||||||
|
self.needs_rollback = True
|
||||||
|
else:
|
||||||
|
self.aconnection = None
|
||||||
|
|
||||||
# ##### Backend-specific savepoint management methods #####
|
# ##### Backend-specific savepoint management methods #####
|
||||||
|
|
||||||
def _savepoint(self, sid):
|
def _savepoint(self, sid):
|
||||||
with self.cursor() as cursor:
|
with self.cursor() as cursor:
|
||||||
cursor.execute(self.ops.savepoint_create_sql(sid))
|
cursor.execute(self.ops.savepoint_create_sql(sid))
|
||||||
|
|
||||||
|
async def _asavepoint(self, sid):
|
||||||
|
async with self.acursor() as cursor:
|
||||||
|
await cursor.aexecute(self.ops.savepoint_create_sql(sid))
|
||||||
|
|
||||||
def _savepoint_rollback(self, sid):
|
def _savepoint_rollback(self, sid):
|
||||||
with self.cursor() as cursor:
|
with self.cursor() as cursor:
|
||||||
cursor.execute(self.ops.savepoint_rollback_sql(sid))
|
cursor.execute(self.ops.savepoint_rollback_sql(sid))
|
||||||
|
|
||||||
|
async def _asavepoint_rollback(self, sid):
|
||||||
|
async with self.acursor() as cursor:
|
||||||
|
await cursor.aexecute(self.ops.savepoint_rollback_sql(sid))
|
||||||
|
|
||||||
def _savepoint_commit(self, sid):
|
def _savepoint_commit(self, sid):
|
||||||
with self.cursor() as cursor:
|
with self.cursor() as cursor:
|
||||||
cursor.execute(self.ops.savepoint_commit_sql(sid))
|
cursor.execute(self.ops.savepoint_commit_sql(sid))
|
||||||
|
|
||||||
|
async def _asavepoint_commit(self, sid):
|
||||||
|
async with self.acursor() as cursor:
|
||||||
|
await cursor.aexecute(self.ops.savepoint_commit_sql(sid))
|
||||||
|
|
||||||
def _savepoint_allowed(self):
|
def _savepoint_allowed(self):
|
||||||
# Savepoints cannot be created outside a transaction
|
# Savepoints cannot be created outside a transaction
|
||||||
return self.features.uses_savepoints and not self.get_autocommit()
|
return self.features.uses_savepoints and not self.get_autocommit()
|
||||||
|
|
||||||
|
async def _asavepoint_allowed(self):
|
||||||
|
# Savepoints cannot be created outside a transaction
|
||||||
|
return self.features.uses_savepoints and not (await self.aget_autocommit())
|
||||||
|
|
||||||
# ##### Generic savepoint management methods #####
|
# ##### Generic savepoint management methods #####
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
|
|
@ -401,6 +556,26 @@ class BaseDatabaseWrapper:
|
||||||
|
|
||||||
return sid
|
return sid
|
||||||
|
|
||||||
|
async def asavepoint(self):
|
||||||
|
"""
|
||||||
|
Create a savepoint inside the current transaction. Return an
|
||||||
|
identifier for the savepoint that will be used for the subsequent
|
||||||
|
rollback or commit. Do nothing if savepoints are not supported.
|
||||||
|
"""
|
||||||
|
if not (await self._asavepoint_allowed()):
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_ident = _thread.get_ident()
|
||||||
|
tid = str(thread_ident).replace("-", "")
|
||||||
|
|
||||||
|
self.savepoint_state += 1
|
||||||
|
sid = "s%s_x%d" % (tid, self.savepoint_state)
|
||||||
|
|
||||||
|
self.validate_thread_sharing()
|
||||||
|
await self._asavepoint(sid)
|
||||||
|
|
||||||
|
return sid
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
def savepoint_rollback(self, sid):
|
def savepoint_rollback(self, sid):
|
||||||
"""
|
"""
|
||||||
|
|
@ -419,6 +594,23 @@ class BaseDatabaseWrapper:
|
||||||
if sid not in sids
|
if sid not in sids
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def asavepoint_rollback(self, sid):
|
||||||
|
"""
|
||||||
|
Roll back to a savepoint. Do nothing if savepoints are not supported.
|
||||||
|
"""
|
||||||
|
if not (await self._asavepoint_allowed()):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.validate_thread_sharing()
|
||||||
|
await self._asavepoint_rollback(sid)
|
||||||
|
|
||||||
|
# Remove any callbacks registered while this savepoint was active.
|
||||||
|
self.run_on_commit = [
|
||||||
|
(sids, func, robust)
|
||||||
|
for (sids, func, robust) in self.run_on_commit
|
||||||
|
if sid not in sids
|
||||||
|
]
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
def savepoint_commit(self, sid):
|
def savepoint_commit(self, sid):
|
||||||
"""
|
"""
|
||||||
|
|
@ -430,6 +622,16 @@ class BaseDatabaseWrapper:
|
||||||
self.validate_thread_sharing()
|
self.validate_thread_sharing()
|
||||||
self._savepoint_commit(sid)
|
self._savepoint_commit(sid)
|
||||||
|
|
||||||
|
async def asavepoint_commit(self, sid):
|
||||||
|
"""
|
||||||
|
Release a savepoint. Do nothing if savepoints are not supported.
|
||||||
|
"""
|
||||||
|
if not (await self._asavepoint_allowed()):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.validate_thread_sharing()
|
||||||
|
await self._asavepoint_commit(sid)
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
def clean_savepoints(self):
|
def clean_savepoints(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -447,6 +649,14 @@ class BaseDatabaseWrapper:
|
||||||
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
|
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _aset_autocommit(self, autocommit):
|
||||||
|
"""
|
||||||
|
Backend-specific implementation to enable or disable autocommit.
|
||||||
|
"""
|
||||||
|
raise NotSupportedError(
|
||||||
|
"subclasses of BaseDatabaseWrapper may require an _aset_autocommit() method"
|
||||||
|
)
|
||||||
|
|
||||||
# ##### Generic transaction management methods #####
|
# ##### Generic transaction management methods #####
|
||||||
|
|
||||||
def get_autocommit(self):
|
def get_autocommit(self):
|
||||||
|
|
@ -454,6 +664,11 @@ class BaseDatabaseWrapper:
|
||||||
self.ensure_connection()
|
self.ensure_connection()
|
||||||
return self.autocommit
|
return self.autocommit
|
||||||
|
|
||||||
|
async def aget_autocommit(self):
|
||||||
|
"""Get the autocommit state."""
|
||||||
|
await self.aensure_connection()
|
||||||
|
return self.autocommit
|
||||||
|
|
||||||
def set_autocommit(
|
def set_autocommit(
|
||||||
self, autocommit, force_begin_transaction_with_broken_autocommit=False
|
self, autocommit, force_begin_transaction_with_broken_autocommit=False
|
||||||
):
|
):
|
||||||
|
|
@ -491,6 +706,43 @@ class BaseDatabaseWrapper:
|
||||||
self.run_and_clear_commit_hooks()
|
self.run_and_clear_commit_hooks()
|
||||||
self.run_commit_hooks_on_set_autocommit_on = False
|
self.run_commit_hooks_on_set_autocommit_on = False
|
||||||
|
|
||||||
|
async def aset_autocommit(
|
||||||
|
self, autocommit, force_begin_transaction_with_broken_autocommit=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enable or disable autocommit.
|
||||||
|
|
||||||
|
The usual way to start a transaction is to turn autocommit off.
|
||||||
|
SQLite does not properly start a transaction when disabling
|
||||||
|
autocommit. To avoid this buggy behavior and to actually enter a new
|
||||||
|
transaction, an explicit BEGIN is required. Using
|
||||||
|
force_begin_transaction_with_broken_autocommit=True will issue an
|
||||||
|
explicit BEGIN with SQLite. This option will be ignored for other
|
||||||
|
backends.
|
||||||
|
"""
|
||||||
|
self.validate_no_atomic_block()
|
||||||
|
await self.aclose_if_health_check_failed()
|
||||||
|
await self.aensure_connection()
|
||||||
|
|
||||||
|
start_transaction_under_autocommit = (
|
||||||
|
force_begin_transaction_with_broken_autocommit
|
||||||
|
and not autocommit
|
||||||
|
and hasattr(self, "_astart_transaction_under_autocommit")
|
||||||
|
)
|
||||||
|
|
||||||
|
if start_transaction_under_autocommit:
|
||||||
|
await self._astart_transaction_under_autocommit()
|
||||||
|
elif autocommit:
|
||||||
|
await self._aset_autocommit(autocommit)
|
||||||
|
else:
|
||||||
|
with debug_transaction(self, "BEGIN"):
|
||||||
|
await self._aset_autocommit(autocommit)
|
||||||
|
self.autocommit = autocommit
|
||||||
|
|
||||||
|
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
|
||||||
|
self.run_and_clear_commit_hooks()
|
||||||
|
self.run_commit_hooks_on_set_autocommit_on = False
|
||||||
|
|
||||||
def get_rollback(self):
|
def get_rollback(self):
|
||||||
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
||||||
if not self.in_atomic_block:
|
if not self.in_atomic_block:
|
||||||
|
|
@ -575,6 +827,19 @@ class BaseDatabaseWrapper:
|
||||||
"subclasses of BaseDatabaseWrapper may require an is_usable() method"
|
"subclasses of BaseDatabaseWrapper may require an is_usable() method"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def ais_usable(self):
|
||||||
|
"""
|
||||||
|
Test if the database connection is usable.
|
||||||
|
|
||||||
|
This method may assume that self.connection is not None.
|
||||||
|
|
||||||
|
Actual implementations should take care not to raise exceptions
|
||||||
|
as that may prevent Django from recycling unusable connections.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"subclasses of BaseDatabaseWrapper may require an ais_usable() method"
|
||||||
|
)
|
||||||
|
|
||||||
def close_if_health_check_failed(self):
|
def close_if_health_check_failed(self):
|
||||||
"""Close existing connection if it fails a health check."""
|
"""Close existing connection if it fails a health check."""
|
||||||
if (
|
if (
|
||||||
|
|
@ -588,6 +853,20 @@ class BaseDatabaseWrapper:
|
||||||
self.close()
|
self.close()
|
||||||
self.health_check_done = True
|
self.health_check_done = True
|
||||||
|
|
||||||
|
async def aclose_if_health_check_failed(self):
|
||||||
|
"""Close existing connection if it fails a health check."""
|
||||||
|
if (
|
||||||
|
self.aconnection is None
|
||||||
|
or not self.health_check_enabled
|
||||||
|
or self.health_check_done
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
is_usable = await self.ais_usable()
|
||||||
|
if not is_usable:
|
||||||
|
await self.aclose()
|
||||||
|
self.health_check_done = True
|
||||||
|
|
||||||
def close_if_unusable_or_obsolete(self):
|
def close_if_unusable_or_obsolete(self):
|
||||||
"""
|
"""
|
||||||
Close the current connection if unrecoverable errors have occurred
|
Close the current connection if unrecoverable errors have occurred
|
||||||
|
|
@ -677,10 +956,18 @@ class BaseDatabaseWrapper:
|
||||||
"""Create a cursor that logs all queries in self.queries_log."""
|
"""Create a cursor that logs all queries in self.queries_log."""
|
||||||
return utils.CursorDebugWrapper(cursor, self)
|
return utils.CursorDebugWrapper(cursor, self)
|
||||||
|
|
||||||
|
def make_debug_async_cursor(self, cursor):
|
||||||
|
"""Create a cursor that logs all queries in self.queries_log."""
|
||||||
|
return utils.AsyncCursorDebugWrapper(cursor, self)
|
||||||
|
|
||||||
def make_cursor(self, cursor):
|
def make_cursor(self, cursor):
|
||||||
"""Create a cursor without debug logging."""
|
"""Create a cursor without debug logging."""
|
||||||
return utils.CursorWrapper(cursor, self)
|
return utils.CursorWrapper(cursor, self)
|
||||||
|
|
||||||
|
def make_async_cursor(self, cursor):
|
||||||
|
"""Create a cursor without debug logging."""
|
||||||
|
return utils.AsyncCursorWrapper(cursor, self)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def temporary_connection(self):
|
def temporary_connection(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -698,6 +985,27 @@ class BaseDatabaseWrapper:
|
||||||
if must_close:
|
if must_close:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def atemporary_connection(self):
|
||||||
|
"""
|
||||||
|
Context manager that ensures that a connection is established, and
|
||||||
|
if it opened one, closes it to avoid leaving a dangling connection.
|
||||||
|
This is useful for operations outside of the request-response cycle.
|
||||||
|
|
||||||
|
Provide a cursor::
|
||||||
|
async with self.atemporary_connection() as cursor:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
# unused
|
||||||
|
|
||||||
|
must_close = self.aconnection is None
|
||||||
|
try:
|
||||||
|
async with self.acursor() as cursor:
|
||||||
|
yield cursor
|
||||||
|
finally:
|
||||||
|
if must_close:
|
||||||
|
await self.aclose()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _nodb_cursor(self):
|
def _nodb_cursor(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -358,6 +358,9 @@ class BaseDatabaseFeatures:
|
||||||
# Does the backend support negative JSON array indexing?
|
# Does the backend support negative JSON array indexing?
|
||||||
supports_json_negative_indexing = True
|
supports_json_negative_indexing = True
|
||||||
|
|
||||||
|
# Asynchronous database operations
|
||||||
|
supports_async = False
|
||||||
|
|
||||||
# Does the backend support column collations?
|
# Does the backend support column collations?
|
||||||
supports_collation_on_charfield = True
|
supports_collation_on_charfield = True
|
||||||
supports_collation_on_textfield = True
|
supports_collation_on_textfield = True
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,9 @@ from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.db import DatabaseError as WrappedDatabaseError
|
from django.db import DatabaseError as WrappedDatabaseError
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
||||||
|
from django.db.backends.utils import (
|
||||||
|
AsyncCursorDebugWrapper as AsyncBaseCursorDebugWrapper,
|
||||||
|
)
|
||||||
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
|
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
|
||||||
from django.utils.asyncio import async_unsafe
|
from django.utils.asyncio import async_unsafe
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
|
|
@ -98,6 +101,8 @@ def _get_decimal_column(data):
|
||||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
vendor = "postgresql"
|
vendor = "postgresql"
|
||||||
display_name = "PostgreSQL"
|
display_name = "PostgreSQL"
|
||||||
|
_pg_version = None
|
||||||
|
|
||||||
# This dictionary maps Field objects to their associated PostgreSQL column
|
# This dictionary maps Field objects to their associated PostgreSQL column
|
||||||
# types, as strings. Column-type strings can contain format strings;
|
# types, as strings. Column-type strings can contain format strings;
|
||||||
# they'll be interpolated against the values of Field.__dict__ before being
|
# they'll be interpolated against the values of Field.__dict__ before being
|
||||||
|
|
@ -231,11 +236,57 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
|
|
||||||
return self._connection_pools[self.alias]
|
return self._connection_pools[self.alias]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def apool(self):
|
||||||
|
pool_options = self.settings_dict["OPTIONS"].get("pool")
|
||||||
|
if self.alias == NO_DB_ALIAS or not pool_options:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.alias not in self._aconnection_pools:
|
||||||
|
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
"Pooling doesn't support persistent connections."
|
||||||
|
)
|
||||||
|
# Set the default options.
|
||||||
|
if pool_options is True:
|
||||||
|
pool_options = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
"Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
connect_kwargs = self.get_connection_params(for_async=True)
|
||||||
|
# Ensure we run in autocommit, Django properly sets it later on.
|
||||||
|
connect_kwargs["autocommit"] = True
|
||||||
|
enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
kwargs=connect_kwargs,
|
||||||
|
open=False, # Do not open the pool during startup.
|
||||||
|
configure=self._aconfigure_connection,
|
||||||
|
check=AsyncConnectionPool.check_connection if enable_checks else None,
|
||||||
|
**pool_options,
|
||||||
|
)
|
||||||
|
# setdefault() ensures that multiple threads don't set this in
|
||||||
|
# parallel. Since we do not open the pool during it's init above,
|
||||||
|
# this means that at worst during startup multiple threads generate
|
||||||
|
# pool objects and the first to set it wins.
|
||||||
|
self._aconnection_pools.setdefault(self.alias, pool)
|
||||||
|
|
||||||
|
return self._aconnection_pools[self.alias]
|
||||||
|
|
||||||
def close_pool(self):
|
def close_pool(self):
|
||||||
if self.pool:
|
if self.pool:
|
||||||
self.pool.close()
|
self.pool.close()
|
||||||
del self._connection_pools[self.alias]
|
del self._connection_pools[self.alias]
|
||||||
|
|
||||||
|
async def aclose_pool(self):
|
||||||
|
if self.apool:
|
||||||
|
await self.apool.close()
|
||||||
|
del self._aconnection_pools[self.alias]
|
||||||
|
|
||||||
def get_database_version(self):
|
def get_database_version(self):
|
||||||
"""
|
"""
|
||||||
Return a tuple of the database's version.
|
Return a tuple of the database's version.
|
||||||
|
|
@ -243,7 +294,38 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
"""
|
"""
|
||||||
return divmod(self.pg_version, 10000)
|
return divmod(self.pg_version, 10000)
|
||||||
|
|
||||||
def get_connection_params(self):
|
async def aget_database_version(self):
|
||||||
|
"""
|
||||||
|
Return a tuple of the database's version.
|
||||||
|
E.g. for pg_version 120004, return (12, 4).
|
||||||
|
"""
|
||||||
|
pg_version = await self.apg_version
|
||||||
|
return divmod(pg_version, 10000)
|
||||||
|
|
||||||
|
def _get_sync_cursor_factory(self, server_side_binding=None):
|
||||||
|
if is_psycopg3 and server_side_binding is True:
|
||||||
|
return ServerBindingCursor
|
||||||
|
else:
|
||||||
|
return Cursor
|
||||||
|
|
||||||
|
def _get_async_cursor_factory(self, server_side_binding=None):
|
||||||
|
if is_psycopg3 and server_side_binding is True:
|
||||||
|
return AsyncServerBindingCursor
|
||||||
|
else:
|
||||||
|
return AsyncCursor
|
||||||
|
|
||||||
|
def _get_cursor_factory(self, server_side_binding=None, for_async=False):
|
||||||
|
if for_async and not is_psycopg3:
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
"Django requires psycopg >= 3 for ORM async support."
|
||||||
|
)
|
||||||
|
|
||||||
|
if for_async:
|
||||||
|
return self._get_async_cursor_factory(server_side_binding)
|
||||||
|
else:
|
||||||
|
return self._get_sync_cursor_factory(server_side_binding)
|
||||||
|
|
||||||
|
def get_connection_params(self, for_async=False):
|
||||||
settings_dict = self.settings_dict
|
settings_dict = self.settings_dict
|
||||||
# None may be used to connect to the default 'postgres' db
|
# None may be used to connect to the default 'postgres' db
|
||||||
if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
|
if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
|
||||||
|
|
@ -283,14 +365,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
|
raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
|
||||||
|
|
||||||
server_side_binding = conn_params.pop("server_side_binding", None)
|
server_side_binding = conn_params.pop("server_side_binding", None)
|
||||||
conn_params.setdefault(
|
cursor_factory = self._get_cursor_factory(
|
||||||
"cursor_factory",
|
server_side_binding, for_async=for_async
|
||||||
(
|
|
||||||
ServerBindingCursor
|
|
||||||
if is_psycopg3 and server_side_binding is True
|
|
||||||
else Cursor
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
conn_params.setdefault("cursor_factory", cursor_factory)
|
||||||
if settings_dict["USER"]:
|
if settings_dict["USER"]:
|
||||||
conn_params["user"] = settings_dict["USER"]
|
conn_params["user"] = settings_dict["USER"]
|
||||||
if settings_dict["PASSWORD"]:
|
if settings_dict["PASSWORD"]:
|
||||||
|
|
@ -310,8 +388,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
)
|
)
|
||||||
return conn_params
|
return conn_params
|
||||||
|
|
||||||
@async_unsafe
|
def _get_isolation_level(self):
|
||||||
def get_new_connection(self, conn_params):
|
|
||||||
# self.isolation_level must be set:
|
# self.isolation_level must be set:
|
||||||
# - after connecting to the database in order to obtain the database's
|
# - after connecting to the database in order to obtain the database's
|
||||||
# default when no value is explicitly specified in options.
|
# default when no value is explicitly specified in options.
|
||||||
|
|
@ -322,17 +399,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
try:
|
try:
|
||||||
isolation_level_value = options["isolation_level"]
|
isolation_level_value = options["isolation_level"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self.isolation_level = IsolationLevel.READ_COMMITTED
|
isolation_level = IsolationLevel.READ_COMMITTED
|
||||||
else:
|
else:
|
||||||
# Set the isolation level to the value from OPTIONS.
|
|
||||||
try:
|
try:
|
||||||
self.isolation_level = IsolationLevel(isolation_level_value)
|
isolation_level = IsolationLevel(isolation_level_value)
|
||||||
set_isolation_level = True
|
set_isolation_level = True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ImproperlyConfigured(
|
raise ImproperlyConfigured(
|
||||||
f"Invalid transaction isolation level {isolation_level_value} "
|
f"Invalid transaction isolation level {isolation_level_value} "
|
||||||
f"specified. Use one of the psycopg.IsolationLevel values."
|
f"specified. Use one of the psycopg.IsolationLevel values."
|
||||||
)
|
)
|
||||||
|
return isolation_level, set_isolation_level
|
||||||
|
|
||||||
|
@async_unsafe
|
||||||
|
def get_new_connection(self, conn_params):
|
||||||
|
isolation_level, set_isolation_level = self._get_isolation_level()
|
||||||
|
self.isolation_level = isolation_level
|
||||||
if self.pool:
|
if self.pool:
|
||||||
# If nothing else has opened the pool, open it now.
|
# If nothing else has opened the pool, open it now.
|
||||||
self.pool.open()
|
self.pool.open()
|
||||||
|
|
@ -340,7 +422,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
else:
|
else:
|
||||||
connection = self.Database.connect(**conn_params)
|
connection = self.Database.connect(**conn_params)
|
||||||
if set_isolation_level:
|
if set_isolation_level:
|
||||||
connection.isolation_level = self.isolation_level
|
connection.isolation_level = isolation_level
|
||||||
if not is_psycopg3:
|
if not is_psycopg3:
|
||||||
# Register dummy loads() to avoid a round trip from psycopg2's
|
# Register dummy loads() to avoid a round trip from psycopg2's
|
||||||
# decode to json.dumps() to json.loads(), when using a custom
|
# decode to json.dumps() to json.loads(), when using a custom
|
||||||
|
|
@ -350,6 +432,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
)
|
)
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
|
async def aget_new_connection(self, conn_params):
|
||||||
|
isolation_level, set_isolation_level = self._get_isolation_level()
|
||||||
|
self.isolation_level = isolation_level
|
||||||
|
if self.apool:
|
||||||
|
# If nothing else has opened the pool, open it now.
|
||||||
|
await self.apool.open()
|
||||||
|
connection = await self.apool.getconn()
|
||||||
|
else:
|
||||||
|
connection = await self.Database.AsyncConnection.connect(**conn_params)
|
||||||
|
if set_isolation_level:
|
||||||
|
connection.isolation_level = isolation_level
|
||||||
|
return connection
|
||||||
|
|
||||||
def ensure_timezone(self):
|
def ensure_timezone(self):
|
||||||
# Close the pool so new connections pick up the correct timezone.
|
# Close the pool so new connections pick up the correct timezone.
|
||||||
self.close_pool()
|
self.close_pool()
|
||||||
|
|
@ -357,6 +452,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
return False
|
return False
|
||||||
return self._configure_timezone(self.connection)
|
return self._configure_timezone(self.connection)
|
||||||
|
|
||||||
|
async def aensure_timezone(self):
|
||||||
|
# Close the pool so new connections pick up the correct timezone.
|
||||||
|
await self.aclose_pool()
|
||||||
|
if self.connection is None:
|
||||||
|
return False
|
||||||
|
return await self._aconfigure_timezone(self.connection)
|
||||||
|
|
||||||
def _configure_timezone(self, connection):
|
def _configure_timezone(self, connection):
|
||||||
conn_timezone_name = connection.info.parameter_status("TimeZone")
|
conn_timezone_name = connection.info.parameter_status("TimeZone")
|
||||||
timezone_name = self.timezone_name
|
timezone_name = self.timezone_name
|
||||||
|
|
@ -366,6 +468,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _aconfigure_timezone(self, connection):
|
||||||
|
conn_timezone_name = connection.info.parameter_status("TimeZone")
|
||||||
|
timezone_name = self.timezone_name
|
||||||
|
if timezone_name and conn_timezone_name != timezone_name:
|
||||||
|
async with connection.cursor() as cursor:
|
||||||
|
await cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _configure_role(self, connection):
|
def _configure_role(self, connection):
|
||||||
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
|
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
|
|
@ -374,6 +485,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _aconfigure_role(self, connection):
|
||||||
|
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
|
||||||
|
async with connection.acursor() as cursor:
|
||||||
|
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
|
||||||
|
await cursor.aaexecute(sql)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _configure_connection(self, connection):
|
def _configure_connection(self, connection):
|
||||||
# This function is called from init_connection_state and from the
|
# This function is called from init_connection_state and from the
|
||||||
# psycopg pool itself after a connection is opened.
|
# psycopg pool itself after a connection is opened.
|
||||||
|
|
@ -387,6 +506,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
|
|
||||||
return commit_role or commit_tz
|
return commit_role or commit_tz
|
||||||
|
|
||||||
|
async def _aconfigure_connection(self, connection):
|
||||||
|
# This function is called from init_connection_state and from the
|
||||||
|
# psycopg pool itself after a connection is opened.
|
||||||
|
|
||||||
|
# Commit after setting the time zone.
|
||||||
|
commit_tz = await self._aconfigure_timezone(connection)
|
||||||
|
# Set the role on the connection. This is useful if the credential used
|
||||||
|
# to login is not the same as the role that owns database resources. As
|
||||||
|
# can be the case when using temporary or ephemeral credentials.
|
||||||
|
commit_role = await self._aconfigure_role(connection)
|
||||||
|
|
||||||
|
return commit_role or commit_tz
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
if self.connection is not None:
|
if self.connection is not None:
|
||||||
# `wrap_database_errors` only works for `putconn` as long as there
|
# `wrap_database_errors` only works for `putconn` as long as there
|
||||||
|
|
@ -403,6 +535,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
else:
|
else:
|
||||||
return self.connection.close()
|
return self.connection.close()
|
||||||
|
|
||||||
|
async def _aclose(self):
|
||||||
|
if self.aconnection is not None:
|
||||||
|
# `wrap_database_errors` only works for `putconn` as long as there
|
||||||
|
# is no `reset` function set in the pool because it is deferred
|
||||||
|
# into a thread and not directly executed.
|
||||||
|
with self.wrap_database_errors:
|
||||||
|
if self.apool:
|
||||||
|
# Ensure the correct pool is returned. This is a workaround
|
||||||
|
# for tests so a pool can be changed on setting changes
|
||||||
|
# (e.g. USE_TZ, TIME_ZONE).
|
||||||
|
await self.aconnection._pool.putconn(self.aconnection)
|
||||||
|
# Connection can no longer be used.
|
||||||
|
self.aconnection = None
|
||||||
|
else:
|
||||||
|
return await self.aconnection.close()
|
||||||
|
|
||||||
def init_connection_state(self):
|
def init_connection_state(self):
|
||||||
super().init_connection_state()
|
super().init_connection_state()
|
||||||
|
|
||||||
|
|
@ -412,6 +560,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
if commit and not self.get_autocommit():
|
if commit and not self.get_autocommit():
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
|
async def ainit_connection_state(self):
|
||||||
|
await super().ainit_connection_state()
|
||||||
|
|
||||||
|
if self.aconnection is not None and not self.apool:
|
||||||
|
commit = await self._aconfigure_connection(self.aconnection)
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
autocommit = await self.aget_autocommit()
|
||||||
|
if not autocommit:
|
||||||
|
await self.aconnection.commit()
|
||||||
|
|
||||||
@async_unsafe
|
@async_unsafe
|
||||||
def create_cursor(self, name=None):
|
def create_cursor(self, name=None):
|
||||||
if name:
|
if name:
|
||||||
|
|
@ -447,6 +606,35 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
||||||
return cursor
|
return cursor
|
||||||
|
|
||||||
|
def create_async_cursor(self, name=None):
|
||||||
|
if name:
|
||||||
|
if self.settings_dict["OPTIONS"].get("server_side_binding") is not True:
|
||||||
|
# psycopg >= 3 forces the usage of server-side bindings for
|
||||||
|
# named cursors so a specialized class that implements
|
||||||
|
# server-side cursors while performing client-side bindings
|
||||||
|
# must be used if `server_side_binding` is disabled (default).
|
||||||
|
cursor = AsyncServerSideCursor(
|
||||||
|
self.aconnection,
|
||||||
|
name=name,
|
||||||
|
scrollable=False,
|
||||||
|
withhold=self.aconnection.autocommit,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# In autocommit mode, the cursor will be used outside of a
|
||||||
|
# transaction, hence use a holdable cursor.
|
||||||
|
cursor = self.aconnection.cursor(
|
||||||
|
name, scrollable=False, withhold=self.aconnection.autocommit
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor = self.aconnection.cursor()
|
||||||
|
|
||||||
|
# Register the cursor timezone only if the connection disagrees, to
|
||||||
|
# avoid copying the adapter map.
|
||||||
|
tzloader = self.aconnection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
|
||||||
|
if self.timezone != tzloader.timezone:
|
||||||
|
register_tzloader(self.timezone, cursor)
|
||||||
|
return cursor
|
||||||
|
|
||||||
def tzinfo_factory(self, offset):
|
def tzinfo_factory(self, offset):
|
||||||
return self.timezone
|
return self.timezone
|
||||||
|
|
||||||
|
|
@ -478,10 +666,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def achunked_cursor(self):
|
||||||
|
self._named_cursor_idx += 1
|
||||||
|
# Get the current async task
|
||||||
|
try:
|
||||||
|
current_task = asyncio.current_task()
|
||||||
|
except RuntimeError:
|
||||||
|
current_task = None
|
||||||
|
# Current task can be none even if the current_task call didn't error
|
||||||
|
if current_task:
|
||||||
|
task_ident = str(id(current_task))
|
||||||
|
else:
|
||||||
|
task_ident = "sync"
|
||||||
|
# Use that and the thread ident to get a unique name
|
||||||
|
return self._acursor(
|
||||||
|
name="_django_curs_%d_%s_%d"
|
||||||
|
% (
|
||||||
|
# Avoid reusing name in other threads / tasks
|
||||||
|
threading.current_thread().ident,
|
||||||
|
task_ident,
|
||||||
|
self._named_cursor_idx,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _set_autocommit(self, autocommit):
|
def _set_autocommit(self, autocommit):
|
||||||
with self.wrap_database_errors:
|
with self.wrap_database_errors:
|
||||||
self.connection.autocommit = autocommit
|
self.connection.autocommit = autocommit
|
||||||
|
|
||||||
|
async def _aset_autocommit(self, autocommit):
|
||||||
|
with self.wrap_database_errors:
|
||||||
|
await self.aconnection.set_autocommit(autocommit)
|
||||||
|
|
||||||
def check_constraints(self, table_names=None):
|
def check_constraints(self, table_names=None):
|
||||||
"""
|
"""
|
||||||
Check constraints by setting them to immediate. Return them to deferred
|
Check constraints by setting them to immediate. Return them to deferred
|
||||||
|
|
@ -503,12 +718,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def ais_usable(self):
|
||||||
|
if self.aconnection is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
# Use a psycopg cursor directly, bypassing Django's utilities.
|
||||||
|
async with self.aconnection.cursor() as cursor:
|
||||||
|
await cursor.execute("SELECT 1")
|
||||||
|
except Database.Error:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
def close_if_health_check_failed(self):
|
def close_if_health_check_failed(self):
|
||||||
if self.pool:
|
if self.pool:
|
||||||
# The pool only returns healthy connections.
|
# The pool only returns healthy connections.
|
||||||
return
|
return
|
||||||
return super().close_if_health_check_failed()
|
return super().close_if_health_check_failed()
|
||||||
|
|
||||||
|
async def aclose_if_health_check_failed(self):
|
||||||
|
if self.apool:
|
||||||
|
# The pool only returns healthy connections.
|
||||||
|
return
|
||||||
|
return await super().aclose_if_health_check_failed()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _nodb_cursor(self):
|
def _nodb_cursor(self):
|
||||||
cursor = None
|
cursor = None
|
||||||
|
|
@ -549,8 +782,17 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def pg_version(self):
|
def pg_version(self):
|
||||||
with self.temporary_connection():
|
if self._pg_version is None:
|
||||||
return self.connection.info.server_version
|
with self.temporary_connection():
|
||||||
|
self._pg_version = self.connection.info.server_version
|
||||||
|
return self._pg_version
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
async def apg_version(self):
|
||||||
|
if self._pg_version is None:
|
||||||
|
async with self.atemporary_connection():
|
||||||
|
self._pg_version = self.aconnection.info.server_version
|
||||||
|
return self._pg_version
|
||||||
|
|
||||||
def make_debug_cursor(self, cursor):
|
def make_debug_cursor(self, cursor):
|
||||||
return CursorDebugWrapper(cursor, self)
|
return CursorDebugWrapper(cursor, self)
|
||||||
|
|
@ -607,6 +849,36 @@ if is_psycopg3:
|
||||||
with self.debug_sql(statement):
|
with self.debug_sql(statement):
|
||||||
return self.cursor.copy(statement)
|
return self.cursor.copy(statement)
|
||||||
|
|
||||||
|
class AsyncServerBindingCursor(CursorMixin, Database.AsyncClientCursor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AsyncCursor(CursorMixin, Database.AsyncClientCursor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AsyncServerSideCursor(
|
||||||
|
CursorMixin,
|
||||||
|
Database.client_cursor.ClientCursorMixin,
|
||||||
|
Database.AsyncServerCursor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
psycopg >= 3 forces the usage of server-side bindings when using named
|
||||||
|
cursors but the ORM doesn't yet support the systematic generation of
|
||||||
|
prepareable SQL (#20516).
|
||||||
|
|
||||||
|
ClientCursorMixin forces the usage of client-side bindings while
|
||||||
|
AsyncServerCursor implements the logic required to declare and scroll
|
||||||
|
through named cursors.
|
||||||
|
|
||||||
|
Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to
|
||||||
|
specify how parameters should be bound instead, which AsyncServerCursor
|
||||||
|
would inherit, but that's not the case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AsyncCursorDebugWrapper(AsyncBaseCursorDebugWrapper):
|
||||||
|
def copy(self, statement):
|
||||||
|
with self.debug_sql(statement):
|
||||||
|
return self.cursor.copy(statement)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
Cursor = psycopg2.extensions.cursor
|
Cursor = psycopg2.extensions.cursor
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
END;
|
END;
|
||||||
$$ LANGUAGE plpgsql;"""
|
$$ LANGUAGE plpgsql;"""
|
||||||
requires_casted_case_in_updates = True
|
requires_casted_case_in_updates = True
|
||||||
|
supports_async = is_psycopg3
|
||||||
supports_over_clause = True
|
supports_over_clause = True
|
||||||
supports_frame_exclusion = True
|
supports_frame_exclusion = True
|
||||||
only_supports_unbounded_with_preceding_and_following = True
|
only_supports_unbounded_with_preceding_and_following = True
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,98 @@ class CursorWrapper:
|
||||||
return self.cursor.executemany(sql, param_list)
|
return self.cursor.executemany(sql, param_list)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCursorCtx:
|
||||||
|
"""
|
||||||
|
Asynchronous context manager to hold an async cursor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db, name=None):
|
||||||
|
self.db = db
|
||||||
|
self.name = name
|
||||||
|
self.wrap_database_errors = self.db.wrap_database_errors
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
await self.db.aclose_if_health_check_failed()
|
||||||
|
await self.db.aensure_connection()
|
||||||
|
self.wrap_database_errors.__enter__()
|
||||||
|
return self.db._aprepare_cursor(self.db.create_async_cursor(self.name))
|
||||||
|
|
||||||
|
async def __aexit__(self, type, value, traceback):
|
||||||
|
self.wrap_database_errors.__exit__(type, value, traceback)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCursorWrapper(CursorWrapper):
|
||||||
|
async def _aexecute(self, sql, params, *ignored_wrapper_args):
|
||||||
|
# Raise a warning during app initialization (stored_app_configs is only
|
||||||
|
# ever set during testing).
|
||||||
|
if not apps.ready and not apps.stored_app_configs:
|
||||||
|
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
|
||||||
|
self.db.validate_no_broken_transaction()
|
||||||
|
with self.db.wrap_database_errors:
|
||||||
|
if params is None:
|
||||||
|
# params default might be backend specific.
|
||||||
|
return await self.cursor.execute(sql)
|
||||||
|
else:
|
||||||
|
return await self.cursor.execute(sql, params)
|
||||||
|
|
||||||
|
async def _aexecute_with_wrappers(self, sql, params, many, executor):
|
||||||
|
context = {"connection": self.db, "cursor": self}
|
||||||
|
for wrapper in reversed(self.db.execute_wrappers):
|
||||||
|
executor = functools.partial(wrapper, executor)
|
||||||
|
return await executor(sql, params, many, context)
|
||||||
|
|
||||||
|
async def aexecute(self, sql, params=None):
|
||||||
|
return await self._aexecute_with_wrappers(
|
||||||
|
sql, params, many=False, executor=self._aexecute
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _aexecutemany(self, sql, param_list, *ignored_wrapper_args):
|
||||||
|
# Raise a warning during app initialization (stored_app_configs is only
|
||||||
|
# ever set during testing).
|
||||||
|
if not apps.ready and not apps.stored_app_configs:
|
||||||
|
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
|
||||||
|
self.db.validate_no_broken_transaction()
|
||||||
|
with self.db.wrap_database_errors:
|
||||||
|
return await self.cursor.executemany(sql, param_list)
|
||||||
|
|
||||||
|
async def aexecutemany(self, sql, param_list):
|
||||||
|
return await self._aexecute_with_wrappers(
|
||||||
|
sql, param_list, many=True, executor=self._aexecutemany
|
||||||
|
)
|
||||||
|
|
||||||
|
async def afetchone(self, *args, **kwargs):
|
||||||
|
return await self.cursor.fetchone(*args, **kwargs)
|
||||||
|
|
||||||
|
async def afetchmany(self, *args, **kwargs):
|
||||||
|
return await self.cursor.fetchmany(*args, **kwargs)
|
||||||
|
|
||||||
|
async def afetchall(self, *args, **kwargs):
|
||||||
|
return await self.cursor.fetchall(*args, **kwargs)
|
||||||
|
|
||||||
|
def acopy(self, *args, **kwargs):
|
||||||
|
return self.cursor.copy(*args, **kwargs)
|
||||||
|
|
||||||
|
def astream(self, *args, **kwargs):
|
||||||
|
return self.cursor.stream(*args, **kwargs)
|
||||||
|
|
||||||
|
async def ascroll(self, *args, **kwargs):
|
||||||
|
return await self.cursor.scroll(*args, **kwargs)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, type, value, traceback):
|
||||||
|
try:
|
||||||
|
await self.close()
|
||||||
|
except self.db.Database.Error:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aiter__(self):
|
||||||
|
with self.db.wrap_database_errors:
|
||||||
|
async for item in self.cursor:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
class CursorDebugWrapper(CursorWrapper):
|
class CursorDebugWrapper(CursorWrapper):
|
||||||
# XXX callproc isn't instrumented at this time.
|
# XXX callproc isn't instrumented at this time.
|
||||||
|
|
||||||
|
|
@ -163,6 +255,57 @@ class CursorDebugWrapper(CursorWrapper):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCursorDebugWrapper(AsyncCursorWrapper):
|
||||||
|
# XXX callproc isn't instrumented at this time.
|
||||||
|
|
||||||
|
async def aexecute(self, sql, params=None):
|
||||||
|
with self.debug_sql(sql, params, use_last_executed_query=True):
|
||||||
|
return await super().aexecute(sql, params)
|
||||||
|
|
||||||
|
async def aexecutemany(self, sql, param_list):
|
||||||
|
with self.debug_sql(sql, param_list, many=True):
|
||||||
|
return await super().aexecutemany(sql, param_list)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def debug_sql(
|
||||||
|
self, sql=None, params=None, use_last_executed_query=False, many=False
|
||||||
|
):
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
stop = time.monotonic()
|
||||||
|
duration = stop - start
|
||||||
|
if use_last_executed_query:
|
||||||
|
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
|
||||||
|
try:
|
||||||
|
times = len(params) if many else ""
|
||||||
|
except TypeError:
|
||||||
|
# params could be an iterator.
|
||||||
|
times = "?"
|
||||||
|
self.db.queries_log.append(
|
||||||
|
{
|
||||||
|
"sql": "%s times: %s" % (times, sql) if many else sql,
|
||||||
|
"time": "%.3f" % duration,
|
||||||
|
"async": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"(%.3f) %s; args=%s; alias=%s; async=True",
|
||||||
|
duration,
|
||||||
|
sql,
|
||||||
|
params,
|
||||||
|
self.db.alias,
|
||||||
|
extra={
|
||||||
|
"duration": duration,
|
||||||
|
"sql": sql,
|
||||||
|
"params": params,
|
||||||
|
"alias": self.db.alias,
|
||||||
|
"async": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def debug_transaction(connection, sql):
|
def debug_transaction(connection, sql):
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
|
|
@ -176,18 +319,21 @@ def debug_transaction(connection, sql):
|
||||||
{
|
{
|
||||||
"sql": "%s" % sql,
|
"sql": "%s" % sql,
|
||||||
"time": "%.3f" % duration,
|
"time": "%.3f" % duration,
|
||||||
|
"async": connection.features.supports_async,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"(%.3f) %s; args=%s; alias=%s",
|
"(%.3f) %s; args=%s; alias=%s; async=%s",
|
||||||
duration,
|
duration,
|
||||||
sql,
|
sql,
|
||||||
None,
|
None,
|
||||||
connection.alias,
|
connection.alias,
|
||||||
|
connection.features.supports_async,
|
||||||
extra={
|
extra={
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
"sql": sql,
|
"sql": sql,
|
||||||
"alias": connection.alias,
|
"alias": connection.alias,
|
||||||
|
"async": connection.features.supports_async,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
|
from asgiref.local import Local
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
|
|
@ -197,6 +199,89 @@ class ConnectionHandler(BaseConnectionHandler):
|
||||||
return backend.DatabaseWrapper(db, alias)
|
return backend.DatabaseWrapper(db, alias)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAlias:
|
||||||
|
"""
|
||||||
|
A Context-aware list of connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._connections = Local()
|
||||||
|
setattr(self._connections, "_stack", [])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connections(self):
|
||||||
|
return getattr(self._connections, "_stack", [])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.connections)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.connections)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return ", ".join([str(id(conn)) for conn in self.connections])
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<{self.__class__.__name__}: {len(self.connections)} connections>"
|
||||||
|
|
||||||
|
def add_connection(self, connection):
|
||||||
|
setattr(self._connections, "_stack", self.connections + [connection])
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
conns = self.connections
|
||||||
|
conns.pop()
|
||||||
|
setattr(self._connections, "_stack", conns)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncConnectionHandler:
|
||||||
|
"""
|
||||||
|
Context-aware class to store async connections, mapped by alias name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_from_testcase = False
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._aliases = Local()
|
||||||
|
self._connection_count = Local()
|
||||||
|
setattr(self._connection_count, "value", 0)
|
||||||
|
|
||||||
|
def __getitem__(self, alias):
|
||||||
|
try:
|
||||||
|
async_alias = getattr(self._aliases, alias)
|
||||||
|
except AttributeError:
|
||||||
|
async_alias = AsyncAlias()
|
||||||
|
setattr(self._aliases, alias, async_alias)
|
||||||
|
return async_alias
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__}: {self.count} connections>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self):
|
||||||
|
return getattr(self._connection_count, "value", 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def empty(self):
|
||||||
|
return self.count == 0
|
||||||
|
|
||||||
|
def add_connection(self, using, connection):
|
||||||
|
self[using].add_connection(connection)
|
||||||
|
setattr(self._connection_count, "value", self.count + 1)
|
||||||
|
|
||||||
|
async def pop_connection(self, using):
|
||||||
|
await self[using].connections[-1].aclose_pool()
|
||||||
|
self[using].connections.pop()
|
||||||
|
setattr(self._connection_count, "value", self.count - 1)
|
||||||
|
|
||||||
|
def get_connection(self, using):
|
||||||
|
alias = self[using]
|
||||||
|
if len(alias.connections) == 0:
|
||||||
|
raise ConnectionDoesNotExist(
|
||||||
|
f"There are no async connections using the '{using}' alias."
|
||||||
|
)
|
||||||
|
return alias.connections[-1]
|
||||||
|
|
||||||
|
|
||||||
class ConnectionRouter:
|
class ConnectionRouter:
|
||||||
def __init__(self, routers=None):
|
def __init__(self, routers=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,13 @@ from django.core.management.color import no_style
|
||||||
from django.core.management.sql import emit_post_migrate_signal
|
from django.core.management.sql import emit_post_migrate_signal
|
||||||
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
|
from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler
|
||||||
from django.core.signals import setting_changed
|
from django.core.signals import setting_changed
|
||||||
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
|
from django.db import (
|
||||||
|
DEFAULT_DB_ALIAS,
|
||||||
|
async_connections,
|
||||||
|
connection,
|
||||||
|
connections,
|
||||||
|
transaction,
|
||||||
|
)
|
||||||
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
||||||
from django.forms.fields import CharField
|
from django.forms.fields import CharField
|
||||||
from django.http import QueryDict
|
from django.http import QueryDict
|
||||||
|
|
@ -1415,6 +1421,7 @@ class TestCase(TransactionTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super().setUpClass()
|
super().setUpClass()
|
||||||
|
async_connections._from_testcase = True
|
||||||
if not (
|
if not (
|
||||||
cls._databases_support_transactions()
|
cls._databases_support_transactions()
|
||||||
and cls._databases_support_savepoints()
|
and cls._databases_support_savepoints()
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,6 @@ Database backends
|
||||||
* MySQL connections now default to using the ``utf8mb4`` character set,
|
* MySQL connections now default to using the ``utf8mb4`` character set,
|
||||||
instead of ``utf8``, which is an alias for the deprecated character set
|
instead of ``utf8``, which is an alias for the deprecated character set
|
||||||
``utf8mb3``.
|
``utf8mb3``.
|
||||||
|
|
||||||
* Oracle backends now support :ref:`connection pools <oracle-pool>`, by setting
|
* Oracle backends now support :ref:`connection pools <oracle-pool>`, by setting
|
||||||
``"pool"`` in the :setting:`OPTIONS` part of your database configuration.
|
``"pool"`` in the :setting:`OPTIONS` part of your database configuration.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,19 @@ instance of those now-deprecated classes.
|
||||||
Minor features
|
Minor features
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
Database backends
|
||||||
|
~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
* It is now possible to perform asynchronous raw SQL queries using an async
|
||||||
|
cursor.
|
||||||
|
This is only possible on backends that support async-native connections.
|
||||||
|
Currently only supported in PostreSQL with the
|
||||||
|
``django.db.backends.postgresql`` backend.
|
||||||
|
* It is now possible to perform asynchronous raw SQL queries using an async
|
||||||
|
cursor, if the backend supports async-native connections. This is only
|
||||||
|
supported on PostgreSQL with ``psycopg`` 3.1.8+. See
|
||||||
|
:ref:`async-connection-cursor` for more details.
|
||||||
|
|
||||||
:mod:`django.contrib.admin`
|
:mod:`django.contrib.admin`
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -404,6 +404,37 @@ is equivalent to::
|
||||||
finally:
|
finally:
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
|
.. _async-connection-cursor:
|
||||||
|
|
||||||
|
Async Connections and cursors
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. versionadded:: 6.0
|
||||||
|
|
||||||
|
On backends that support async-native connections, you can request an async
|
||||||
|
cursor::
|
||||||
|
|
||||||
|
from django.db import new_connection
|
||||||
|
|
||||||
|
async with new_connection() as connection:
|
||||||
|
async with connection.acursor() as c:
|
||||||
|
await c.aexecute(...)
|
||||||
|
|
||||||
|
Async cursors provide the following methods:
|
||||||
|
|
||||||
|
* ``.aexecute()``
|
||||||
|
* ``.aexecutemany()``
|
||||||
|
* ``.afetchone()``
|
||||||
|
* ``.afetchmany()``
|
||||||
|
* ``.afetchall()``
|
||||||
|
* ``.acopy()``
|
||||||
|
* ``.astream()``
|
||||||
|
* ``.ascroll()``
|
||||||
|
|
||||||
|
Currently, Django ships with the following async-enabled backend:
|
||||||
|
|
||||||
|
* ``django.db.backends.postgresql`` with ``psycopg3``.
|
||||||
|
|
||||||
Calling stored procedures
|
Calling stored procedures
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
||||||
26
tests/async/test_async_connections.py
Normal file
26
tests/async/test_async_connections.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
from django.db import new_connection
|
||||||
|
from django.test import TransactionTestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
|
from .models import SimpleModel
|
||||||
|
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_async")
|
||||||
|
class AsyncSyncCominglingTest(TransactionTestCase):
|
||||||
|
|
||||||
|
available_apps = ["async"]
|
||||||
|
|
||||||
|
async def change_model_with_async(self, obj):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute(
|
||||||
|
"""UPDATE "async_simplemodel" SET "field" = 10,"""
|
||||||
|
""" "created" = '2024-11-19 14:12:50.606384'::timestamp"""
|
||||||
|
""" WHERE "async_simplemodel"."id" = 1"""
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_transaction_async_comingling(self):
|
||||||
|
s1 = await sync_to_async(SimpleModel.objects.create)(field=0)
|
||||||
|
# with transaction.atomic():
|
||||||
|
await self.change_model_with_async(s1)
|
||||||
133
tests/async/test_async_cursor.py
Normal file
133
tests/async/test_async_cursor.py
Normal file
|
|
@ -0,0 +1,133 @@
|
||||||
|
from django.db import new_connection
|
||||||
|
from django.test import SimpleTestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_async")
|
||||||
|
class AsyncCursorTests(SimpleTestCase):
|
||||||
|
databases = {"default", "other"}
|
||||||
|
|
||||||
|
async def test_aexecute(self):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute("SELECT 1")
|
||||||
|
|
||||||
|
async def test_aexecutemany(self):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute("CREATE TABLE numbers (number SMALLINT)")
|
||||||
|
await cursor.aexecutemany(
|
||||||
|
"INSERT INTO numbers VALUES (%s)", [(1,), (2,), (3,)]
|
||||||
|
)
|
||||||
|
await cursor.aexecute("SELECT * FROM numbers")
|
||||||
|
result = await cursor.afetchall()
|
||||||
|
self.assertEqual(result, [(1,), (2,), (3,)])
|
||||||
|
|
||||||
|
async def test_afetchone(self):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute("SELECT 1")
|
||||||
|
result = await cursor.afetchone()
|
||||||
|
self.assertEqual(result, (1,))
|
||||||
|
|
||||||
|
async def test_afetchmany(self):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM (VALUES
|
||||||
|
('BANANA'),
|
||||||
|
('STRAWBERRY'),
|
||||||
|
('MELON')
|
||||||
|
) AS v (NAME)"""
|
||||||
|
)
|
||||||
|
result = await cursor.afetchmany(size=2)
|
||||||
|
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",)])
|
||||||
|
|
||||||
|
async def test_afetchall(self):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM (VALUES
|
||||||
|
('BANANA'),
|
||||||
|
('STRAWBERRY'),
|
||||||
|
('MELON')
|
||||||
|
) AS v (NAME)"""
|
||||||
|
)
|
||||||
|
result = await cursor.afetchall()
|
||||||
|
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||||
|
|
||||||
|
async def test_aiter(self):
|
||||||
|
result = []
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM (VALUES
|
||||||
|
('BANANA'),
|
||||||
|
('STRAWBERRY'),
|
||||||
|
('MELON')
|
||||||
|
) AS v (NAME)"""
|
||||||
|
)
|
||||||
|
async for record in cursor:
|
||||||
|
result.append(record)
|
||||||
|
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||||
|
|
||||||
|
async def test_acopy(self):
|
||||||
|
result = []
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
async with cursor.acopy(
|
||||||
|
"""
|
||||||
|
COPY (
|
||||||
|
SELECT *
|
||||||
|
FROM (VALUES
|
||||||
|
('BANANA'),
|
||||||
|
('STRAWBERRY'),
|
||||||
|
('MELON')
|
||||||
|
) AS v (NAME)
|
||||||
|
) TO STDOUT"""
|
||||||
|
) as copy:
|
||||||
|
async for row in copy.rows():
|
||||||
|
result.append(row)
|
||||||
|
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||||
|
|
||||||
|
async def test_astream(self):
|
||||||
|
result = []
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
async for record in cursor.astream(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM (VALUES
|
||||||
|
('BANANA'),
|
||||||
|
('STRAWBERRY'),
|
||||||
|
('MELON')
|
||||||
|
) AS v (NAME)"""
|
||||||
|
):
|
||||||
|
result.append(record)
|
||||||
|
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||||
|
|
||||||
|
async def test_ascroll(self):
|
||||||
|
result = []
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM (VALUES
|
||||||
|
('BANANA'),
|
||||||
|
('STRAWBERRY'),
|
||||||
|
('MELON')
|
||||||
|
) AS v (NAME)"""
|
||||||
|
)
|
||||||
|
await cursor.ascroll(1, "absolute")
|
||||||
|
|
||||||
|
result = await cursor.afetchall()
|
||||||
|
self.assertEqual(result, [("STRAWBERRY",), ("MELON",)])
|
||||||
|
await cursor.ascroll(0, "absolute")
|
||||||
|
result = await cursor.afetchall()
|
||||||
|
self.assertEqual(result, [("BANANA",), ("STRAWBERRY",), ("MELON",)])
|
||||||
21
tests/backends/base/test_base_async.py
Normal file
21
tests/backends/base/test_base_async.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from django.db import new_connection
|
||||||
|
from django.test import SimpleTestCase, skipUnlessDBFeature
|
||||||
|
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_async")
|
||||||
|
class AsyncDatabaseWrapperTests(SimpleTestCase):
|
||||||
|
databases = {"default", "other"}
|
||||||
|
|
||||||
|
async def test_async_cursor(self):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.execute("SELECT 1")
|
||||||
|
result = (await cursor.fetchone())[0]
|
||||||
|
self.assertEqual(result, 1)
|
||||||
|
|
||||||
|
async def test_async_cursor_alias(self):
|
||||||
|
async with new_connection() as conn:
|
||||||
|
async with conn.acursor() as cursor:
|
||||||
|
await cursor.aexecute("SELECT 1")
|
||||||
|
result = (await cursor.afetchone())[0]
|
||||||
|
self.assertEqual(result, 1)
|
||||||
|
|
@ -1,11 +1,26 @@
|
||||||
"""Tests for django.db.utils."""
|
"""Tests for django.db.utils."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.db import DEFAULT_DB_ALIAS, ProgrammingError, connection
|
from django.db import (
|
||||||
from django.db.utils import ConnectionHandler, load_backend
|
DEFAULT_DB_ALIAS,
|
||||||
from django.test import SimpleTestCase, TestCase
|
NotSupportedError,
|
||||||
|
ProgrammingError,
|
||||||
|
async_connections,
|
||||||
|
connection,
|
||||||
|
new_connection,
|
||||||
|
)
|
||||||
|
from django.db.utils import (
|
||||||
|
AsyncAlias,
|
||||||
|
AsyncConnectionHandler,
|
||||||
|
ConnectionHandler,
|
||||||
|
load_backend,
|
||||||
|
)
|
||||||
|
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
||||||
from django.utils.connection import ConnectionDoesNotExist
|
from django.utils.connection import ConnectionDoesNotExist
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -90,3 +105,82 @@ class LoadBackendTests(SimpleTestCase):
|
||||||
with self.assertRaisesMessage(ImproperlyConfigured, msg) as cm:
|
with self.assertRaisesMessage(ImproperlyConfigured, msg) as cm:
|
||||||
load_backend("foo")
|
load_backend("foo")
|
||||||
self.assertEqual(str(cm.exception.__cause__), "No module named 'foo'")
|
self.assertEqual(str(cm.exception.__cause__), "No module named 'foo'")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncConnectionTests(SimpleTestCase):
|
||||||
|
databases = {"default", "other"}
|
||||||
|
|
||||||
|
def run_pool(self, coro, count=2):
|
||||||
|
def fn():
|
||||||
|
asyncio.run(coro())
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
futures = []
|
||||||
|
for _ in range(count):
|
||||||
|
futures.append(executor.submit(fn))
|
||||||
|
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
exc = future.exception()
|
||||||
|
if exc is not None:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
def test_async_alias(self):
|
||||||
|
alias = AsyncAlias()
|
||||||
|
assert len(alias) == 0
|
||||||
|
assert alias.connections == []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
assert len(alias) == 0
|
||||||
|
alias.add_connection(mock.Mock())
|
||||||
|
alias.pop()
|
||||||
|
|
||||||
|
self.run_pool(coro)
|
||||||
|
|
||||||
|
def test_async_connection_handler(self):
|
||||||
|
aconns = AsyncConnectionHandler()
|
||||||
|
assert aconns.empty is True
|
||||||
|
assert aconns["default"].connections == []
|
||||||
|
|
||||||
|
async def coro():
|
||||||
|
assert aconns["default"].connections == []
|
||||||
|
aconns.add_connection("default", mock.Mock())
|
||||||
|
aconns.pop_connection("default")
|
||||||
|
|
||||||
|
self.run_pool(coro)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_async")
|
||||||
|
def test_new_connection_threading(self):
|
||||||
|
async def coro():
|
||||||
|
assert async_connections.empty is True
|
||||||
|
async with new_connection() as connection:
|
||||||
|
async with connection.acursor() as c:
|
||||||
|
await c.execute("SELECT 1")
|
||||||
|
|
||||||
|
self.run_pool(coro)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_async")
|
||||||
|
async def test_new_connection(self):
|
||||||
|
with self.assertRaises(ConnectionDoesNotExist):
|
||||||
|
async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||||
|
|
||||||
|
async with new_connection():
|
||||||
|
conn1 = async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||||
|
self.assertIsNotNone(conn1.aconnection)
|
||||||
|
async with new_connection():
|
||||||
|
conn2 = async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||||
|
self.assertIsNotNone(conn1.aconnection)
|
||||||
|
self.assertIsNotNone(conn2.aconnection)
|
||||||
|
self.assertNotEqual(conn1.aconnection, conn2.aconnection)
|
||||||
|
|
||||||
|
self.assertIsNotNone(conn1.aconnection)
|
||||||
|
self.assertIsNone(conn2.aconnection)
|
||||||
|
self.assertIsNone(conn1.aconnection)
|
||||||
|
|
||||||
|
with self.assertRaises(ConnectionDoesNotExist):
|
||||||
|
async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_async")
|
||||||
|
async def test_new_connection_on_sync(self):
|
||||||
|
with self.assertRaises(NotSupportedError):
|
||||||
|
async with new_connection():
|
||||||
|
async_connections.get_connection(DEFAULT_DB_ALIAS)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from django.db import (
|
||||||
IntegrityError,
|
IntegrityError,
|
||||||
OperationalError,
|
OperationalError,
|
||||||
connection,
|
connection,
|
||||||
|
new_connection,
|
||||||
transaction,
|
transaction,
|
||||||
)
|
)
|
||||||
from django.test import (
|
from django.test import (
|
||||||
|
|
@ -586,3 +587,93 @@ class DurableTransactionTests(DurableTestsBase, TransactionTestCase):
|
||||||
|
|
||||||
class DurableTests(DurableTestsBase, TestCase):
|
class DurableTests(DurableTestsBase, TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("uses_savepoints", "supports_async")
|
||||||
|
class AsyncTransactionTestCase(TransactionTestCase):
|
||||||
|
available_apps = ["transactions"]
|
||||||
|
|
||||||
|
async def test_new_connection_nested(self):
|
||||||
|
async with new_connection() as connection:
|
||||||
|
async with new_connection() as connection2:
|
||||||
|
await connection2.aset_autocommit(False)
|
||||||
|
async with connection2.acursor() as cursor2:
|
||||||
|
await cursor2.aexecute(
|
||||||
|
"INSERT INTO transactions_reporter "
|
||||||
|
"(first_name, last_name, email) "
|
||||||
|
"VALUES (%s, %s, %s)",
|
||||||
|
("Sarah", "Hatoff", ""),
|
||||||
|
)
|
||||||
|
await cursor2.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor2.afetchmany()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
async with connection.acursor() as cursor:
|
||||||
|
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor.afetchmany()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
async def test_new_connection_nested2(self):
|
||||||
|
async with new_connection() as connection:
|
||||||
|
async with connection.acursor() as cursor:
|
||||||
|
await cursor.aexecute(
|
||||||
|
"INSERT INTO transactions_reporter (first_name, last_name, email) "
|
||||||
|
"VALUES (%s, %s, %s)",
|
||||||
|
("Sarah", "Hatoff", ""),
|
||||||
|
)
|
||||||
|
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor.afetchmany()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
async with new_connection() as connection2:
|
||||||
|
await connection2.aset_autocommit(False)
|
||||||
|
async with connection2.acursor() as cursor2:
|
||||||
|
await cursor2.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor2.afetchmany()
|
||||||
|
# This connection won't see any rows, because the outer one
|
||||||
|
# hasn't committed yet.
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
async def test_new_connection_nested3(self):
|
||||||
|
async with new_connection() as connection:
|
||||||
|
async with new_connection() as connection2:
|
||||||
|
await connection2.aset_autocommit(False)
|
||||||
|
assert id(connection) != id(connection2)
|
||||||
|
async with connection2.acursor() as cursor2:
|
||||||
|
await cursor2.aexecute(
|
||||||
|
"INSERT INTO transactions_reporter "
|
||||||
|
"(first_name, last_name, email) "
|
||||||
|
"VALUES (%s, %s, %s)",
|
||||||
|
("Sarah", "Hatoff", ""),
|
||||||
|
)
|
||||||
|
await cursor2.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor2.afetchmany()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
# Outermost connection doesn't see what the innermost did,
|
||||||
|
# because the innermost connection hasn't exited yet.
|
||||||
|
async with connection.acursor() as cursor:
|
||||||
|
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor.afetchmany()
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
async def test_asavepoint(self):
|
||||||
|
async with new_connection() as connection:
|
||||||
|
async with connection.acursor() as cursor:
|
||||||
|
sid = await connection.asavepoint()
|
||||||
|
assert sid is not None
|
||||||
|
|
||||||
|
await cursor.aexecute(
|
||||||
|
"INSERT INTO transactions_reporter (first_name, last_name, email) "
|
||||||
|
"VALUES (%s, %s, %s)",
|
||||||
|
("Archibald", "Haddock", ""),
|
||||||
|
)
|
||||||
|
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor.afetchmany(size=5)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][1:] == ("Archibald", "Haddock", "")
|
||||||
|
|
||||||
|
await connection.asavepoint_rollback(sid)
|
||||||
|
await cursor.aexecute("SELECT * FROM transactions_reporter")
|
||||||
|
result = await cursor.fetchmany(size=5)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue