Fixed #10929 -- Added default argument to aggregates.

Thanks to Simon Charette and Adam Johnson for the reviews.
This commit is contained in:
Nick Pope 2021-02-21 01:38:55 +00:00 committed by Mariusz Felisiak
parent 59942a66ce
commit 501a8db465
11 changed files with 393 additions and 64 deletions

View file

@ -1,15 +1,19 @@
import datetime
import math
import re
from decimal import Decimal
from django.core.exceptions import FieldError
from django.db import connection
from django.db.models import (
Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField,
IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When,
Avg, Case, Count, DateField, DateTimeField, DecimalField, DurationField,
Exists, F, FloatField, IntegerField, Max, Min, OuterRef, Q, StdDev,
Subquery, Sum, TimeField, Value, Variance, When,
)
from django.db.models.expressions import Func, RawSQL
from django.db.models.functions import Coalesce, Greatest
from django.db.models.functions import (
Cast, Coalesce, Greatest, Now, Pi, TruncDate, TruncHour,
)
from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext
@ -18,6 +22,20 @@ from django.utils import timezone
from .models import Author, Book, Publisher, Store
class NowUTC(Now):
template = 'CURRENT_TIMESTAMP'
output_field = DateTimeField()
def as_mysql(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template='UTC_TIMESTAMP', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template="CURRENT_TIMESTAMP AT TIME ZONE 'UTC'", **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template="STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'", **extra_context)
class AggregateTestCase(TestCase):
@classmethod
@ -1402,3 +1420,190 @@ class AggregateTestCase(TestCase):
)['latest_opening'],
datetime.datetime,
)
def test_aggregation_default_unsupported_by_count(self):
msg = 'Count does not allow default.'
with self.assertRaisesMessage(TypeError, msg):
Count('age', default=0)
def test_aggregation_default_unset(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age'),
)
self.assertIsNone(result['value'])
def test_aggregation_default_zero(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=0),
)
self.assertEqual(result['value'], 0)
def test_aggregation_default_integer(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=21),
)
self.assertEqual(result['value'], 21)
def test_aggregation_default_expression(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=Value(5) * Value(7)),
)
self.assertEqual(result['value'], 35)
def test_aggregation_default_group_by(self):
qs = Publisher.objects.values('name').annotate(
books=Count('book'),
pages=Sum('book__pages', default=0),
).filter(books=0)
self.assertSequenceEqual(
qs,
[{'name': "Jonno's House of Books", 'books': 0, 'pages': 0}],
)
def test_aggregation_default_compound_expression(self):
# Scale rating to a percentage; default to 50% if no books published.
formula = Avg('book__rating', default=2.5) * 20.0
queryset = Publisher.objects.annotate(rating=formula).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'rating'), [
{'name': 'Apress', 'rating': 85.0},
{'name': "Jonno's House of Books", 'rating': 50.0},
{'name': 'Morgan Kaufmann', 'rating': 100.0},
{'name': 'Prentice Hall', 'rating': 80.0},
{'name': 'Sams', 'rating': 60.0},
])
def test_aggregation_default_using_time_from_python(self):
expr = Min(
'store__friday_night_closing',
filter=~Q(store__name='Amazon.com'),
default=datetime.time(17),
)
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 8.0+ & MariaDB.
expr.default = Cast(expr.default, TimeField())
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '067232959', 'oldest_store_opening': datetime.time(17)},
{'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)},
])
def test_aggregation_default_using_time_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min(
'store__friday_night_closing',
filter=~Q(store__name='Amazon.com'),
default=TruncHour(NowUTC(), output_field=TimeField()),
)
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '067232959', 'oldest_store_opening': datetime.time(now.hour)},
{'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)},
])
def test_aggregation_default_using_date_from_python(self):
expr = Min('book__pubdate', default=datetime.date(1970, 1, 1))
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 5.7+ & MariaDB.
expr.default = Cast(expr.default, DateField())
queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [
{'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)},
{'name': "Jonno's House of Books", 'earliest_pubdate': datetime.date(1970, 1, 1)},
{'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)},
{'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)},
{'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)},
])
def test_aggregation_default_using_date_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min('book__pubdate', default=TruncDate(NowUTC()))
queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [
{'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)},
{'name': "Jonno's House of Books", 'earliest_pubdate': now.date()},
{'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)},
{'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)},
{'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)},
])
def test_aggregation_default_using_datetime_from_python(self):
expr = Min(
'store__original_opening',
filter=~Q(store__name='Amazon.com'),
default=datetime.datetime(1970, 1, 1),
)
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 8.0+ & MariaDB.
expr.default = Cast(expr.default, DateTimeField())
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '067232959', 'oldest_store_opening': datetime.datetime(1970, 1, 1)},
{'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
])
def test_aggregation_default_using_datetime_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min(
'store__original_opening',
filter=~Q(store__name='Amazon.com'),
default=TruncHour(NowUTC(), output_field=DateTimeField()),
)
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '067232959', 'oldest_store_opening': now.replace(minute=0, second=0, microsecond=0, tzinfo=None)},
{'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
])
def test_aggregation_default_using_duration_from_python(self):
result = Publisher.objects.filter(num_awards__gt=3).aggregate(
value=Sum('duration', default=datetime.timedelta(0)),
)
self.assertEqual(result['value'], datetime.timedelta(0))
def test_aggregation_default_using_duration_from_database(self):
result = Publisher.objects.filter(num_awards__gt=3).aggregate(
value=Sum('duration', default=Now() - Now()),
)
self.assertEqual(result['value'], datetime.timedelta(0))
def test_aggregation_default_using_decimal_from_python(self):
result = Book.objects.filter(rating__lt=3.0).aggregate(
value=Sum('price', default=Decimal('0.00')),
)
self.assertEqual(result['value'], Decimal('0.00'))
def test_aggregation_default_using_decimal_from_database(self):
result = Book.objects.filter(rating__lt=3.0).aggregate(
value=Sum('price', default=Pi()),
)
self.assertAlmostEqual(result['value'], Decimal.from_float(math.pi), places=6)
def test_aggregation_default_passed_another_aggregate(self):
result = Book.objects.aggregate(
value=Sum('price', filter=Q(rating__lt=3.0), default=Avg('pages') / 10.0),
)
self.assertAlmostEqual(result['value'], Decimal('61.72'), places=2)