mirror of
https://github.com/django/django.git
synced 2025-09-26 12:09:19 +00:00
Fixed #14030 -- Allowed annotations to accept all expressions
This commit is contained in:
parent
39e3ef88c2
commit
f59fd15c49
43 changed files with 2572 additions and 801 deletions
|
@ -154,8 +154,7 @@ class QuerySet(object):
|
|||
2. sql/compiler.results_iter()
|
||||
- Returns one row at time. At this point the rows are still just
|
||||
tuples. In some cases the return values are converted to
|
||||
Python values at this location (see resolve_columns(),
|
||||
resolve_aggregate()).
|
||||
Python values at this location.
|
||||
3. self.iterator()
|
||||
- Responsible for turning the rows into model objects.
|
||||
"""
|
||||
|
@ -241,7 +240,7 @@ class QuerySet(object):
|
|||
max_depth = self.query.max_depth
|
||||
|
||||
extra_select = list(self.query.extra_select)
|
||||
aggregate_select = list(self.query.aggregate_select)
|
||||
annotation_select = list(self.query.annotation_select)
|
||||
|
||||
only_load = self.query.get_loaded_field_names()
|
||||
fields = self.model._meta.concrete_fields
|
||||
|
@ -282,7 +281,7 @@ class QuerySet(object):
|
|||
db = self.db
|
||||
compiler = self.query.get_compiler(using=db)
|
||||
index_start = len(extra_select)
|
||||
aggregate_start = index_start + len(init_list)
|
||||
annotation_start = index_start + len(init_list)
|
||||
|
||||
if fill_cache:
|
||||
klass_info = get_klass_info(model_cls, max_depth=max_depth,
|
||||
|
@ -290,18 +289,18 @@ class QuerySet(object):
|
|||
for row in compiler.results_iter():
|
||||
if fill_cache:
|
||||
obj, _ = get_cached_row(row, index_start, db, klass_info,
|
||||
offset=len(aggregate_select))
|
||||
offset=len(annotation_select))
|
||||
else:
|
||||
obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start])
|
||||
obj = model_cls.from_db(db, init_list, row[index_start:annotation_start])
|
||||
|
||||
if extra_select:
|
||||
for i, k in enumerate(extra_select):
|
||||
setattr(obj, k, row[i])
|
||||
|
||||
# Add the aggregates to the model
|
||||
if aggregate_select:
|
||||
for i, aggregate in enumerate(aggregate_select):
|
||||
setattr(obj, aggregate, row[i + aggregate_start])
|
||||
# Add the annotations to the model
|
||||
if annotation_select:
|
||||
for i, annotation in enumerate(annotation_select):
|
||||
setattr(obj, annotation, row[i + annotation_start])
|
||||
|
||||
# Add the known related objects to the model, if there are any
|
||||
if self._known_related_objects:
|
||||
|
@ -330,13 +329,16 @@ class QuerySet(object):
|
|||
if self.query.distinct_fields:
|
||||
raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
|
||||
for arg in args:
|
||||
if not hasattr(arg, 'default_alias'):
|
||||
raise TypeError("Complex aggregates require an alias")
|
||||
kwargs[arg.default_alias] = arg
|
||||
|
||||
query = self.query.clone()
|
||||
force_subq = query.low_mark != 0 or query.high_mark is not None
|
||||
for (alias, aggregate_expr) in kwargs.items():
|
||||
query.add_aggregate(aggregate_expr, self.model, alias,
|
||||
is_summary=True)
|
||||
query.add_annotation(aggregate_expr, self.model, alias, is_summary=True)
|
||||
if not query.annotations[alias].contains_aggregate:
|
||||
raise TypeError("%s is not an aggregate expression" % alias)
|
||||
return query.get_aggregation(using=self.db, force_subq=force_subq)
|
||||
|
||||
def count(self):
|
||||
|
@ -787,33 +789,40 @@ class QuerySet(object):
|
|||
def annotate(self, *args, **kwargs):
|
||||
"""
|
||||
Return a query set in which the returned objects have been annotated
|
||||
with data aggregated from related fields.
|
||||
with extra data or aggregations.
|
||||
"""
|
||||
aggrs = OrderedDict() # To preserve ordering of args
|
||||
annotations = OrderedDict() # To preserve ordering of args
|
||||
for arg in args:
|
||||
if arg.default_alias in kwargs:
|
||||
raise ValueError("The named annotation '%s' conflicts with the "
|
||||
"default name for another annotation."
|
||||
% arg.default_alias)
|
||||
aggrs[arg.default_alias] = arg
|
||||
aggrs.update(kwargs)
|
||||
try:
|
||||
# we can't do an hasattr here because py2 returns False
|
||||
# if default_alias exists but throws a TypeError
|
||||
if arg.default_alias in kwargs:
|
||||
raise ValueError("The named annotation '%s' conflicts with the "
|
||||
"default name for another annotation."
|
||||
% arg.default_alias)
|
||||
except AttributeError: # default_alias
|
||||
raise TypeError("Complex annotations require an alias")
|
||||
annotations[arg.default_alias] = arg
|
||||
annotations.update(kwargs)
|
||||
|
||||
obj = self._clone()
|
||||
names = getattr(self, '_fields', None)
|
||||
if names is None:
|
||||
names = set(self.model._meta.get_all_field_names())
|
||||
for aggregate in aggrs:
|
||||
if aggregate in names:
|
||||
|
||||
# Add the annotations to the query
|
||||
for alias, annotation in annotations.items():
|
||||
if alias in names:
|
||||
raise ValueError("The annotation '%s' conflicts with a field on "
|
||||
"the model." % aggregate)
|
||||
|
||||
obj = self._clone()
|
||||
|
||||
obj._setup_aggregate_query(list(aggrs))
|
||||
|
||||
# Add the aggregates to the query
|
||||
for (alias, aggregate_expr) in aggrs.items():
|
||||
obj.query.add_aggregate(aggregate_expr, self.model, alias,
|
||||
is_summary=False)
|
||||
"the model." % alias)
|
||||
obj.query.add_annotation(annotation, self.model, alias, is_summary=False)
|
||||
# expressions need to be added to the query before we know if they contain aggregates
|
||||
added_aggregates = []
|
||||
for alias, annotation in obj.query.annotations.items():
|
||||
if alias in annotations and annotation.contains_aggregate:
|
||||
added_aggregates.append(alias)
|
||||
if added_aggregates:
|
||||
obj._setup_aggregate_query(list(added_aggregates))
|
||||
|
||||
return obj
|
||||
|
||||
|
@ -1096,9 +1105,9 @@ class ValuesQuerySet(QuerySet):
|
|||
# Purge any extra columns that haven't been explicitly asked for
|
||||
extra_names = list(self.query.extra_select)
|
||||
field_names = self.field_names
|
||||
aggregate_names = list(self.query.aggregate_select)
|
||||
annotation_names = list(self.query.annotation_select)
|
||||
|
||||
names = extra_names + field_names + aggregate_names
|
||||
names = extra_names + field_names + annotation_names
|
||||
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
yield dict(zip(names, row))
|
||||
|
@ -1122,9 +1131,9 @@ class ValuesQuerySet(QuerySet):
|
|||
|
||||
if self._fields:
|
||||
self.extra_names = []
|
||||
self.aggregate_names = []
|
||||
if not self.query._extra and not self.query._aggregates:
|
||||
# Short cut - if there are no extra or aggregates, then
|
||||
self.annotation_names = []
|
||||
if not self.query._extra and not self.query._annotations:
|
||||
# Short cut - if there are no extra or annotations, then
|
||||
# the values() clause must be just field names.
|
||||
self.field_names = list(self._fields)
|
||||
else:
|
||||
|
@ -1136,22 +1145,22 @@ class ValuesQuerySet(QuerySet):
|
|||
# had selected previously.
|
||||
if self.query._extra and f in self.query._extra:
|
||||
self.extra_names.append(f)
|
||||
elif f in self.query.aggregate_select:
|
||||
self.aggregate_names.append(f)
|
||||
elif f in self.query.annotation_select:
|
||||
self.annotation_names.append(f)
|
||||
else:
|
||||
self.field_names.append(f)
|
||||
else:
|
||||
# Default to all fields.
|
||||
self.extra_names = None
|
||||
self.field_names = [f.attname for f in self.model._meta.concrete_fields]
|
||||
self.aggregate_names = None
|
||||
self.annotation_names = None
|
||||
|
||||
self.query.select = []
|
||||
if self.extra_names is not None:
|
||||
self.query.set_extra_mask(self.extra_names)
|
||||
self.query.add_fields(self.field_names, True)
|
||||
if self.aggregate_names is not None:
|
||||
self.query.set_aggregate_mask(self.aggregate_names)
|
||||
if self.annotation_names is not None:
|
||||
self.query.set_annotation_mask(self.annotation_names)
|
||||
|
||||
def _clone(self, klass=None, setup=False, **kwargs):
|
||||
"""
|
||||
|
@ -1164,7 +1173,7 @@ class ValuesQuerySet(QuerySet):
|
|||
c._fields = self._fields[:]
|
||||
c.field_names = self.field_names
|
||||
c.extra_names = self.extra_names
|
||||
c.aggregate_names = self.aggregate_names
|
||||
c.annotation_names = self.annotation_names
|
||||
if setup and hasattr(c, '_setup_query'):
|
||||
c._setup_query()
|
||||
return c
|
||||
|
@ -1173,7 +1182,7 @@ class ValuesQuerySet(QuerySet):
|
|||
super(ValuesQuerySet, self)._merge_sanity_check(other)
|
||||
if (set(self.extra_names) != set(other.extra_names) or
|
||||
set(self.field_names) != set(other.field_names) or
|
||||
self.aggregate_names != other.aggregate_names):
|
||||
self.annotation_names != other.annotation_names):
|
||||
raise TypeError("Merging '%s' classes must involve the same values in each case."
|
||||
% self.__class__.__name__)
|
||||
|
||||
|
@ -1183,9 +1192,9 @@ class ValuesQuerySet(QuerySet):
|
|||
"""
|
||||
self.query.set_group_by()
|
||||
|
||||
if self.aggregate_names is not None:
|
||||
self.aggregate_names.extend(aggregates)
|
||||
self.query.set_aggregate_mask(self.aggregate_names)
|
||||
if self.annotation_names is not None:
|
||||
self.annotation_names.extend(aggregates)
|
||||
self.query.set_annotation_mask(self.annotation_names)
|
||||
|
||||
super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
|
||||
|
||||
|
@ -1231,7 +1240,7 @@ class ValuesListQuerySet(ValuesQuerySet):
|
|||
if self.flat and len(self._fields) == 1:
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
yield row[0]
|
||||
elif not self.query.extra_select and not self.query.aggregate_select:
|
||||
elif not self.query.extra_select and not self.query.annotation_select:
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
yield tuple(row)
|
||||
else:
|
||||
|
@ -1240,14 +1249,14 @@ class ValuesListQuerySet(ValuesQuerySet):
|
|||
# the fields to match the order in self._fields.
|
||||
extra_names = list(self.query.extra_select)
|
||||
field_names = self.field_names
|
||||
aggregate_names = list(self.query.aggregate_select)
|
||||
annotation_names = list(self.query.annotation_select)
|
||||
|
||||
names = extra_names + field_names + aggregate_names
|
||||
names = extra_names + field_names + annotation_names
|
||||
|
||||
# If a field list has been specified, use it. Otherwise, use the
|
||||
# full list of fields, including extras and aggregates.
|
||||
# full list of fields, including extras and annotations.
|
||||
if self._fields:
|
||||
fields = list(self._fields) + [f for f in aggregate_names if f not in self._fields]
|
||||
fields = list(self._fields) + [f for f in annotation_names if f not in self._fields]
|
||||
else:
|
||||
fields = names
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue