mirror of
https://github.com/django/django.git
synced 2025-11-18 02:56:45 +00:00
Merge d113964672 into 1ce6e78dd4
This commit is contained in:
commit
ebd4335602
6 changed files with 99 additions and 16 deletions
|
|
@ -690,7 +690,7 @@ class BaseDatabaseOperations:
|
|||
"""
|
||||
return True
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||
"""
|
||||
Combine a list of subexpressions into a single expression, using
|
||||
the provided connecting operator. This is required because operators
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||
if connector == "^":
|
||||
return "POW(%s)" % ",".join(sub_expressions)
|
||||
# Convert the result to a signed integer since MySQL's binary operators
|
||||
|
|
@ -277,7 +277,9 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
elif connector == ">>":
|
||||
lhs, rhs = sub_expressions
|
||||
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
return super().combine_expression(
|
||||
connector, sub_expressions, output_field=output_field
|
||||
)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
|
|
|
|||
|
|
@ -403,3 +403,15 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
rhs_expr = Cast(rhs_expr, lhs_field)
|
||||
|
||||
return lhs_expr, rhs_expr
|
||||
|
||||
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||
if (
|
||||
connector == "/"
|
||||
and output_field
|
||||
and output_field.get_internal_type() in ("FloatField", "DecimalField")
|
||||
):
|
||||
lhs, rhs = sub_expressions
|
||||
return f"CAST({lhs} AS NUMERIC) / {rhs}"
|
||||
return super().combine_expression(
|
||||
connector, sub_expressions, output_field=output_field
|
||||
)
|
||||
|
|
|
|||
|
|
@ -351,14 +351,32 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
return bool(value) if value in (1, 0) else value
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
||||
# function that's registered in connect().
|
||||
if connector == "^":
|
||||
return "POWER(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "#":
|
||||
return "BITXOR(%s)" % ",".join(sub_expressions)
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
elif connector == "/":
|
||||
lhs, rhs = sub_expressions
|
||||
# SQLite performs floating-point division. To ensure results match the
|
||||
# expected output_field type:
|
||||
# - For FloatField/DecimalField, ensure REAL division.
|
||||
# - For other types (e.g. IntegerField), perform REAL division,
|
||||
# then ROUND and CAST to INTEGER to mimic integer division behavior.
|
||||
if output_field and output_field.get_internal_type() in (
|
||||
"FloatField",
|
||||
"DecimalField",
|
||||
):
|
||||
return f"CAST({lhs} AS REAL) / CAST({rhs} AS REAL)"
|
||||
else:
|
||||
return (
|
||||
f"CAST(ROUND(CAST({lhs} AS REAL) / CAST({rhs} AS REAL)) AS INTEGER)"
|
||||
)
|
||||
return super().combine_expression(
|
||||
connector, sub_expressions, output_field=output_field
|
||||
)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
if connector not in ["+", "-", "*", "/"]:
|
||||
|
|
|
|||
|
|
@ -758,18 +758,29 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
|||
return combined_type()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# We need output_field for specific combined operations
|
||||
# AND we don't want to block the run in case of None
|
||||
try:
|
||||
output_field = self.output_field
|
||||
except FieldError:
|
||||
output_field = None
|
||||
sql, params = self._compile_expressions(compiler)
|
||||
sql = self._handle_operator(sql, connection, output_field=output_field)
|
||||
return f"({sql})", params
|
||||
|
||||
def _compile_expressions(self, compiler):
|
||||
expressions = []
|
||||
expression_params = []
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
expressions.append(sql)
|
||||
expression_params.extend(params)
|
||||
sql, params = compiler.compile(self.rhs)
|
||||
expressions.append(sql)
|
||||
expression_params.extend(params)
|
||||
# order of precedence
|
||||
expression_wrapper = "(%s)"
|
||||
sql = connection.ops.combine_expression(self.connector, expressions)
|
||||
return expression_wrapper % sql, expression_params
|
||||
params = []
|
||||
for expr in [self.lhs, self.rhs]:
|
||||
sql, param = compiler.compile(expr)
|
||||
expressions.append(sql)
|
||||
params.extend(param)
|
||||
return expressions, params
|
||||
|
||||
def _handle_operator(self, expressions, connection, output_field=None):
|
||||
return connection.ops.combine_expression(
|
||||
self.connector, expressions, output_field=output_field
|
||||
)
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
|
|
|
|||
|
|
@ -1653,6 +1653,46 @@ class ExpressionsNumericTests(TestCase):
|
|||
with self.assertNumQueries(expected_num_queries):
|
||||
self.assertEqual(n.decimal_value, Decimal("0.1"))
|
||||
|
||||
def test_decimal_division_precision(self):
|
||||
"""Test that division with Decimal preserves precision"""
|
||||
obj = Number.objects.create(integer=2)
|
||||
qs = Number.objects.annotate(
|
||||
ratio=ExpressionWrapper(
|
||||
F("integer") / Value(3.0),
|
||||
output_field=DecimalField(max_digits=10, decimal_places=4),
|
||||
)
|
||||
).filter(pk=obj.pk)
|
||||
self.assertAlmostEqual(
|
||||
float(qs.get().ratio),
|
||||
float(Decimal("2") / Decimal("3")),
|
||||
places=4,
|
||||
msg="Division should preserve decimal precision",
|
||||
)
|
||||
|
||||
def test_decimal_division_types(self):
|
||||
"""Test that division with Decimal preserves numeric type"""
|
||||
Number.objects.all().delete()
|
||||
for num, den, expected in [
|
||||
(2, Decimal("3"), "0.6667"),
|
||||
(5, Decimal("2"), "2.5000"),
|
||||
(1, Decimal("3"), "0.3333"),
|
||||
]:
|
||||
with self.subTest(num=num, den=den):
|
||||
number = Number.objects.create(integer=num)
|
||||
qs = Number.objects.filter(pk=number.pk).annotate(
|
||||
ratio=ExpressionWrapper(
|
||||
F("integer") / Value(den, output_field=DecimalField()),
|
||||
output_field=DecimalField(max_digits=10, decimal_places=4),
|
||||
)
|
||||
)
|
||||
result = qs.get()
|
||||
self.assertAlmostEqual(
|
||||
float(result.ratio),
|
||||
float(expected),
|
||||
places=4,
|
||||
msg=f"Divide {num} by {den} should give result: {expected}",
|
||||
)
|
||||
|
||||
|
||||
class ExpressionOperatorTests(TestCase):
|
||||
@classmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue