Fixed #24747 -- Allowed transforms in QuerySet.order_by() and distinct(*fields).

This commit is contained in:
Matthew Wilkes 2017-06-18 16:53:40 +01:00 committed by Tim Graham
parent bf26f66029
commit 2162f0983d
15 changed files with 260 additions and 38 deletions

View file

@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL
databases). The abstraction barrier only works one way: this module has to know
all about the internals of models in order to get the information it needs.
"""
import functools
from collections import Counter, OrderedDict, namedtuple
from collections.abc import Iterator, Mapping
from itertools import chain, count, product
@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
from django.db.models.fields import Field
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.lookups import Lookup
from django.db.models.query_utils import (
@ -56,7 +58,7 @@ def get_children_from_q(q):
JoinInfo = namedtuple(
'JoinInfo',
('final_field', 'targets', 'opts', 'joins', 'path')
('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function')
)
@ -1429,8 +1431,11 @@ class Query:
generate a MultiJoin exception.
Return the final field involved in the joins, the target field (used
for any 'where' constraint), the final 'opts' value, the joins and the
field path travelled to generate the joins.
for any 'where' constraint), the final 'opts' value, the joins, the
field path traveled to generate the joins, and a transform function
that takes a field and alias and is equivalent to `field.get_col(alias)`
in the simple case but wraps field transforms if they were included in
names.
The target field is the field containing the concrete value. Final
field can be something different, for example foreign key pointing to
@ -1439,10 +1444,46 @@ class Query:
key field for example).
"""
joins = [alias]
# First, generate the path for the names
path, final_field, targets, rest = self.names_to_path(
names, opts, allow_many, fail_on_missing=True)
# The transform can't be applied yet, as joins must be trimmed later.
# To avoid making every caller of this method look up transforms
# directly, compute transforms here and and create a partial that
# converts fields to the appropriate wrapped version.
def final_transformer(field, alias):
return field.get_col(alias)
# Try resolving all the names as fields first. If there's an error,
# treat trailing names as lookups until a field can be resolved.
last_field_exception = None
for pivot in range(len(names), 0, -1):
try:
path, final_field, targets, rest = self.names_to_path(
names[:pivot], opts, allow_many, fail_on_missing=True,
)
except FieldError as exc:
if pivot == 1:
# The first item cannot be a lookup, so it's safe
# to raise the field error here.
raise
else:
last_field_exception = exc
else:
# The transforms are the remaining items that couldn't be
# resolved into fields.
transforms = names[pivot:]
break
for name in transforms:
def transform(field, alias, *, name, previous):
try:
wrapped = previous(field, alias)
return self.try_transform(wrapped, name)
except FieldError:
# FieldError is raised if the transform doesn't exist.
if isinstance(final_field, Field) and last_field_exception:
raise last_field_exception
else:
raise
final_transformer = functools.partial(transform, name=name, previous=final_transformer)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@ -1470,7 +1511,7 @@ class Query:
joins.append(alias)
if filtered_relation:
filtered_relation.path = joins[:]
return JoinInfo(final_field, targets, opts, joins, path)
return JoinInfo(final_field, targets, opts, joins, path, final_transformer)
def trim_joins(self, targets, joins, path):
"""
@ -1683,7 +1724,7 @@ class Query:
join_info.path,
)
for target in targets:
cols.append(target.get_col(final_alias))
cols.append(join_info.transform_function(target, final_alias))
if cols:
self.set_select(cols)
except MultiJoin: