mirror of
https://github.com/django/django.git
synced 2025-11-18 02:56:45 +00:00
Fixed #36030 - Divide an integer field by a constant decimal.Decimal returns inconsistent decimal (sqlite3 & postgre only)
This commit is contained in:
parent
098c8bc99c
commit
d113964672
6 changed files with 99 additions and 16 deletions
|
|
@ -658,7 +658,7 @@ class BaseDatabaseOperations:
|
||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def combine_expression(self, connector, sub_expressions):
|
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||||
"""
|
"""
|
||||||
Combine a list of subexpressions into a single expression, using
|
Combine a list of subexpressions into a single expression, using
|
||||||
the provided connecting operator. This is required because operators
|
the provided connecting operator. This is required because operators
|
||||||
|
|
|
||||||
|
|
@ -287,7 +287,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
def pk_default_value(self):
|
def pk_default_value(self):
|
||||||
return "NULL"
|
return "NULL"
|
||||||
|
|
||||||
def combine_expression(self, connector, sub_expressions):
|
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||||
if connector == "^":
|
if connector == "^":
|
||||||
return "POW(%s)" % ",".join(sub_expressions)
|
return "POW(%s)" % ",".join(sub_expressions)
|
||||||
# Convert the result to a signed integer since MySQL's binary operators
|
# Convert the result to a signed integer since MySQL's binary operators
|
||||||
|
|
@ -298,7 +298,9 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
elif connector == ">>":
|
elif connector == ">>":
|
||||||
lhs, rhs = sub_expressions
|
lhs, rhs = sub_expressions
|
||||||
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||||
return super().combine_expression(connector, sub_expressions)
|
return super().combine_expression(
|
||||||
|
connector, sub_expressions, output_field=output_field
|
||||||
|
)
|
||||||
|
|
||||||
def get_db_converters(self, expression):
|
def get_db_converters(self, expression):
|
||||||
converters = super().get_db_converters(expression)
|
converters = super().get_db_converters(expression)
|
||||||
|
|
|
||||||
|
|
@ -420,3 +420,15 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
rhs_expr = Cast(rhs_expr, lhs_field)
|
rhs_expr = Cast(rhs_expr, lhs_field)
|
||||||
|
|
||||||
return lhs_expr, rhs_expr
|
return lhs_expr, rhs_expr
|
||||||
|
|
||||||
|
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||||
|
if (
|
||||||
|
connector == "/"
|
||||||
|
and output_field
|
||||||
|
and output_field.get_internal_type() in ("FloatField", "DecimalField")
|
||||||
|
):
|
||||||
|
lhs, rhs = sub_expressions
|
||||||
|
return f"CAST({lhs} AS NUMERIC) / {rhs}"
|
||||||
|
return super().combine_expression(
|
||||||
|
connector, sub_expressions, output_field=output_field
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -363,14 +363,32 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
def convert_booleanfield_value(self, value, expression, connection):
|
def convert_booleanfield_value(self, value, expression, connection):
|
||||||
return bool(value) if value in (1, 0) else value
|
return bool(value) if value in (1, 0) else value
|
||||||
|
|
||||||
def combine_expression(self, connector, sub_expressions):
|
def combine_expression(self, connector, sub_expressions, output_field=None):
|
||||||
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
||||||
# function that's registered in connect().
|
# function that's registered in connect().
|
||||||
if connector == "^":
|
if connector == "^":
|
||||||
return "POWER(%s)" % ",".join(sub_expressions)
|
return "POWER(%s)" % ",".join(sub_expressions)
|
||||||
elif connector == "#":
|
elif connector == "#":
|
||||||
return "BITXOR(%s)" % ",".join(sub_expressions)
|
return "BITXOR(%s)" % ",".join(sub_expressions)
|
||||||
return super().combine_expression(connector, sub_expressions)
|
elif connector == "/":
|
||||||
|
lhs, rhs = sub_expressions
|
||||||
|
# SQLite performs floating-point division. To ensure results match the
|
||||||
|
# expected output_field type:
|
||||||
|
# - For FloatField/DecimalField, ensure REAL division.
|
||||||
|
# - For other types (e.g. IntegerField), perform REAL division,
|
||||||
|
# then ROUND and CAST to INTEGER to mimic integer division behavior.
|
||||||
|
if output_field and output_field.get_internal_type() in (
|
||||||
|
"FloatField",
|
||||||
|
"DecimalField",
|
||||||
|
):
|
||||||
|
return f"CAST({lhs} AS REAL) / CAST({rhs} AS REAL)"
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
f"CAST(ROUND(CAST({lhs} AS REAL) / CAST({rhs} AS REAL)) AS INTEGER)"
|
||||||
|
)
|
||||||
|
return super().combine_expression(
|
||||||
|
connector, sub_expressions, output_field=output_field
|
||||||
|
)
|
||||||
|
|
||||||
def combine_duration_expression(self, connector, sub_expressions):
|
def combine_duration_expression(self, connector, sub_expressions):
|
||||||
if connector not in ["+", "-", "*", "/"]:
|
if connector not in ["+", "-", "*", "/"]:
|
||||||
|
|
|
||||||
|
|
@ -756,18 +756,29 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
||||||
return combined_type()
|
return combined_type()
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
# We need output_field for specific combined operations
|
||||||
|
# AND we don't want to block the run in case of None
|
||||||
|
try:
|
||||||
|
output_field = self.output_field
|
||||||
|
except FieldError:
|
||||||
|
output_field = None
|
||||||
|
sql, params = self._compile_expressions(compiler)
|
||||||
|
sql = self._handle_operator(sql, connection, output_field=output_field)
|
||||||
|
return f"({sql})", params
|
||||||
|
|
||||||
|
def _compile_expressions(self, compiler):
|
||||||
expressions = []
|
expressions = []
|
||||||
expression_params = []
|
params = []
|
||||||
sql, params = compiler.compile(self.lhs)
|
for expr in [self.lhs, self.rhs]:
|
||||||
expressions.append(sql)
|
sql, param = compiler.compile(expr)
|
||||||
expression_params.extend(params)
|
expressions.append(sql)
|
||||||
sql, params = compiler.compile(self.rhs)
|
params.extend(param)
|
||||||
expressions.append(sql)
|
return expressions, params
|
||||||
expression_params.extend(params)
|
|
||||||
# order of precedence
|
def _handle_operator(self, expressions, connection, output_field=None):
|
||||||
expression_wrapper = "(%s)"
|
return connection.ops.combine_expression(
|
||||||
sql = connection.ops.combine_expression(self.connector, expressions)
|
self.connector, expressions, output_field=output_field
|
||||||
return expression_wrapper % sql, expression_params
|
)
|
||||||
|
|
||||||
def resolve_expression(
|
def resolve_expression(
|
||||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||||
|
|
|
||||||
|
|
@ -1599,6 +1599,46 @@ class ExpressionsNumericTests(TestCase):
|
||||||
n.refresh_from_db()
|
n.refresh_from_db()
|
||||||
self.assertEqual(n.decimal_value, Decimal("0.1"))
|
self.assertEqual(n.decimal_value, Decimal("0.1"))
|
||||||
|
|
||||||
|
def test_decimal_division_precision(self):
|
||||||
|
"""Test that division with Decimal preserves precision"""
|
||||||
|
obj = Number.objects.create(integer=2)
|
||||||
|
qs = Number.objects.annotate(
|
||||||
|
ratio=ExpressionWrapper(
|
||||||
|
F("integer") / Value(3.0),
|
||||||
|
output_field=DecimalField(max_digits=10, decimal_places=4),
|
||||||
|
)
|
||||||
|
).filter(pk=obj.pk)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
float(qs.get().ratio),
|
||||||
|
float(Decimal("2") / Decimal("3")),
|
||||||
|
places=4,
|
||||||
|
msg="Division should preserve decimal precision",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_decimal_division_types(self):
|
||||||
|
"""Test that division with Decimal preserves numeric type"""
|
||||||
|
Number.objects.all().delete()
|
||||||
|
for num, den, expected in [
|
||||||
|
(2, Decimal("3"), "0.6667"),
|
||||||
|
(5, Decimal("2"), "2.5000"),
|
||||||
|
(1, Decimal("3"), "0.3333"),
|
||||||
|
]:
|
||||||
|
with self.subTest(num=num, den=den):
|
||||||
|
number = Number.objects.create(integer=num)
|
||||||
|
qs = Number.objects.filter(pk=number.pk).annotate(
|
||||||
|
ratio=ExpressionWrapper(
|
||||||
|
F("integer") / Value(den, output_field=DecimalField()),
|
||||||
|
output_field=DecimalField(max_digits=10, decimal_places=4),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = qs.get()
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
float(result.ratio),
|
||||||
|
float(expected),
|
||||||
|
places=4,
|
||||||
|
msg=f"Divide {num} by {den} should give result: {expected}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExpressionOperatorTests(TestCase):
|
class ExpressionOperatorTests(TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue