Fixed #14030 -- Allowed annotations to accept all expressions

This commit is contained in:
Josh Smeaton 2013-12-26 00:13:18 +11:00 committed by Marc Tamlyn
parent 39e3ef88c2
commit f59fd15c49
43 changed files with 2572 additions and 801 deletions

View file

@ -154,8 +154,7 @@ class QuerySet(object):
2. sql/compiler.results_iter()
- Returns one row at time. At this point the rows are still just
tuples. In some cases the return values are converted to
Python values at this location (see resolve_columns(),
resolve_aggregate()).
Python values at this location.
3. self.iterator()
- Responsible for turning the rows into model objects.
"""
@ -241,7 +240,7 @@ class QuerySet(object):
max_depth = self.query.max_depth
extra_select = list(self.query.extra_select)
aggregate_select = list(self.query.aggregate_select)
annotation_select = list(self.query.annotation_select)
only_load = self.query.get_loaded_field_names()
fields = self.model._meta.concrete_fields
@ -282,7 +281,7 @@ class QuerySet(object):
db = self.db
compiler = self.query.get_compiler(using=db)
index_start = len(extra_select)
aggregate_start = index_start + len(init_list)
annotation_start = index_start + len(init_list)
if fill_cache:
klass_info = get_klass_info(model_cls, max_depth=max_depth,
@ -290,18 +289,18 @@ class QuerySet(object):
for row in compiler.results_iter():
if fill_cache:
obj, _ = get_cached_row(row, index_start, db, klass_info,
offset=len(aggregate_select))
offset=len(annotation_select))
else:
obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start])
obj = model_cls.from_db(db, init_list, row[index_start:annotation_start])
if extra_select:
for i, k in enumerate(extra_select):
setattr(obj, k, row[i])
# Add the aggregates to the model
if aggregate_select:
for i, aggregate in enumerate(aggregate_select):
setattr(obj, aggregate, row[i + aggregate_start])
# Add the annotations to the model
if annotation_select:
for i, annotation in enumerate(annotation_select):
setattr(obj, annotation, row[i + annotation_start])
# Add the known related objects to the model, if there are any
if self._known_related_objects:
@ -330,13 +329,16 @@ class QuerySet(object):
if self.query.distinct_fields:
raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
for arg in args:
if not hasattr(arg, 'default_alias'):
raise TypeError("Complex aggregates require an alias")
kwargs[arg.default_alias] = arg
query = self.query.clone()
force_subq = query.low_mark != 0 or query.high_mark is not None
for (alias, aggregate_expr) in kwargs.items():
query.add_aggregate(aggregate_expr, self.model, alias,
is_summary=True)
query.add_annotation(aggregate_expr, self.model, alias, is_summary=True)
if not query.annotations[alias].contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
return query.get_aggregation(using=self.db, force_subq=force_subq)
def count(self):
@ -787,33 +789,40 @@ class QuerySet(object):
def annotate(self, *args, **kwargs):
"""
Return a query set in which the returned objects have been annotated
with data aggregated from related fields.
with extra data or aggregations.
"""
aggrs = OrderedDict() # To preserve ordering of args
annotations = OrderedDict() # To preserve ordering of args
for arg in args:
if arg.default_alias in kwargs:
raise ValueError("The named annotation '%s' conflicts with the "
"default name for another annotation."
% arg.default_alias)
aggrs[arg.default_alias] = arg
aggrs.update(kwargs)
try:
# we can't do an hasattr here because py2 returns False
# if default_alias exists but throws a TypeError
if arg.default_alias in kwargs:
raise ValueError("The named annotation '%s' conflicts with the "
"default name for another annotation."
% arg.default_alias)
except AttributeError: # default_alias
raise TypeError("Complex annotations require an alias")
annotations[arg.default_alias] = arg
annotations.update(kwargs)
obj = self._clone()
names = getattr(self, '_fields', None)
if names is None:
names = set(self.model._meta.get_all_field_names())
for aggregate in aggrs:
if aggregate in names:
# Add the annotations to the query
for alias, annotation in annotations.items():
if alias in names:
raise ValueError("The annotation '%s' conflicts with a field on "
"the model." % aggregate)
obj = self._clone()
obj._setup_aggregate_query(list(aggrs))
# Add the aggregates to the query
for (alias, aggregate_expr) in aggrs.items():
obj.query.add_aggregate(aggregate_expr, self.model, alias,
is_summary=False)
"the model." % alias)
obj.query.add_annotation(annotation, self.model, alias, is_summary=False)
# expressions need to be added to the query before we know if they contain aggregates
added_aggregates = []
for alias, annotation in obj.query.annotations.items():
if alias in annotations and annotation.contains_aggregate:
added_aggregates.append(alias)
if added_aggregates:
obj._setup_aggregate_query(list(added_aggregates))
return obj
@ -1096,9 +1105,9 @@ class ValuesQuerySet(QuerySet):
# Purge any extra columns that haven't been explicitly asked for
extra_names = list(self.query.extra_select)
field_names = self.field_names
aggregate_names = list(self.query.aggregate_select)
annotation_names = list(self.query.annotation_select)
names = extra_names + field_names + aggregate_names
names = extra_names + field_names + annotation_names
for row in self.query.get_compiler(self.db).results_iter():
yield dict(zip(names, row))
@ -1122,9 +1131,9 @@ class ValuesQuerySet(QuerySet):
if self._fields:
self.extra_names = []
self.aggregate_names = []
if not self.query._extra and not self.query._aggregates:
# Short cut - if there are no extra or aggregates, then
self.annotation_names = []
if not self.query._extra and not self.query._annotations:
# Short cut - if there are no extra or annotations, then
# the values() clause must be just field names.
self.field_names = list(self._fields)
else:
@ -1136,22 +1145,22 @@ class ValuesQuerySet(QuerySet):
# had selected previously.
if self.query._extra and f in self.query._extra:
self.extra_names.append(f)
elif f in self.query.aggregate_select:
self.aggregate_names.append(f)
elif f in self.query.annotation_select:
self.annotation_names.append(f)
else:
self.field_names.append(f)
else:
# Default to all fields.
self.extra_names = None
self.field_names = [f.attname for f in self.model._meta.concrete_fields]
self.aggregate_names = None
self.annotation_names = None
self.query.select = []
if self.extra_names is not None:
self.query.set_extra_mask(self.extra_names)
self.query.add_fields(self.field_names, True)
if self.aggregate_names is not None:
self.query.set_aggregate_mask(self.aggregate_names)
if self.annotation_names is not None:
self.query.set_annotation_mask(self.annotation_names)
def _clone(self, klass=None, setup=False, **kwargs):
"""
@ -1164,7 +1173,7 @@ class ValuesQuerySet(QuerySet):
c._fields = self._fields[:]
c.field_names = self.field_names
c.extra_names = self.extra_names
c.aggregate_names = self.aggregate_names
c.annotation_names = self.annotation_names
if setup and hasattr(c, '_setup_query'):
c._setup_query()
return c
@ -1173,7 +1182,7 @@ class ValuesQuerySet(QuerySet):
super(ValuesQuerySet, self)._merge_sanity_check(other)
if (set(self.extra_names) != set(other.extra_names) or
set(self.field_names) != set(other.field_names) or
self.aggregate_names != other.aggregate_names):
self.annotation_names != other.annotation_names):
raise TypeError("Merging '%s' classes must involve the same values in each case."
% self.__class__.__name__)
@ -1183,9 +1192,9 @@ class ValuesQuerySet(QuerySet):
"""
self.query.set_group_by()
if self.aggregate_names is not None:
self.aggregate_names.extend(aggregates)
self.query.set_aggregate_mask(self.aggregate_names)
if self.annotation_names is not None:
self.annotation_names.extend(aggregates)
self.query.set_annotation_mask(self.annotation_names)
super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
@ -1231,7 +1240,7 @@ class ValuesListQuerySet(ValuesQuerySet):
if self.flat and len(self._fields) == 1:
for row in self.query.get_compiler(self.db).results_iter():
yield row[0]
elif not self.query.extra_select and not self.query.aggregate_select:
elif not self.query.extra_select and not self.query.annotation_select:
for row in self.query.get_compiler(self.db).results_iter():
yield tuple(row)
else:
@ -1240,14 +1249,14 @@ class ValuesListQuerySet(ValuesQuerySet):
# the fields to match the order in self._fields.
extra_names = list(self.query.extra_select)
field_names = self.field_names
aggregate_names = list(self.query.aggregate_select)
annotation_names = list(self.query.annotation_select)
names = extra_names + field_names + aggregate_names
names = extra_names + field_names + annotation_names
# If a field list has been specified, use it. Otherwise, use the
# full list of fields, including extras and aggregates.
# full list of fields, including extras and annotations.
if self._fields:
fields = list(self._fields) + [f for f in aggregate_names if f not in self._fields]
fields = list(self._fields) + [f for f in annotation_names if f not in self._fields]
else:
fields = names