mirror of
https://github.com/django/django.git
synced 2025-08-26 05:24:27 +00:00
Fixed #30446 -- Resolved Value.output_field for stdlib types.
This required implementing a limited form of dynamic dispatch to combine expressions with numerical output. Refs #26355 should eventually provide a better interface for that.
This commit is contained in:
parent
d08e6f55e3
commit
1e38f1191d
10 changed files with 122 additions and 39 deletions
|
@ -1,7 +1,9 @@
|
|||
import copy
|
||||
import datetime
|
||||
import functools
|
||||
import inspect
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FieldError
|
||||
from django.db import NotSupportedError, connection
|
||||
|
@ -56,12 +58,7 @@ class Combinable:
|
|||
def _combine(self, other, connector, reversed):
|
||||
if not hasattr(other, 'resolve_expression'):
|
||||
# everything must be resolvable to an expression
|
||||
output_field = (
|
||||
fields.DurationField()
|
||||
if isinstance(other, datetime.timedelta) else
|
||||
None
|
||||
)
|
||||
other = Value(other, output_field=output_field)
|
||||
other = Value(other)
|
||||
|
||||
if reversed:
|
||||
return CombinedExpression(other, connector, self)
|
||||
|
@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
|
|||
pass
|
||||
|
||||
|
||||
_connector_combinators = {
|
||||
connector: [
|
||||
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
|
||||
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
|
||||
(fields.IntegerField, fields.FloatField, fields.FloatField),
|
||||
(fields.FloatField, fields.IntegerField, fields.FloatField),
|
||||
]
|
||||
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _resolve_combined_type(connector, lhs_type, rhs_type):
|
||||
combinators = _connector_combinators.get(connector, ())
|
||||
for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
|
||||
if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
|
||||
return combined_type
|
||||
|
||||
|
||||
class CombinedExpression(SQLiteNumericMixin, Expression):
|
||||
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
|
@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
|||
def set_source_expressions(self, exprs):
|
||||
self.lhs, self.rhs = exprs
|
||||
|
||||
def _resolve_output_field(self):
|
||||
try:
|
||||
return super()._resolve_output_field()
|
||||
except FieldError:
|
||||
combined_type = _resolve_combined_type(
|
||||
self.connector,
|
||||
type(self.lhs.output_field),
|
||||
type(self.rhs.output_field),
|
||||
)
|
||||
if combined_type is None:
|
||||
raise
|
||||
return combined_type()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
expressions = []
|
||||
expression_params = []
|
||||
|
@ -721,6 +750,30 @@ class Value(Expression):
|
|||
def get_group_by_cols(self, alias=None):
|
||||
return []
|
||||
|
||||
def _resolve_output_field(self):
|
||||
if isinstance(self.value, str):
|
||||
return fields.CharField()
|
||||
if isinstance(self.value, bool):
|
||||
return fields.BooleanField()
|
||||
if isinstance(self.value, int):
|
||||
return fields.IntegerField()
|
||||
if isinstance(self.value, float):
|
||||
return fields.FloatField()
|
||||
if isinstance(self.value, datetime.datetime):
|
||||
return fields.DateTimeField()
|
||||
if isinstance(self.value, datetime.date):
|
||||
return fields.DateField()
|
||||
if isinstance(self.value, datetime.time):
|
||||
return fields.TimeField()
|
||||
if isinstance(self.value, datetime.timedelta):
|
||||
return fields.DurationField()
|
||||
if isinstance(self.value, Decimal):
|
||||
return fields.DecimalField()
|
||||
if isinstance(self.value, bytes):
|
||||
return fields.BinaryField()
|
||||
if isinstance(self.value, UUID):
|
||||
return fields.UUIDField()
|
||||
|
||||
|
||||
class RawSQL(Expression):
|
||||
def __init__(self, sql, params, output_field=None):
|
||||
|
@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression):
|
|||
copy.expression = Case(
|
||||
When(self.expression, then=True),
|
||||
default=False,
|
||||
output_field=fields.BooleanField(),
|
||||
)
|
||||
return copy.as_sql(compiler, connection)
|
||||
return self.as_sql(compiler, connection)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue