Fixed #3566 -- Added support for aggregation to the ORM. See the documentation for details on usage.

Many thanks to:
 * Nicolas Lara, who worked on this feature during the 2008 Google Summer of Code.
 * Alex Gaynor for his help debugging and fixing a number of issues.
 * Justin Bronn for his help integrating with contrib.gis.
 * Karen Tracey for her help with cross-platform testing.
 * Ian Kelly for his help testing and fixing Oracle support.
 * Malcolm Tredinnick for his invaluable review notes.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9742 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2009-01-15 11:06:34 +00:00
parent 50a293a0c3
commit cc4e4d9aee
30 changed files with 2357 additions and 325 deletions

View file

@ -4,6 +4,7 @@ except NameError:
from sets import Set as set # Python 2.3 fallback
from django.db import connection, transaction, IntegrityError
from django.db.models.aggregates import Aggregate
from django.db.models.fields import DateField
from django.db.models.query_utils import Q, select_related_descend
from django.db.models import signals, sql
@ -270,18 +271,47 @@ class QuerySet(object):
else:
requested = None
max_depth = self.query.max_depth
extra_select = self.query.extra_select.keys()
aggregate_select = self.query.aggregate_select.keys()
index_start = len(extra_select)
aggregate_start = index_start + len(self.model._meta.fields)
for row in self.query.results_iter():
if fill_cache:
obj, _ = get_cached_row(self.model, row, index_start,
max_depth, requested=requested)
obj, aggregate_start = get_cached_row(self.model, row,
index_start, max_depth, requested=requested)
else:
obj = self.model(*row[index_start:])
# omit aggregates in object creation
obj = self.model(*row[index_start:aggregate_start])
for i, k in enumerate(extra_select):
setattr(obj, k, row[i])
# Add the aggregates to the model
for i, aggregate in enumerate(aggregate_select):
setattr(obj, aggregate, row[i+aggregate_start])
yield obj
def aggregate(self, *args, **kwargs):
"""
Returns a dictionary containing the calculations (aggregation)
over the current queryset
If args is present the expression is passed as a kwarg ussing
the Aggregate object's default alias.
"""
for arg in args:
kwargs[arg.default_alias] = arg
for (alias, aggregate_expr) in kwargs.items():
self.query.add_aggregate(aggregate_expr, self.model, alias,
is_summary=True)
return self.query.get_aggregation()
def count(self):
"""
Performs a SELECT COUNT() and returns the number of records as an
@ -553,6 +583,25 @@ class QuerySet(object):
"""
self.query.select_related = other.query.select_related
def annotate(self, *args, **kwargs):
"""
Return a query set in which the returned objects have been annotated
with data aggregated from related fields.
"""
for arg in args:
kwargs[arg.default_alias] = arg
obj = self._clone()
obj._setup_aggregate_query()
# Add the aggregates to the query
for (alias, aggregate_expr) in kwargs.items():
obj.query.add_aggregate(aggregate_expr, self.model, alias,
is_summary=False)
return obj
def order_by(self, *field_names):
"""
Returns a new QuerySet instance with the ordering changed.
@ -641,6 +690,16 @@ class QuerySet(object):
"""
pass
def _setup_aggregate_query(self):
"""
Prepare the query for computing a result that contains aggregate annotations.
"""
opts = self.model._meta
if not self.query.group_by:
field_names = [f.attname for f in opts.fields]
self.query.add_fields(field_names, False)
self.query.set_group_by()
def as_sql(self):
"""
Returns the internal query's SQL and parameters (as a tuple).
@ -669,6 +728,8 @@ class ValuesQuerySet(QuerySet):
len(self.field_names) != len(self.model._meta.fields)):
self.query.trim_extra_select(self.extra_names)
names = self.query.extra_select.keys() + self.field_names
names.extend(self.query.aggregate_select.keys())
for row in self.query.results_iter():
yield dict(zip(names, row))
@ -682,20 +743,25 @@ class ValuesQuerySet(QuerySet):
"""
self.query.clear_select_fields()
self.extra_names = []
self.aggregate_names = []
if self._fields:
if not self.query.extra_select:
if not self.query.extra_select and not self.query.aggregate_select:
field_names = list(self._fields)
else:
field_names = []
for f in self._fields:
if self.query.extra_select.has_key(f):
self.extra_names.append(f)
elif self.query.aggregate_select.has_key(f):
self.aggregate_names.append(f)
else:
field_names.append(f)
else:
# Default to all fields.
field_names = [f.attname for f in self.model._meta.fields]
self.query.select = []
self.query.add_fields(field_names, False)
self.query.default_cols = False
self.field_names = field_names
@ -711,6 +777,7 @@ class ValuesQuerySet(QuerySet):
c._fields = self._fields[:]
c.field_names = self.field_names
c.extra_names = self.extra_names
c.aggregate_names = self.aggregate_names
if setup and hasattr(c, '_setup_query'):
c._setup_query()
return c
@ -718,10 +785,18 @@ class ValuesQuerySet(QuerySet):
def _merge_sanity_check(self, other):
super(ValuesQuerySet, self)._merge_sanity_check(other)
if (set(self.extra_names) != set(other.extra_names) or
set(self.field_names) != set(other.field_names)):
set(self.field_names) != set(other.field_names) or
self.aggregate_names != other.aggregate_names):
raise TypeError("Merging '%s' classes must involve the same values in each case."
% self.__class__.__name__)
def _setup_aggregate_query(self):
"""
Prepare the query for computing a result that contains aggregate annotations.
"""
self.query.set_group_by()
super(ValuesQuerySet, self)._setup_aggregate_query()
class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
@ -729,14 +804,14 @@ class ValuesListQuerySet(ValuesQuerySet):
if self.flat and len(self._fields) == 1:
for row in self.query.results_iter():
yield row[0]
elif not self.query.extra_select:
elif not self.query.extra_select and not self.query.aggregate_select:
for row in self.query.results_iter():
yield tuple(row)
else:
# When extra(select=...) is involved, the extra cols come are
# always at the start of the row, so we need to reorder the fields
# to match the order in self._fields.
names = self.query.extra_select.keys() + self.field_names
names = self.query.extra_select.keys() + self.field_names + self.query.aggregate_select.keys()
for row in self.query.results_iter():
data = dict(zip(names, row))
yield tuple([data[f] for f in self._fields])