Fixed #24020 -- Refactored SQL compiler to use expressions

Refactored compiler SELECT, GROUP BY and ORDER BY generation.
While there, also refactored select_related() implementation
(get_cached_row() and get_klass_info() are now gone!).

Made get_db_converters() method work on expressions instead of
internal_type. This allows the backend converters to target
specific expressions if need be.

Added query.context, this can be used to set per-query state.

Also changed the signature of database converters. They now accept
context as an argument.
This commit is contained in:
Anssi Kääriäinen 2014-12-01 09:28:01 +02:00 committed by Tim Graham
parent b8abfe141b
commit 0c7633178f
41 changed files with 970 additions and 1416 deletions

View file

@ -21,7 +21,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
from django.db.models.query_utils import PathInfo, Q, refs_aggregate
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
ORDER_PATTERN, SelectInfo, INNER, LOUTER)
ORDER_PATTERN, INNER, LOUTER)
from django.db.models.sql.datastructures import (
EmptyResultSet, Empty, MultiJoin, Join, BaseTable)
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@ -46,7 +46,7 @@ class RawQuery(object):
A single raw SQL query
"""
def __init__(self, sql, using, params=None):
def __init__(self, sql, using, params=None, context=None):
self.params = params or ()
self.sql = sql
self.using = using
@ -57,9 +57,10 @@ class RawQuery(object):
self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.extra_select = {}
self.annotation_select = {}
self.context = context or {}
def clone(self, using):
return RawQuery(self.sql, using, params=self.params)
return RawQuery(self.sql, using, params=self.params, context=self.context.copy())
def get_columns(self):
if self.cursor is None:
@ -122,20 +123,23 @@ class Query(object):
self.standard_ordering = True
self.used_aliases = set()
self.filter_is_sticky = False
self.included_inherited_models = {}
# SQL-related attributes
# Select and related select clauses as SelectInfo instances.
# Select and related select clauses are expressions to use in the
# SELECT clause of the query.
# The select is used for cases where we want to set up the select
# clause to contain other than default fields (values(), annotate(),
# subqueries...)
# clause to contain other than default fields (values(), subqueries...)
# Note that annotations go to annotations dictionary.
self.select = []
# The related_select_cols is used for columns needed for
# select_related - this is populated in the compile stage.
self.related_select_cols = []
self.tables = [] # Aliases in the order they are created.
self.where = where()
self.where_class = where
# The group_by attribute can have one of the following forms:
# - None: no group by at all in the query
# - A list of expressions: group by (at least) those expressions.
# String refs are also allowed for now.
# - True: group by all select fields of the model
# See compiler.get_group_by() for details.
self.group_by = None
self.having = where()
self.order_by = []
@ -174,6 +178,8 @@ class Query(object):
# load.
self.deferred_loading = (set(), True)
self.context = {}
@property
def extra(self):
if self._extra is None:
@ -254,14 +260,14 @@ class Query(object):
obj.default_cols = self.default_cols
obj.default_ordering = self.default_ordering
obj.standard_ordering = self.standard_ordering
obj.included_inherited_models = self.included_inherited_models.copy()
obj.select = self.select[:]
obj.related_select_cols = []
obj.tables = self.tables[:]
obj.where = self.where.clone()
obj.where_class = self.where_class
if self.group_by is None:
obj.group_by = None
elif self.group_by is True:
obj.group_by = True
else:
obj.group_by = self.group_by[:]
obj.having = self.having.clone()
@ -272,7 +278,6 @@ class Query(object):
obj.select_for_update = self.select_for_update
obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_related = self.select_related
obj.related_select_cols = []
obj._annotations = self._annotations.copy() if self._annotations is not None else None
if self.annotation_select_mask is None:
obj.annotation_select_mask = None
@ -310,8 +315,15 @@ class Query(object):
obj.__dict__.update(kwargs)
if hasattr(obj, '_setup_query'):
obj._setup_query()
obj.context = self.context.copy()
return obj
def add_context(self, key, value):
self.context[key] = value
def get_context(self, key, default=None):
return self.context.get(key, default)
def relabeled_clone(self, change_map):
clone = self.clone()
clone.change_aliases(change_map)
@ -375,7 +387,8 @@ class Query(object):
# done in a subquery so that we are aggregating on the limit and/or
# distinct results instead of applying the distinct and limit after the
# aggregation.
if (self.group_by or has_limit or has_existing_annotations or self.distinct):
if (isinstance(self.group_by, list) or has_limit or has_existing_annotations or
self.distinct):
from django.db.models.sql.subqueries import AggregateQuery
outer_query = AggregateQuery(self.model)
inner_query = self.clone()
@ -383,7 +396,6 @@ class Query(object):
inner_query.clear_ordering(True)
inner_query.select_for_update = False
inner_query.select_related = False
inner_query.related_select_cols = []
relabels = {t: 'subquery' for t in inner_query.tables}
relabels[None] = 'subquery'
@ -407,26 +419,17 @@ class Query(object):
self.select = []
self.default_cols = False
self._extra = {}
self.remove_inherited_models()
outer_query.clear_ordering(True)
outer_query.clear_limits()
outer_query.select_for_update = False
outer_query.select_related = False
outer_query.related_select_cols = []
compiler = outer_query.get_compiler(using)
result = compiler.execute_sql(SINGLE)
if result is None:
result = [None for q in outer_query.annotation_select.items()]
fields = [annotation.output_field
for alias, annotation in outer_query.annotation_select.items()]
converters = compiler.get_converters(fields)
for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()):
if position in converters:
converters[position][1].insert(0, annotation.convert_value)
else:
converters[position] = ([], [annotation.convert_value], annotation.output_field)
converters = compiler.get_converters(outer_query.annotation_select.values())
result = compiler.apply_converters(result, converters)
return {
@ -476,7 +479,6 @@ class Query(object):
assert self.distinct_fields == rhs.distinct_fields, \
"Cannot combine queries with different distinct fields."
self.remove_inherited_models()
# Work out how to relabel the rhs aliases, if necessary.
change_map = {}
conjunction = (connector == AND)
@ -545,13 +547,8 @@ class Query(object):
# Selection columns and extra extensions are those provided by 'rhs'.
self.select = []
for col, field in rhs.select:
if isinstance(col, (list, tuple)):
new_col = change_map.get(col[0], col[0]), col[1]
self.select.append(SelectInfo(new_col, field))
else:
new_col = col.relabeled_clone(change_map)
self.select.append(SelectInfo(new_col, field))
for col in rhs.select:
self.add_select(col.relabeled_clone(change_map))
if connector == OR:
# It would be nice to be able to handle this, but the queries don't
@ -661,17 +658,6 @@ class Query(object):
for model, values in six.iteritems(seen):
callback(target, model, values)
def deferred_to_columns_cb(self, target, model, fields):
"""
Callback used by deferred_to_columns(). The "target" parameter should
be a set instance.
"""
table = model._meta.db_table
if table not in target:
target[table] = set()
for field in fields:
target[table].add(field.column)
def table_alias(self, table_name, create=False):
"""
Returns a table alias for the given table_name and whether this is a
@ -788,10 +774,9 @@ class Query(object):
# "group by", "where" and "having".
self.where.relabel_aliases(change_map)
self.having.relabel_aliases(change_map)
if self.group_by:
if isinstance(self.group_by, list):
self.group_by = [relabel_column(col) for col in self.group_by]
self.select = [SelectInfo(relabel_column(s.col), s.field)
for s in self.select]
self.select = [col.relabeled_clone(change_map) for col in self.select]
if self._annotations:
self._annotations = OrderedDict(
(key, relabel_column(col)) for key, col in self._annotations.items())
@ -815,9 +800,6 @@ class Query(object):
if alias == old_alias:
self.tables[pos] = new_alias
break
for key, alias in self.included_inherited_models.items():
if alias in change_map:
self.included_inherited_models[key] = change_map[alias]
self.external_aliases = {change_map.get(alias, alias)
for alias in self.external_aliases}
@ -930,28 +912,6 @@ class Query(object):
self.alias_map[alias] = join
return alias
def setup_inherited_models(self):
"""
If the model that is the basis for this QuerySet inherits other models,
we need to ensure that those other models have their tables included in
the query.
We do this as a separate step so that subclasses know which
tables are going to be active in the query, without needing to compute
all the select columns (this method is called from pre_sql_setup(),
whereas column determination is a later part, and side-effect, of
as_sql()).
"""
opts = self.get_meta()
root_alias = self.tables[0]
seen = {None: root_alias}
for field in opts.fields:
model = field.model._meta.concrete_model
if model is not opts.model and model not in seen:
self.join_parent_model(opts, model, root_alias, seen)
self.included_inherited_models = seen
def join_parent_model(self, opts, model, alias, seen):
"""
Makes sure the given 'model' is joined in the query. If 'model' isn't
@ -969,7 +929,9 @@ class Query(object):
curr_opts = opts
for int_model in chain:
if int_model in seen:
return seen[int_model]
curr_opts = int_model._meta
alias = seen[int_model]
continue
# Proxy model have elements in base chain
# with no parents, assign the new options
# object and skip to the next base in that
@ -984,23 +946,13 @@ class Query(object):
alias = seen[int_model] = joins[-1]
return alias or seen[None]
def remove_inherited_models(self):
"""
Undoes the effects of setup_inherited_models(). Should be called
whenever select columns (self.select) are set explicitly.
"""
for key, alias in self.included_inherited_models.items():
if key:
self.unref_alias(alias)
self.included_inherited_models = {}
def add_aggregate(self, aggregate, model, alias, is_summary):
warnings.warn(
"add_aggregate() is deprecated. Use add_annotation() instead.",
RemovedInDjango20Warning, stacklevel=2)
self.add_annotation(aggregate, alias, is_summary)
def add_annotation(self, annotation, alias, is_summary):
def add_annotation(self, annotation, alias, is_summary=False):
"""
Adds a single annotation expression to the Query
"""
@ -1011,6 +963,7 @@ class Query(object):
def prepare_lookup_value(self, value, lookups, can_reuse):
# Default lookup if none given is exact.
used_joins = []
if len(lookups) == 0:
lookups = ['exact']
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
@ -1026,7 +979,9 @@ class Query(object):
RemovedInDjango19Warning, stacklevel=2)
value = value()
elif hasattr(value, 'resolve_expression'):
pre_joins = self.alias_refcount.copy()
value = value.resolve_expression(self, reuse=can_reuse)
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
# Subqueries need to use a different set of aliases than the
# outer query. Call bump_prefix to change aliases of the inner
# query (the value).
@ -1044,7 +999,7 @@ class Query(object):
lookups[-1] == 'exact' and value == ''):
value = True
lookups[-1] = 'isnull'
return value, lookups
return value, lookups, used_joins
def solve_lookup_type(self, lookup):
"""
@ -1173,8 +1128,7 @@ class Query(object):
# Work out the lookup type and remove it from the end of 'parts',
# if necessary.
value, lookups = self.prepare_lookup_value(value, lookups, can_reuse)
used_joins = getattr(value, '_used_joins', [])
value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse)
clause = self.where_class()
if reffed_aggregate:
@ -1223,7 +1177,7 @@ class Query(object):
# handle Expressions as annotations
col = targets[0]
else:
col = Col(alias, targets[0], field)
col = targets[0].get_col(alias, field)
condition = self.build_lookup(lookups, col, value)
if not condition:
# Backwards compat for custom lookups
@ -1258,7 +1212,7 @@ class Query(object):
# <=>
# NOT (col IS NOT NULL AND col = someval).
lookup_class = targets[0].get_lookup('isnull')
clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND)
clause.add(lookup_class(targets[0].get_col(alias, sources[0]), False), AND)
return clause, used_joins if not require_outer else ()
def add_filter(self, filter_clause):
@ -1535,7 +1489,7 @@ class Query(object):
self.unref_alias(joins.pop())
return targets, joins[-1], joins
def resolve_ref(self, name, allow_joins, reuse, summarize):
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False):
if not allow_joins and LOOKUP_SEP in name:
raise FieldError("Joined field references are not permitted in this query")
if name in self.annotations:
@ -1558,8 +1512,7 @@ class Query(object):
"isn't supported")
if reuse is not None:
reuse.update(join_list)
col = Col(join_list[-1], targets[0], sources[0])
col._used_joins = join_list
col = targets[0].get_col(join_list[-1], sources[0])
return col
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
@ -1588,26 +1541,28 @@ class Query(object):
# Try to have as simple as possible subquery -> trim leading joins from
# the subquery.
trimmed_prefix, contains_louter = query.trim_start(names_with_path)
query.remove_inherited_models()
# Add extra check to make sure the selected field will not be null
# since we are adding an IN <subquery> clause. This prevents the
# database from tripping over IN (...,NULL,...) selects and returning
# nothing
alias, col = query.select[0].col
if self.is_nullable(query.select[0].field):
lookup_class = query.select[0].field.get_lookup('isnull')
lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False)
col = query.select[0]
select_field = col.field
alias = col.alias
if self.is_nullable(select_field):
lookup_class = select_field.get_lookup('isnull')
lookup = lookup_class(select_field.get_col(alias), False)
query.where.add(lookup, AND)
if alias in can_reuse:
select_field = query.select[0].field
pk = select_field.model._meta.pk
# Need to add a restriction so that outer query's filters are in effect for
# the subquery, too.
query.bump_prefix(self)
lookup_class = select_field.get_lookup('exact')
lookup = lookup_class(Col(query.select[0].col[0], pk, pk),
Col(alias, pk, pk))
# Note that the query.select[0].alias is different from alias
# due to bump_prefix above.
lookup = lookup_class(pk.get_col(query.select[0].alias),
pk.get_col(alias))
query.where.add(lookup, AND)
query.external_aliases.add(alias)
@ -1687,6 +1642,14 @@ class Query(object):
"""
self.select = []
def add_select(self, col):
self.default_cols = False
self.select.append(col)
def set_select(self, cols):
self.default_cols = False
self.select = cols
def add_distinct_fields(self, *field_names):
"""
Adds and resolves the given fields to the query's "distinct on" clause.
@ -1710,7 +1673,7 @@ class Query(object):
name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
targets, final_alias, joins = self.trim_joins(targets, joins, path)
for target in targets:
self.select.append(SelectInfo((final_alias, target.column), target))
self.add_select(target.get_col(final_alias))
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
@ -1723,7 +1686,6 @@ class Query(object):
+ list(self.annotation_select))
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
self.remove_inherited_models()
def add_ordering(self, *ordering):
"""
@ -1766,7 +1728,7 @@ class Query(object):
"""
self.group_by = []
for col, _ in self.select:
for col in self.select:
self.group_by.append(col)
if self._annotations:
@ -1789,7 +1751,6 @@ class Query(object):
for part in field.split(LOOKUP_SEP):
d = d.setdefault(part, {})
self.select_related = field_dict
self.related_select_cols = []
def add_extra(self, select, select_params, where, params, tables, order_by):
"""
@ -1897,7 +1858,7 @@ class Query(object):
"""
Callback used by get_deferred_field_names().
"""
target[model] = set(f.name for f in fields)
target[model] = {f.attname for f in fields}
def set_aggregate_mask(self, names):
warnings.warn(
@ -2041,7 +2002,7 @@ class Query(object):
if self.alias_refcount[table] > 0:
self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table)
break
self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields]
self.set_select([f.get_col(select_alias) for f in select_fields])
return trimmed_prefix, contains_louter
def is_nullable(self, field):