diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 2ada5177be..c3faa920fe 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -339,6 +339,8 @@ class BaseDatabaseFeatures: # Does the backend support JSONField? supports_json_field = True + # Does the backend implement support for ABSENT ON NULL clause? + supports_json_absent_on_null = True # Can the backend introspect a JSONField? can_introspect_json_field = True # Does the backend support primitives in JSONField? diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index b289d0af2f..e0ffc04b77 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -60,6 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_virtual_generated_columns = True supports_json_negative_indexing = False + supports_json_absent_on_null = False @cached_property def minimum_database_version(self): diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 4fa6ab831b..b17ad9b513 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -36,6 +36,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_aggregate_distinct_multiple_argument = False supports_any_value = True order_by_nulls_first = True + supports_json_absent_on_null = False supports_json_field_contains = False supports_update_conflicts = True supports_update_conflicts_with_target = True diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 89ad23909d..421bf93d9d 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -20,6 +20,7 @@ from django.db.models.functions.mixins import ( FixDurationInputMixin, NumericOutputFieldMixin, ) +from django.db.models.lookups import IsNull __all__ = [ "Aggregate", @@ -407,11 +408,20 @@ class JSONArrayAgg(Aggregate): allow_order_by = True arity = 1 + def __init__(self, *expressions, absent_on_null=False, **extra): + self.absent_on_null = absent_on_null + super().__init__(*expressions, **extra) + def as_sql(self, compiler, connection, **extra_context): if self.filter and not connection.features.supports_aggregate_filter_clause: raise NotSupportedError( "JSONArrayAgg(filter) is not supported on this database backend." ) + if self.absent_on_null and not connection.features.supports_json_absent_on_null: + raise NotSupportedError( + "JSONArrayAgg(absent_on_null) is not supported on this database " + "backend." + ) return super().as_sql(compiler, connection, **extra_context) def as_mysql(self, compiler, connection, **extra_context): @@ -444,21 +454,60 @@ class JSONArrayAgg(Aggregate): sql = f"(CASE WHEN {count_sql} > 0 THEN {sql}{default_sql} END)" return sql, count_params + params + default_params + def as_native(self, compiler, connection, *, returning=None, **extra_context): + # Oracle and PostgreSQL 16+ default to removing SQL null values from + # the returned array. This adds the NULL ON NULL clause to preserve + # the null values in the array as default behaviour Similar to that + # of SQLite, also removes the null values from the array when + # specified via ABSENT ON NULL. + if len(self.get_source_expressions()) == 0: + on_null_clause = "" + elif self.absent_on_null: + on_null_clause = "ABSENT ON NULL" + else: + on_null_clause = "NULL ON NULL" + if returning: + extra_context.setdefault( + "template", + "%(function)s(%(distinct)s%(expressions)s%(order_by)s " + f"{on_null_clause} RETURNING {returning}) %(filter)s", + ) + else: + extra_context.setdefault( + "template", + "%(function)s(%(distinct)s%(expressions)s%(order_by)s " + f"{on_null_clause}) %(filter)s", + ) + return self.as_sql(compiler, connection, **extra_context) + def as_postgresql(self, compiler, connection, **extra_context): if not connection.features.is_postgresql_16: - sql, params = super().as_sql( + sql, params = self.as_sql( compiler, connection, function="ARRAY_AGG", **extra_context, ) - return f"TO_JSONB({sql})", params - extra_context.setdefault( - "template", - "%(function)s(%(distinct)s%(expressions)s%(order_by)s RETURNING JSONB)\ - %(filter)s", - ) - return self.as_sql(compiler, connection, **extra_context) + # Use a filter to cleanly remove null values from the array to + # match the behaviour of ABSENT ON NULL on Oracle and + # PostgreSQL 16+. + if self.absent_on_null: + expression = self.get_source_expressions()[0] + if self.filter: + not_null_condition = IsNull(expression, False) + copy = self.copy() + copy.filter.source_expressions[0].children += [not_null_condition] + sql, params = copy.as_sql( + compiler, connection, function="ARRAY_AGG", **extra_context + ) + return f"TO_JSONB({sql})", params + else: + expr, _ = compiler.compile(expression) + filter_sql = f"FILTER (WHERE {expr} IS NOT NULL)" + return f"TO_JSONB({sql} {filter_sql})", params + else: + return f"TO_JSONB({sql})", params + return self.as_native(compiler, connection, returning="JSONB", **extra_context) def as_oracle(self, compiler, connection, **extra_context): # Oracle turns DATE columns into ISO 8601 timestamp including T00:00:00 @@ -474,5 +523,5 @@ class JSONArrayAgg(Aggregate): *source_expressions[1:], ] ) - return clone.as_sql(compiler, connection, **extra_context) - return self.as_sql(compiler, connection, **extra_context) + return clone.as_native(compiler, connection, **extra_context) + return self.as_native(compiler, connection, **extra_context) diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index bbe4fef8ee..2b735ab2b7 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -2737,7 +2737,7 @@ class AggregateAnnotationPruningTests(TestCase): class JSONArrayAggTests(TestCase): @classmethod def setUpTestData(cls): - cls.a1 = Author.objects.create(name="Adrian Holovaty", age=34) + cls.a1 = Author.objects.create(name="Adrian Holovaty", age=34, rating=1.5) cls.a2 = Author.objects.create(name="Jacob Kaplan-Moss", age=45) cls.a3 = Author.objects.create(name="Brad Dayley", age=40) cls.p1 = Publisher.objects.create(num_awards=3) @@ -2799,6 +2799,17 @@ class JSONArrayAggTests(TestCase): vals = Author.objects.aggregate(jsonarrayagg=JSONArrayAgg("book__pages")) self.assertEqual(vals, {"jsonarrayagg": [447, 528, 300]}) + def test_null_on_null(self): + vals = Author.objects.aggregate(jsonarrayagg=JSONArrayAgg("rating")) + self.assertEqual(vals, {"jsonarrayagg": [1.5, None, None]}) + + @skipUnlessDBFeature("supports_json_absent_on_null") + def test_absent_on_null(self): + vals = Author.objects.aggregate( + jsonarrayagg=JSONArrayAgg("rating", absent_on_null=True) + ) + self.assertEqual(vals, {"jsonarrayagg": [1.5]}) + @skipUnlessDBFeature("supports_aggregate_filter_clause") def test_filter(self): vals = Book.objects.aggregate( @@ -2842,6 +2853,14 @@ class JSONArrayAggTests(TestCase): with self.assertRaisesMessage(NotSupportedError, msg): Author.objects.aggregate(arrayagg=JSONArrayAgg("age", order_by="-name")) + @skipIfDBFeature("supports_json_absent_on_null") + def test_absent_on_null_not_supported(self): + msg = "JSONArrayAgg(absent_on_null) is not supported on this database backend." + with self.assertRaisesMessage(NotSupportedError, msg): + Author.objects.aggregate( + arrayagg=JSONArrayAgg("rating", absent_on_null=True) + ) + def test_distinct_true(self): msg = "JSONArrayAgg does not allow distinct." with self.assertRaisesMessage(TypeError, msg):