mirror of
https://github.com/django/django.git
synced 2025-08-31 15:57:45 +00:00
Fixed #14030 -- Allowed annotations to accept all expressions
This commit is contained in:
parent
39e3ef88c2
commit
f59fd15c49
43 changed files with 2572 additions and 801 deletions
|
@ -1,14 +1,20 @@
|
|||
import copy
|
||||
import datetime
|
||||
|
||||
from django.db.models.aggregates import refs_aggregate
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.models import fields
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.utils import tree
|
||||
from django.db.models.query_utils import refs_aggregate
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class ExpressionNode(tree.Node):
|
||||
class CombinableMixin(object):
|
||||
"""
|
||||
Base class for all query expressions.
|
||||
Provides the ability to combine one or two objects with
|
||||
some connector. For example F('foo') + F('bar').
|
||||
"""
|
||||
|
||||
# Arithmetic connectors
|
||||
ADD = '+'
|
||||
SUB = '-'
|
||||
|
@ -25,44 +31,17 @@ class ExpressionNode(tree.Node):
|
|||
BITAND = '&'
|
||||
BITOR = '|'
|
||||
|
||||
def __init__(self, children=None, connector=None, negated=False):
|
||||
if children is not None and len(children) > 1 and connector is None:
|
||||
raise TypeError('You have to specify a connector.')
|
||||
super(ExpressionNode, self).__init__(children, connector, negated)
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
if isinstance(other, datetime.timedelta):
|
||||
return DateModifierNode([self, other], connector)
|
||||
return DateModifierNode(self, connector, other)
|
||||
|
||||
if not hasattr(other, 'resolve_expression'):
|
||||
# everything must be resolvable to an expression
|
||||
other = Value(other)
|
||||
|
||||
if reversed:
|
||||
obj = ExpressionNode([other], connector)
|
||||
obj.add(node or self, connector)
|
||||
else:
|
||||
obj = node or ExpressionNode([self], connector)
|
||||
obj.add(other, connector)
|
||||
return obj
|
||||
|
||||
def contains_aggregate(self, existing_aggregates):
|
||||
if self.children:
|
||||
return any(child.contains_aggregate(existing_aggregates)
|
||||
for child in self.children
|
||||
if hasattr(child, 'contains_aggregate'))
|
||||
else:
|
||||
return refs_aggregate(self.name.split(LOOKUP_SEP),
|
||||
existing_aggregates)
|
||||
|
||||
def prepare_database_save(self, unused):
|
||||
return self
|
||||
|
||||
###################
|
||||
# VISITOR METHODS #
|
||||
###################
|
||||
|
||||
def prepare(self, evaluator, query, allow_joins):
|
||||
return evaluator.prepare_node(self, query, allow_joins)
|
||||
|
||||
def evaluate(self, evaluator, qn, connection):
|
||||
return evaluator.evaluate_node(self, qn, connection)
|
||||
return Expression(other, connector, self)
|
||||
return Expression(self, connector, other)
|
||||
|
||||
#############
|
||||
# OPERATORS #
|
||||
|
@ -137,27 +116,240 @@ class ExpressionNode(tree.Node):
|
|||
)
|
||||
|
||||
|
||||
class F(ExpressionNode):
|
||||
class ExpressionNode(CombinableMixin):
|
||||
"""
|
||||
An expression representing the value of the given field.
|
||||
Base class for all query expressions.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
super(F, self).__init__(None, None, False)
|
||||
self.name = name
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
obj = super(F, self).__deepcopy__(memodict)
|
||||
obj.name = self.name
|
||||
return obj
|
||||
# aggregate specific fields
|
||||
is_summary = False
|
||||
|
||||
def prepare(self, evaluator, query, allow_joins):
|
||||
return evaluator.prepare_leaf(self, query, allow_joins)
|
||||
def __init__(self, output_field=None):
|
||||
self._output_field = output_field
|
||||
|
||||
def evaluate(self, evaluator, qn, connection):
|
||||
return evaluator.evaluate_leaf(self, qn, connection)
|
||||
def get_source_expressions(self):
|
||||
return []
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
assert len(exprs) == 0
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
"""
|
||||
Responsible for returning a (sql, [params]) tuple to be included
|
||||
in the current query.
|
||||
|
||||
Different backends can provide their own implementation, by
|
||||
providing an `as_{vendor}` method and patching the Expression:
|
||||
|
||||
```
|
||||
def override_as_sql(self, compiler, connection):
|
||||
# custom logic
|
||||
return super(ExpressionNode, self).as_sql(compiler, connection)
|
||||
setattr(ExpressionNode, 'as_' + connection.vendor, override_as_sql)
|
||||
```
|
||||
|
||||
Arguments:
|
||||
* compiler: the query compiler responsible for generating the query.
|
||||
Must have a compile method, returning a (sql, [params]) tuple.
|
||||
Calling compiler(value) will return a quoted `value`.
|
||||
|
||||
* connection: the database connection used for the current query.
|
||||
|
||||
Returns: (sql, params)
|
||||
Where `sql` is a string containing ordered sql parameters to be
|
||||
replaced with the elements of the list `params`.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement as_sql()")
|
||||
|
||||
@cached_property
|
||||
def contains_aggregate(self):
|
||||
for expr in self.get_source_expressions():
|
||||
if expr and expr.contains_aggregate:
|
||||
return True
|
||||
return False
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
"""
|
||||
Provides the chance to do any preprocessing or validation before being
|
||||
added to the query.
|
||||
|
||||
Arguments:
|
||||
* query: the backend query implementation
|
||||
* allow_joins: boolean allowing or denying use of joins
|
||||
in this query
|
||||
* reuse: a set of reusable joins for multijoins
|
||||
* summarize: a terminal aggregate clause
|
||||
|
||||
Returns: an ExpressionNode to be added to the query.
|
||||
"""
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
return c
|
||||
|
||||
def _prepare(self):
|
||||
"""
|
||||
Hook used by Field.get_prep_lookup() to do custom preparation.
|
||||
"""
|
||||
return self
|
||||
|
||||
@property
|
||||
def field(self):
|
||||
return self.output_field
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
"""
|
||||
Returns the output type of this expressions.
|
||||
"""
|
||||
if self._output_field_or_none is None:
|
||||
raise FieldError("Cannot resolve expression type, unknown output_field")
|
||||
return self._output_field_or_none
|
||||
|
||||
@cached_property
|
||||
def _output_field_or_none(self):
|
||||
"""
|
||||
Returns the output field of this expression, or None if no output type
|
||||
can be resolved. Note that the 'output_field' property will raise
|
||||
FieldError if no type can be resolved, but this attribute allows for
|
||||
None values.
|
||||
"""
|
||||
if self._output_field is None:
|
||||
self._resolve_output_field()
|
||||
return self._output_field
|
||||
|
||||
def _resolve_output_field(self):
|
||||
"""
|
||||
Attempts to infer the output type of the expression. If the output
|
||||
fields of all source fields match then we can simply infer the same
|
||||
type here.
|
||||
"""
|
||||
if self._output_field is None:
|
||||
sources = self.get_source_fields()
|
||||
num_sources = len(sources)
|
||||
if num_sources == 0:
|
||||
self._output_field = None
|
||||
else:
|
||||
self._output_field = sources[0]
|
||||
for source in sources:
|
||||
if source is not None and not isinstance(self._output_field, source.__class__):
|
||||
raise FieldError(
|
||||
"Expression contains mixed types. You must set output_field")
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
"""
|
||||
Expressions provide their own converters because users have the option
|
||||
of manually specifying the output_field which may be a different type
|
||||
from the one the database returns.
|
||||
"""
|
||||
field = self.output_field
|
||||
internal_type = field.get_internal_type()
|
||||
if value is None:
|
||||
return value
|
||||
elif internal_type == 'FloatField':
|
||||
return float(value)
|
||||
elif internal_type.endswith('IntegerField'):
|
||||
return int(value)
|
||||
elif internal_type == 'DecimalField':
|
||||
return backend_utils.typecast_decimal(field.format_number(value))
|
||||
return value
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.output_field.get_transform(name)
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[e.relabeled_clone(change_map) for e in self.get_source_expressions()])
|
||||
return clone
|
||||
|
||||
def copy(self):
|
||||
c = copy.copy(self)
|
||||
c.copied = True
|
||||
return c
|
||||
|
||||
def refs_aggregate(self, existing_aggregates):
|
||||
"""
|
||||
Does this expression contain a reference to some of the
|
||||
existing aggregates? If so, returns the aggregate and also
|
||||
the lookup parts that *weren't* found. So, if
|
||||
exsiting_aggregates = {'max_id': Max('id')}
|
||||
self.name = 'max_id'
|
||||
queryset.filter(max_id__range=[10,100])
|
||||
then this method will return Max('id') and those parts of the
|
||||
name that weren't found. In this case `max_id` is found and the range
|
||||
portion is returned as ('range',).
|
||||
"""
|
||||
for node in self.get_source_expressions():
|
||||
agg, lookup = node.refs_aggregate(existing_aggregates)
|
||||
if agg:
|
||||
return agg, lookup
|
||||
return False, ()
|
||||
|
||||
def refs_field(self, aggregate_types, field_types):
|
||||
"""
|
||||
Helper method for check_aggregate_support on backends
|
||||
"""
|
||||
return any(
|
||||
node.refs_field(aggregate_types, field_types)
|
||||
for node in self.get_source_expressions())
|
||||
|
||||
def prepare_database_save(self, field):
|
||||
return self
|
||||
|
||||
def get_group_by_cols(self):
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def get_source_fields(self):
|
||||
"""
|
||||
Returns the underlying field types used by this
|
||||
aggregate.
|
||||
"""
|
||||
return [e._output_field_or_none for e in self.get_source_expressions()]
|
||||
|
||||
|
||||
class DateModifierNode(ExpressionNode):
|
||||
class Expression(ExpressionNode):
|
||||
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super(Expression, self).__init__(output_field=output_field)
|
||||
self.connector = connector
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.lhs, self.rhs]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.lhs, self.rhs = exprs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
expressions = []
|
||||
expression_params = []
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
expressions.append(sql)
|
||||
expression_params.extend(params)
|
||||
sql, params = compiler.compile(self.rhs)
|
||||
expressions.append(sql)
|
||||
expression_params.extend(params)
|
||||
# order of precedence
|
||||
expression_wrapper = '(%s)'
|
||||
sql = connection.ops.combine_expression(self.connector, expressions)
|
||||
return expression_wrapper % sql, expression_params
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
return c
|
||||
|
||||
|
||||
class DateModifierNode(Expression):
|
||||
"""
|
||||
Node that implements the following syntax:
|
||||
filter(end_date__gt=F('start_date') + datetime.timedelta(days=3, seconds=200))
|
||||
|
@ -183,14 +375,195 @@ class DateModifierNode(ExpressionNode):
|
|||
Only adding and subtracting timedeltas is supported, attempts to use other
|
||||
operations raise a TypeError.
|
||||
"""
|
||||
def __init__(self, children, connector, negated=False):
|
||||
if len(children) != 2:
|
||||
raise TypeError('Must specify a node and a timedelta.')
|
||||
if not isinstance(children[1], datetime.timedelta):
|
||||
raise TypeError('Second child must be a timedelta.')
|
||||
def __init__(self, lhs, connector, rhs):
|
||||
if not isinstance(rhs, datetime.timedelta):
|
||||
raise TypeError('rhs must be a timedelta.')
|
||||
if connector not in (self.ADD, self.SUB):
|
||||
raise TypeError('Connector must be + or -, not %s' % connector)
|
||||
super(DateModifierNode, self).__init__(children, connector, negated)
|
||||
super(DateModifierNode, self).__init__(lhs, connector, Value(rhs))
|
||||
|
||||
def evaluate(self, evaluator, qn, connection):
|
||||
return evaluator.evaluate_date_modifier_node(self, qn, connection)
|
||||
def as_sql(self, compiler, connection):
|
||||
timedelta = self.rhs.value
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
|
||||
return sql, params
|
||||
return connection.ops.date_interval_sql(sql, self.connector, timedelta), params
|
||||
|
||||
|
||||
class F(CombinableMixin):
|
||||
"""
|
||||
An object capable of resolving references to existing query objects.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Arguments:
|
||||
* name: the name of the field this expression references
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
|
||||
|
||||
def refs_aggregate(self, existing_aggregates):
|
||||
return refs_aggregate(self.name.split(LOOKUP_SEP), existing_aggregates)
|
||||
|
||||
|
||||
class Func(ExpressionNode):
|
||||
"""
|
||||
A SQL function call.
|
||||
"""
|
||||
function = None
|
||||
template = '%(function)s(%(expressions)s)'
|
||||
arg_joiner = ', '
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
output_field = extra.pop('output_field', None)
|
||||
super(Func, self).__init__(output_field=output_field)
|
||||
self.source_expressions = self._parse_expressions(*expressions)
|
||||
self.extra = extra
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.source_expressions
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.source_expressions = exprs
|
||||
|
||||
def _parse_expressions(self, *expressions):
|
||||
return [
|
||||
arg if hasattr(arg, 'resolve_expression') else F(arg)
|
||||
for arg in expressions
|
||||
]
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
for pos, arg in enumerate(c.source_expressions):
|
||||
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
sql_parts = []
|
||||
params = []
|
||||
for arg in self.source_expressions:
|
||||
arg_sql, arg_params = compiler.compile(arg)
|
||||
sql_parts.append(arg_sql)
|
||||
params.extend(arg_params)
|
||||
if function is None:
|
||||
self.extra['function'] = self.extra.get('function', self.function)
|
||||
else:
|
||||
self.extra['function'] = function
|
||||
self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts)
|
||||
template = template or self.extra.get('template', self.template)
|
||||
return template % self.extra, params
|
||||
|
||||
def copy(self):
|
||||
copy = super(Func, self).copy()
|
||||
copy.source_expressions = self.source_expressions[:]
|
||||
copy.extra = self.extra.copy()
|
||||
return copy
|
||||
|
||||
|
||||
class Value(ExpressionNode):
|
||||
"""
|
||||
Represents a wrapped value as a node within an expression
|
||||
"""
|
||||
def __init__(self, value, output_field=None):
|
||||
"""
|
||||
Arguments:
|
||||
* value: the value this expression represents. The value will be
|
||||
added into the sql parameter list and properly quoted.
|
||||
|
||||
* output_field: an instance of the model field type that this
|
||||
expression will return, such as IntegerField() or CharField().
|
||||
"""
|
||||
super(Value, self).__init__(output_field=output_field)
|
||||
self.value = value
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return '%s', [self.value]
|
||||
|
||||
|
||||
class Col(ExpressionNode):
|
||||
def __init__(self, alias, target, source=None):
|
||||
if source is None:
|
||||
source = target
|
||||
super(Col, self).__init__(output_field=source)
|
||||
self.alias, self.target = alias, target
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [(self.alias, self.target.column)]
|
||||
|
||||
|
||||
class Ref(ExpressionNode):
|
||||
"""
|
||||
Reference to column alias of the query. For example, Ref('sum_cost') in
|
||||
qs.annotate(sum_cost=Sum('cost')) query.
|
||||
"""
|
||||
def __init__(self, refs, source):
|
||||
super(Ref, self).__init__()
|
||||
self.source = source
|
||||
self.refs = refs
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.source]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.source, = exprs
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return "%s" % compiler(self.refs), []
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [(None, self.refs)]
|
||||
|
||||
|
||||
class Date(ExpressionNode):
|
||||
"""
|
||||
Add a date selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type):
|
||||
super(Date, self).__init__(output_field=fields.DateField())
|
||||
self.col = col
|
||||
self.lookup_type = lookup_type
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.col]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.col, = self.exprs
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = self.col.as_sql(qn, connection)
|
||||
assert not(params)
|
||||
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
|
||||
|
||||
|
||||
class DateTime(ExpressionNode):
|
||||
"""
|
||||
Add a datetime selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type, tzname):
|
||||
super(DateTime, self).__init__(output_field=fields.DateTimeField())
|
||||
self.col = col
|
||||
self.lookup_type = lookup_type
|
||||
self.tzname = tzname
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.col]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.col, = exprs
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = self.col.as_sql(qn, connection)
|
||||
assert not(params)
|
||||
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue