diff --git a/django/db/__init__.py b/django/db/__init__.py index 1f6123e3a4..c8ab060027 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -51,7 +51,7 @@ class new_connection: async def __aenter__(self): conn = connections.create_connection(self.using) - if conn.supports_async is False: + if conn.features.supports_async is False: raise NotSupportedError( "The database backend does not support asynchronous execution." ) diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index f54da845d4..cf232c6dc7 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -39,7 +39,6 @@ class BaseDatabaseWrapper: ops = None vendor = "unknown" display_name = "unknown" - supports_async = False SchemaEditorClass = None # Classes instantiated in __init__(). diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 0c79e5c133..2734d66999 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -358,6 +358,9 @@ class BaseDatabaseFeatures: # Does the backend support negative JSON array indexing? supports_json_negative_indexing = True + # Asynchronous database operations + supports_async = False + # Does the backend support column collations? supports_collation_on_charfield = True supports_collation_on_textfield = True diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 4743877837..25cd07ac36 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -94,7 +94,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): vendor = "postgresql" display_name = "PostgreSQL" _pg_version = None - supports_async = is_psycopg3 # This dictionary maps Field objects to their associated PostgreSQL column # types, as strings. Column-type strings can contain format strings; diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 5f63b6c713..c21ce7a582 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -54,6 +54,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): END; $$ LANGUAGE plpgsql;""" requires_casted_case_in_updates = True + supports_async = is_psycopg3 supports_over_clause = True supports_frame_exclusion = True only_supports_unbounded_with_preceding_and_following = True diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index b4d956dfe9..e259d54e48 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -319,7 +319,7 @@ def debug_transaction(connection, sql): { "sql": "%s" % sql, "time": "%.3f" % duration, - "async": connection.supports_async, + "async": connection.features.supports_async, } ) logger.debug( @@ -328,12 +328,12 @@ def debug_transaction(connection, sql): sql, None, connection.alias, - connection.supports_async, + connection.features.supports_async, extra={ "duration": duration, "sql": sql, "alias": connection.alias, - "async": connection.supports_async, + "async": connection.features.supports_async, }, ) diff --git a/tests/async/test_async_connections.py b/tests/async/test_async_connections.py index 625e266779..c28745c1dd 100644 --- a/tests/async/test_async_connections.py +++ b/tests/async/test_async_connections.py @@ -1,14 +1,12 @@ -import unittest - from asgiref.sync import sync_to_async -from django.db import connection, new_connection -from django.test import TransactionTestCase +from django.db import new_connection +from django.test import TransactionTestCase, skipUnlessDBFeature from .models import SimpleModel -@unittest.skipUnless(connection.supports_async is True, "Async DB test") +@skipUnlessDBFeature("supports_async") class AsyncSyncCominglingTest(TransactionTestCase): available_apps = ["async"] diff --git a/tests/async/test_async_cursor.py b/tests/async/test_async_cursor.py index 7b11864cf9..b1beea8568 100644 --- a/tests/async/test_async_cursor.py +++ b/tests/async/test_async_cursor.py @@ -1,11 +1,11 @@ -import unittest - -from django.db import connection, new_connection -from django.test import SimpleTestCase +from django.db import new_connection +from django.test import SimpleTestCase, skipUnlessDBFeature -@unittest.skipUnless(connection.supports_async is True, "Async DB test") +@skipUnlessDBFeature("supports_async") class AsyncCursorTests(SimpleTestCase): + databases = {"default", "other"} + async def test_aexecute(self): async with new_connection() as conn: async with conn.acursor() as cursor: diff --git a/tests/backends/base/test_base_async.py b/tests/backends/base/test_base_async.py index 5977776957..633c88aa69 100644 --- a/tests/backends/base/test_base_async.py +++ b/tests/backends/base/test_base_async.py @@ -1,11 +1,11 @@ -import unittest - -from django.db import connection, new_connection -from django.test import SimpleTestCase +from django.db import new_connection +from django.test import SimpleTestCase, skipUnlessDBFeature -@unittest.skipUnless(connection.supports_async is True, "Async DB test") +@skipUnlessDBFeature("supports_async") class AsyncDatabaseWrapperTests(SimpleTestCase): + databases = {"default", "other"} + async def test_async_cursor(self): async with new_connection() as conn: async with conn.acursor() as cursor: diff --git a/tests/db_utils/tests.py b/tests/db_utils/tests.py index 9f01dc1a40..aa10286904 100644 --- a/tests/db_utils/tests.py +++ b/tests/db_utils/tests.py @@ -20,7 +20,7 @@ from django.db.utils import ( ConnectionHandler, load_backend, ) -from django.test import SimpleTestCase, TestCase +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.utils.connection import ConnectionDoesNotExist @@ -108,6 +108,8 @@ class LoadBackendTests(SimpleTestCase): class AsyncConnectionTests(SimpleTestCase): + databases = {"default", "other"} + def run_pool(self, coro, count=2): def fn(): asyncio.run(coro()) @@ -146,7 +148,7 @@ class AsyncConnectionTests(SimpleTestCase): self.run_pool(coro) - @unittest.skipUnless(connection.supports_async is True, "Async DB test") + @skipUnlessDBFeature("supports_async") def test_new_connection_threading(self): async def coro(): assert async_connections.empty is True @@ -156,7 +158,7 @@ class AsyncConnectionTests(SimpleTestCase): self.run_pool(coro) - @unittest.skipUnless(connection.supports_async is True, "Async DB test") + @skipUnlessDBFeature("supports_async") async def test_new_connection(self): with self.assertRaises(ConnectionDoesNotExist): async_connections.get_connection(DEFAULT_DB_ALIAS) @@ -177,7 +179,7 @@ class AsyncConnectionTests(SimpleTestCase): with self.assertRaises(ConnectionDoesNotExist): async_connections.get_connection(DEFAULT_DB_ALIAS) - @unittest.skipUnless(connection.supports_async is False, "Sync DB test") + @skipUnlessDBFeature("supports_async") async def test_new_connection_on_sync(self): with self.assertRaises(NotSupportedError): async with new_connection(): diff --git a/tests/transactions/tests.py b/tests/transactions/tests.py index fa31267857..e80cb36eb1 100644 --- a/tests/transactions/tests.py +++ b/tests/transactions/tests.py @@ -589,8 +589,7 @@ class DurableTests(DurableTestsBase, TestCase): pass -@skipUnlessDBFeature("uses_savepoints") -@skipUnless(connection.supports_async is True, "Async DB test") +@skipUnlessDBFeature("uses_savepoints", "supports_async") class AsyncTransactionTestCase(TransactionTestCase): available_apps = ["transactions"]