Fixed #36030 - Divide an integer field by a constant decimal.Decimal returns inconsistent decimal (sqlite3 & postgre only)

This commit is contained in:
greg 2025-01-16 09:26:57 +01:00
parent 098c8bc99c
commit d113964672
6 changed files with 99 additions and 16 deletions

View file

@ -658,7 +658,7 @@ class BaseDatabaseOperations:
""" """
return True return True
def combine_expression(self, connector, sub_expressions): def combine_expression(self, connector, sub_expressions, output_field=None):
""" """
Combine a list of subexpressions into a single expression, using Combine a list of subexpressions into a single expression, using
the provided connecting operator. This is required because operators the provided connecting operator. This is required because operators

View file

@ -287,7 +287,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def pk_default_value(self): def pk_default_value(self):
return "NULL" return "NULL"
def combine_expression(self, connector, sub_expressions): def combine_expression(self, connector, sub_expressions, output_field=None):
if connector == "^": if connector == "^":
return "POW(%s)" % ",".join(sub_expressions) return "POW(%s)" % ",".join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators # Convert the result to a signed integer since MySQL's binary operators
@ -298,7 +298,9 @@ class DatabaseOperations(BaseDatabaseOperations):
elif connector == ">>": elif connector == ">>":
lhs, rhs = sub_expressions lhs, rhs = sub_expressions
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
return super().combine_expression(connector, sub_expressions) return super().combine_expression(
connector, sub_expressions, output_field=output_field
)
def get_db_converters(self, expression): def get_db_converters(self, expression):
converters = super().get_db_converters(expression) converters = super().get_db_converters(expression)

View file

@ -420,3 +420,15 @@ class DatabaseOperations(BaseDatabaseOperations):
rhs_expr = Cast(rhs_expr, lhs_field) rhs_expr = Cast(rhs_expr, lhs_field)
return lhs_expr, rhs_expr return lhs_expr, rhs_expr
def combine_expression(self, connector, sub_expressions, output_field=None):
if (
connector == "/"
and output_field
and output_field.get_internal_type() in ("FloatField", "DecimalField")
):
lhs, rhs = sub_expressions
return f"CAST({lhs} AS NUMERIC) / {rhs}"
return super().combine_expression(
connector, sub_expressions, output_field=output_field
)

View file

@ -363,14 +363,32 @@ class DatabaseOperations(BaseDatabaseOperations):
def convert_booleanfield_value(self, value, expression, connection): def convert_booleanfield_value(self, value, expression, connection):
return bool(value) if value in (1, 0) else value return bool(value) if value in (1, 0) else value
def combine_expression(self, connector, sub_expressions): def combine_expression(self, connector, sub_expressions, output_field=None):
# SQLite doesn't have a ^ operator, so use the user-defined POWER # SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect(). # function that's registered in connect().
if connector == "^": if connector == "^":
return "POWER(%s)" % ",".join(sub_expressions) return "POWER(%s)" % ",".join(sub_expressions)
elif connector == "#": elif connector == "#":
return "BITXOR(%s)" % ",".join(sub_expressions) return "BITXOR(%s)" % ",".join(sub_expressions)
return super().combine_expression(connector, sub_expressions) elif connector == "/":
lhs, rhs = sub_expressions
# SQLite performs floating-point division. To ensure results match the
# expected output_field type:
# - For FloatField/DecimalField, ensure REAL division.
# - For other types (e.g. IntegerField), perform REAL division,
# then ROUND and CAST to INTEGER to mimic integer division behavior.
if output_field and output_field.get_internal_type() in (
"FloatField",
"DecimalField",
):
return f"CAST({lhs} AS REAL) / CAST({rhs} AS REAL)"
else:
return (
f"CAST(ROUND(CAST({lhs} AS REAL) / CAST({rhs} AS REAL)) AS INTEGER)"
)
return super().combine_expression(
connector, sub_expressions, output_field=output_field
)
def combine_duration_expression(self, connector, sub_expressions): def combine_duration_expression(self, connector, sub_expressions):
if connector not in ["+", "-", "*", "/"]: if connector not in ["+", "-", "*", "/"]:

View file

@ -756,18 +756,29 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
return combined_type() return combined_type()
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# We need output_field for specific combined operations
# AND we don't want to block the run in case of None
try:
output_field = self.output_field
except FieldError:
output_field = None
sql, params = self._compile_expressions(compiler)
sql = self._handle_operator(sql, connection, output_field=output_field)
return f"({sql})", params
def _compile_expressions(self, compiler):
expressions = [] expressions = []
expression_params = [] params = []
sql, params = compiler.compile(self.lhs) for expr in [self.lhs, self.rhs]:
expressions.append(sql) sql, param = compiler.compile(expr)
expression_params.extend(params) expressions.append(sql)
sql, params = compiler.compile(self.rhs) params.extend(param)
expressions.append(sql) return expressions, params
expression_params.extend(params)
# order of precedence def _handle_operator(self, expressions, connection, output_field=None):
expression_wrapper = "(%s)" return connection.ops.combine_expression(
sql = connection.ops.combine_expression(self.connector, expressions) self.connector, expressions, output_field=output_field
return expression_wrapper % sql, expression_params )
def resolve_expression( def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False

View file

@ -1599,6 +1599,46 @@ class ExpressionsNumericTests(TestCase):
n.refresh_from_db() n.refresh_from_db()
self.assertEqual(n.decimal_value, Decimal("0.1")) self.assertEqual(n.decimal_value, Decimal("0.1"))
def test_decimal_division_precision(self):
"""Test that division with Decimal preserves precision"""
obj = Number.objects.create(integer=2)
qs = Number.objects.annotate(
ratio=ExpressionWrapper(
F("integer") / Value(3.0),
output_field=DecimalField(max_digits=10, decimal_places=4),
)
).filter(pk=obj.pk)
self.assertAlmostEqual(
float(qs.get().ratio),
float(Decimal("2") / Decimal("3")),
places=4,
msg="Division should preserve decimal precision",
)
def test_decimal_division_types(self):
"""Test that division with Decimal preserves numeric type"""
Number.objects.all().delete()
for num, den, expected in [
(2, Decimal("3"), "0.6667"),
(5, Decimal("2"), "2.5000"),
(1, Decimal("3"), "0.3333"),
]:
with self.subTest(num=num, den=den):
number = Number.objects.create(integer=num)
qs = Number.objects.filter(pk=number.pk).annotate(
ratio=ExpressionWrapper(
F("integer") / Value(den, output_field=DecimalField()),
output_field=DecimalField(max_digits=10, decimal_places=4),
)
)
result = qs.get()
self.assertAlmostEqual(
float(result.ratio),
float(expected),
places=4,
msg=f"Divide {num} by {den} should give result: {expected}",
)
class ExpressionOperatorTests(TestCase): class ExpressionOperatorTests(TestCase):
@classmethod @classmethod