From d1139646723a9756a8d31e2d8bf337a80437effc Mon Sep 17 00:00:00 2001 From: greg Date: Thu, 16 Jan 2025 09:26:57 +0100 Subject: [PATCH] Fixed #36030 - Divide an integer field by a constant decimal.Decimal returns inconsistent decimal (sqlite3 & postgre only) --- django/db/backends/base/operations.py | 2 +- django/db/backends/mysql/operations.py | 6 ++-- django/db/backends/postgresql/operations.py | 12 +++++++ django/db/backends/sqlite3/operations.py | 22 ++++++++++-- django/db/models/expressions.py | 33 +++++++++++------ tests/expressions/tests.py | 40 +++++++++++++++++++++ 6 files changed, 99 insertions(+), 16 deletions(-) diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index fea73bc1e4..cf3d55d2d8 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -658,7 +658,7 @@ class BaseDatabaseOperations: """ return True - def combine_expression(self, connector, sub_expressions): + def combine_expression(self, connector, sub_expressions, output_field=None): """ Combine a list of subexpressions into a single expression, using the provided connecting operator. This is required because operators diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 9806303539..e3bb2a749d 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -287,7 +287,7 @@ class DatabaseOperations(BaseDatabaseOperations): def pk_default_value(self): return "NULL" - def combine_expression(self, connector, sub_expressions): + def combine_expression(self, connector, sub_expressions, output_field=None): if connector == "^": return "POW(%s)" % ",".join(sub_expressions) # Convert the result to a signed integer since MySQL's binary operators @@ -298,7 +298,9 @@ class DatabaseOperations(BaseDatabaseOperations): elif connector == ">>": lhs, rhs = sub_expressions return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs} - return super().combine_expression(connector, sub_expressions) + return super().combine_expression( + connector, sub_expressions, output_field=output_field + ) def get_db_converters(self, expression): converters = super().get_db_converters(expression) diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 9db755bb89..e5c500a2b7 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -420,3 +420,15 @@ class DatabaseOperations(BaseDatabaseOperations): rhs_expr = Cast(rhs_expr, lhs_field) return lhs_expr, rhs_expr + + def combine_expression(self, connector, sub_expressions, output_field=None): + if ( + connector == "/" + and output_field + and output_field.get_internal_type() in ("FloatField", "DecimalField") + ): + lhs, rhs = sub_expressions + return f"CAST({lhs} AS NUMERIC) / {rhs}" + return super().combine_expression( + connector, sub_expressions, output_field=output_field + ) diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 0ab853f766..fb854e8dd2 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -363,14 +363,32 @@ class DatabaseOperations(BaseDatabaseOperations): def convert_booleanfield_value(self, value, expression, connection): return bool(value) if value in (1, 0) else value - def combine_expression(self, connector, sub_expressions): + def combine_expression(self, connector, sub_expressions, output_field=None): # SQLite doesn't have a ^ operator, so use the user-defined POWER # function that's registered in connect(). if connector == "^": return "POWER(%s)" % ",".join(sub_expressions) elif connector == "#": return "BITXOR(%s)" % ",".join(sub_expressions) - return super().combine_expression(connector, sub_expressions) + elif connector == "/": + lhs, rhs = sub_expressions + # SQLite performs floating-point division. To ensure results match the + # expected output_field type: + # - For FloatField/DecimalField, ensure REAL division. + # - For other types (e.g. IntegerField), perform REAL division, + # then ROUND and CAST to INTEGER to mimic integer division behavior. + if output_field and output_field.get_internal_type() in ( + "FloatField", + "DecimalField", + ): + return f"CAST({lhs} AS REAL) / CAST({rhs} AS REAL)" + else: + return ( + f"CAST(ROUND(CAST({lhs} AS REAL) / CAST({rhs} AS REAL)) AS INTEGER)" + ) + return super().combine_expression( + connector, sub_expressions, output_field=output_field + ) def combine_duration_expression(self, connector, sub_expressions): if connector not in ["+", "-", "*", "/"]: diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 46c3b63a91..cff42c6de8 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -756,18 +756,29 @@ class CombinedExpression(SQLiteNumericMixin, Expression): return combined_type() def as_sql(self, compiler, connection): + # We need output_field for specific combined operations + # AND we don't want to block the run in case of None + try: + output_field = self.output_field + except FieldError: + output_field = None + sql, params = self._compile_expressions(compiler) + sql = self._handle_operator(sql, connection, output_field=output_field) + return f"({sql})", params + + def _compile_expressions(self, compiler): expressions = [] - expression_params = [] - sql, params = compiler.compile(self.lhs) - expressions.append(sql) - expression_params.extend(params) - sql, params = compiler.compile(self.rhs) - expressions.append(sql) - expression_params.extend(params) - # order of precedence - expression_wrapper = "(%s)" - sql = connection.ops.combine_expression(self.connector, expressions) - return expression_wrapper % sql, expression_params + params = [] + for expr in [self.lhs, self.rhs]: + sql, param = compiler.compile(expr) + expressions.append(sql) + params.extend(param) + return expressions, params + + def _handle_operator(self, expressions, connection, output_field=None): + return connection.ops.combine_expression( + self.connector, expressions, output_field=output_field + ) def resolve_expression( self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 1fb4e2f34d..e9f7330598 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1599,6 +1599,46 @@ class ExpressionsNumericTests(TestCase): n.refresh_from_db() self.assertEqual(n.decimal_value, Decimal("0.1")) + def test_decimal_division_precision(self): + """Test that division with Decimal preserves precision""" + obj = Number.objects.create(integer=2) + qs = Number.objects.annotate( + ratio=ExpressionWrapper( + F("integer") / Value(3.0), + output_field=DecimalField(max_digits=10, decimal_places=4), + ) + ).filter(pk=obj.pk) + self.assertAlmostEqual( + float(qs.get().ratio), + float(Decimal("2") / Decimal("3")), + places=4, + msg="Division should preserve decimal precision", + ) + + def test_decimal_division_types(self): + """Test that division with Decimal preserves numeric type""" + Number.objects.all().delete() + for num, den, expected in [ + (2, Decimal("3"), "0.6667"), + (5, Decimal("2"), "2.5000"), + (1, Decimal("3"), "0.3333"), + ]: + with self.subTest(num=num, den=den): + number = Number.objects.create(integer=num) + qs = Number.objects.filter(pk=number.pk).annotate( + ratio=ExpressionWrapper( + F("integer") / Value(den, output_field=DecimalField()), + output_field=DecimalField(max_digits=10, decimal_places=4), + ) + ) + result = qs.get() + self.assertAlmostEqual( + float(result.ratio), + float(expected), + places=4, + msg=f"Divide {num} by {den} should give result: {expected}", + ) + class ExpressionOperatorTests(TestCase): @classmethod