Fixed #30446 -- Resolved Value.output_field for stdlib types.

This required implementing a limited form of dynamic dispatch to combine
expressions with numerical output. Refs #26355 should eventually provide
a better interface for that.
This commit is contained in:
Simon Charette 2019-05-12 17:17:47 -04:00 committed by Mariusz Felisiak
parent d08e6f55e3
commit 1e38f1191d
10 changed files with 122 additions and 39 deletions

View file

@ -1,7 +1,9 @@
import copy
import datetime
import functools
import inspect
from decimal import Decimal
from uuid import UUID
from django.core.exceptions import EmptyResultSet, FieldError
from django.db import NotSupportedError, connection
@ -56,12 +58,7 @@ class Combinable:
def _combine(self, other, connector, reversed):
if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression
output_field = (
fields.DurationField()
if isinstance(other, datetime.timedelta) else
None
)
other = Value(other, output_field=output_field)
other = Value(other)
if reversed:
return CombinedExpression(other, connector, self)
@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
pass
_connector_combinators = {
connector: [
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
(fields.IntegerField, fields.FloatField, fields.FloatField),
(fields.FloatField, fields.IntegerField, fields.FloatField),
]
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
}
@functools.lru_cache(maxsize=128)
def _resolve_combined_type(connector, lhs_type, rhs_type):
combinators = _connector_combinators.get(connector, ())
for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
return combined_type
class CombinedExpression(SQLiteNumericMixin, Expression):
def __init__(self, lhs, connector, rhs, output_field=None):
@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs):
self.lhs, self.rhs = exprs
def _resolve_output_field(self):
try:
return super()._resolve_output_field()
except FieldError:
combined_type = _resolve_combined_type(
self.connector,
type(self.lhs.output_field),
type(self.rhs.output_field),
)
if combined_type is None:
raise
return combined_type()
def as_sql(self, compiler, connection):
expressions = []
expression_params = []
@ -721,6 +750,30 @@ class Value(Expression):
def get_group_by_cols(self, alias=None):
return []
def _resolve_output_field(self):
if isinstance(self.value, str):
return fields.CharField()
if isinstance(self.value, bool):
return fields.BooleanField()
if isinstance(self.value, int):
return fields.IntegerField()
if isinstance(self.value, float):
return fields.FloatField()
if isinstance(self.value, datetime.datetime):
return fields.DateTimeField()
if isinstance(self.value, datetime.date):
return fields.DateField()
if isinstance(self.value, datetime.time):
return fields.TimeField()
if isinstance(self.value, datetime.timedelta):
return fields.DurationField()
if isinstance(self.value, Decimal):
return fields.DecimalField()
if isinstance(self.value, bytes):
return fields.BinaryField()
if isinstance(self.value, UUID):
return fields.UUIDField()
class RawSQL(Expression):
def __init__(self, sql, params, output_field=None):
@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression):
copy.expression = Case(
When(self.expression, then=True),
default=False,
output_field=fields.BooleanField(),
)
return copy.as_sql(compiler, connection)
return self.as_sql(compiler, connection)