This commit is contained in:
Gregory Mariani 2025-11-17 08:42:00 -05:00 committed by GitHub
commit ebd4335602
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 99 additions and 16 deletions

View file

@ -690,7 +690,7 @@ class BaseDatabaseOperations:
"""
return True
def combine_expression(self, connector, sub_expressions):
def combine_expression(self, connector, sub_expressions, output_field=None):
"""
Combine a list of subexpressions into a single expression, using
the provided connecting operator. This is required because operators

View file

@ -266,7 +266,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def pk_default_value(self):
return "NULL"
def combine_expression(self, connector, sub_expressions):
def combine_expression(self, connector, sub_expressions, output_field=None):
if connector == "^":
return "POW(%s)" % ",".join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators
@ -277,7 +277,9 @@ class DatabaseOperations(BaseDatabaseOperations):
elif connector == ">>":
lhs, rhs = sub_expressions
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
return super().combine_expression(connector, sub_expressions)
return super().combine_expression(
connector, sub_expressions, output_field=output_field
)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)

View file

@ -403,3 +403,15 @@ class DatabaseOperations(BaseDatabaseOperations):
rhs_expr = Cast(rhs_expr, lhs_field)
return lhs_expr, rhs_expr
def combine_expression(self, connector, sub_expressions, output_field=None):
if (
connector == "/"
and output_field
and output_field.get_internal_type() in ("FloatField", "DecimalField")
):
lhs, rhs = sub_expressions
return f"CAST({lhs} AS NUMERIC) / {rhs}"
return super().combine_expression(
connector, sub_expressions, output_field=output_field
)

View file

@ -351,14 +351,32 @@ class DatabaseOperations(BaseDatabaseOperations):
def convert_booleanfield_value(self, value, expression, connection):
return bool(value) if value in (1, 0) else value
def combine_expression(self, connector, sub_expressions):
def combine_expression(self, connector, sub_expressions, output_field=None):
# SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect().
if connector == "^":
return "POWER(%s)" % ",".join(sub_expressions)
elif connector == "#":
return "BITXOR(%s)" % ",".join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
elif connector == "/":
lhs, rhs = sub_expressions
# SQLite performs floating-point division. To ensure results match the
# expected output_field type:
# - For FloatField/DecimalField, ensure REAL division.
# - For other types (e.g. IntegerField), perform REAL division,
# then ROUND and CAST to INTEGER to mimic integer division behavior.
if output_field and output_field.get_internal_type() in (
"FloatField",
"DecimalField",
):
return f"CAST({lhs} AS REAL) / CAST({rhs} AS REAL)"
else:
return (
f"CAST(ROUND(CAST({lhs} AS REAL) / CAST({rhs} AS REAL)) AS INTEGER)"
)
return super().combine_expression(
connector, sub_expressions, output_field=output_field
)
def combine_duration_expression(self, connector, sub_expressions):
if connector not in ["+", "-", "*", "/"]:

View file

@ -758,18 +758,29 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
return combined_type()
def as_sql(self, compiler, connection):
# We need output_field for specific combined operations
# AND we don't want to block the run in case of None
try:
output_field = self.output_field
except FieldError:
output_field = None
sql, params = self._compile_expressions(compiler)
sql = self._handle_operator(sql, connection, output_field=output_field)
return f"({sql})", params
def _compile_expressions(self, compiler):
expressions = []
expression_params = []
sql, params = compiler.compile(self.lhs)
expressions.append(sql)
expression_params.extend(params)
sql, params = compiler.compile(self.rhs)
expressions.append(sql)
expression_params.extend(params)
# order of precedence
expression_wrapper = "(%s)"
sql = connection.ops.combine_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
params = []
for expr in [self.lhs, self.rhs]:
sql, param = compiler.compile(expr)
expressions.append(sql)
params.extend(param)
return expressions, params
def _handle_operator(self, expressions, connection, output_field=None):
return connection.ops.combine_expression(
self.connector, expressions, output_field=output_field
)
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False

View file

@ -1653,6 +1653,46 @@ class ExpressionsNumericTests(TestCase):
with self.assertNumQueries(expected_num_queries):
self.assertEqual(n.decimal_value, Decimal("0.1"))
def test_decimal_division_precision(self):
"""Test that division with Decimal preserves precision"""
obj = Number.objects.create(integer=2)
qs = Number.objects.annotate(
ratio=ExpressionWrapper(
F("integer") / Value(3.0),
output_field=DecimalField(max_digits=10, decimal_places=4),
)
).filter(pk=obj.pk)
self.assertAlmostEqual(
float(qs.get().ratio),
float(Decimal("2") / Decimal("3")),
places=4,
msg="Division should preserve decimal precision",
)
def test_decimal_division_types(self):
"""Test that division with Decimal preserves numeric type"""
Number.objects.all().delete()
for num, den, expected in [
(2, Decimal("3"), "0.6667"),
(5, Decimal("2"), "2.5000"),
(1, Decimal("3"), "0.3333"),
]:
with self.subTest(num=num, den=den):
number = Number.objects.create(integer=num)
qs = Number.objects.filter(pk=number.pk).annotate(
ratio=ExpressionWrapper(
F("integer") / Value(den, output_field=DecimalField()),
output_field=DecimalField(max_digits=10, decimal_places=4),
)
)
result = qs.get()
self.assertAlmostEqual(
float(result.ratio),
float(expected),
places=4,
msg=f"Divide {num} by {den} should give result: {expected}",
)
class ExpressionOperatorTests(TestCase):
@classmethod