This commit is contained in:
Mariusz Felisiak 2025-11-17 18:00:46 +00:00 committed by GitHub
commit 62d7b5c3d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 125 additions and 4 deletions

View file

@ -100,6 +100,19 @@ class DecimalSerializer(BaseSerializer):
return repr(self.value), {"from decimal import Decimal"}
class DecimalContextSerializer(BaseSerializer):
def serialize(self):
decimal_imports = ["Context"]
for trap, is_used in self.value.traps.items():
# Decimal exceptions defined in the decimal module.
if is_used and trap.__module__ == "decimal":
decimal_imports.append(trap.__name__)
decimal_imports = ", ".join(sorted(decimal_imports))
return repr(self.value), {
f"from decimal import {self.value.rounding}, {decimal_imports}"
}
class DeconstructibleSerializer(BaseSerializer):
@staticmethod
def serialize_deconstructed(path, args, kwargs):
@ -372,6 +385,7 @@ class Serializer:
SettingsReference: SettingsReferenceSerializer,
float: FloatSerializer,
(bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
decimal.Context: DecimalContextSerializer,
decimal.Decimal: DecimalSerializer,
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
FUNCTION_TYPES: FunctionTypeSerializer,

View file

@ -1706,11 +1706,19 @@ class DecimalField(Field):
name=None,
max_digits=None,
decimal_places=None,
context=None,
**kwargs,
):
self.max_digits, self.decimal_places = max_digits, decimal_places
if context is not None and not isinstance(context, decimal.Context):
raise ValueError("DecimalField.context must be a decimal.Context instance.")
self._context = context
super().__init__(verbose_name, name, **kwargs)
@property
def non_db_attrs(self):
return (*super().non_db_attrs, "context")
def check(self, **kwargs):
errors = super().check(**kwargs)
@ -1822,6 +1830,8 @@ class DecimalField(Field):
@cached_property
def context(self):
if self._context is not None:
return self._context
return decimal.Context(prec=self.max_digits)
def deconstruct(self):
@ -1830,6 +1840,8 @@ class DecimalField(Field):
kwargs["max_digits"] = self.max_digits
if self.decimal_places is not None:
kwargs["decimal_places"] = self.decimal_places
if self._context is not None:
kwargs["context"] = self.context
return name, path, args, kwargs
def get_internal_type(self):

View file

@ -856,7 +856,7 @@ The default form widget for this field is a single
``DecimalField``
----------------
.. class:: DecimalField(max_digits=None, decimal_places=None, **options)
.. class:: DecimalField(max_digits=None, decimal_places=None, context=None, **options)
A fixed-precision decimal number, represented in Python by a
:class:`~decimal.Decimal` instance. It validates the input using
@ -879,6 +879,15 @@ Has the following arguments:
precision. It's also required for all database backends when
:attr:`~DecimalField.max_digits` is provided.
.. attribute:: DecimalField.context
.. versionadded:: 6.1
Optional. A :class:`decimal.Context` used to create ``Decimal`` instances
from floats. If :attr:`~DecimalField.context` is not specified, Django will
use a default context based on :attr:`~DecimalField.max_digits`, i.e.
``decimal.Context(prec=max_digits)``.
For example, to store numbers up to ``999.99`` with a resolution of 2 decimal
places, you'd use::
@ -906,6 +915,10 @@ when :attr:`~django.forms.Field.localize` is ``False`` or
Support for ``DecimalField`` with no precision was added on Oracle,
PostgreSQL, and SQLite.
.. versionchanged:: 6.1
The ``context`` argument was added.
``DurationField``
-----------------

View file

@ -252,6 +252,10 @@ Models
<django.db.models.DecimalField.decimal_places>` are no longer required to be
set on Oracle, PostgreSQL, and SQLite.
* The new :attr:`DecimalField.context <django.db.models.DecimalField.context>`
attribute allows customizing the creation of ``Decimal`` instances from
floats.
Pagination
~~~~~~~~~~

View file

@ -631,6 +631,23 @@ class WriterTests(SimpleTestCase):
string = MigrationWriter.serialize(field)[0]
self.assertEqual(string, "models.FilePathField(path=%r)" % path_like.path)
def test_serialize_decimal_context(self):
decimal_context = decimal.Context(
prec=7,
rounding=decimal.ROUND_FLOOR,
traps=[decimal.InvalidOperation, decimal.Overflow],
)
self.assertSerializedResultEqual(
decimal_context,
(
repr(decimal_context),
{
"from decimal import ROUND_FLOOR, Context, InvalidOperation, "
"Overflow"
},
),
)
def test_serialize_zoneinfo(self):
self.assertSerializedEqual(zoneinfo.ZoneInfo("Asia/Kolkata"))
self.assertSerializedResultEqual(

View file

@ -1,11 +1,11 @@
import math
from decimal import Decimal
from decimal import ROUND_DOWN, Context, Decimal
from unittest import mock
from django.core import validators
from django.core.exceptions import ValidationError
from django.db import connection, models
from django.test import TestCase
from django.test import SimpleTestCase, TestCase
from .models import BigD, Foo
@ -22,6 +22,26 @@ class DecimalFieldTests(TestCase):
self.assertEqual(f.to_python(2.0625), Decimal("2.062"))
self.assertEqual(f.to_python(2.1875), Decimal("2.188"))
def test_to_python_custom_context(self):
f = models.DecimalField(
max_digits=4,
decimal_places=2,
context=Context(prec=5, rounding=ROUND_DOWN),
)
self.assertEqual(f.to_python(3), Decimal("3"))
self.assertEqual(f.to_python("3.14"), Decimal("3.14"))
# to_python() converts floats and honors the custom context.
self.assertEqual(f.to_python(3.1415926535897), Decimal("3.1415"))
self.assertEqual(f.to_python(2.41), Decimal("2.4100"))
# Uses custom rounding of ROUND_DOWN.
self.assertEqual(f.to_python(2.06245), Decimal("2.0624"))
self.assertEqual(f.to_python(2.18775), Decimal("2.1877"))
def test_invalid_context(self):
msg = "DecimalField.context must be a decimal.Context instance."
with self.assertRaisesMessage(ValueError, msg):
models.DecimalField(context=ROUND_DOWN)
def test_invalid_value(self):
field = models.DecimalField(max_digits=4, decimal_places=2)
msg = "%s” value must be a decimal number."
@ -140,3 +160,20 @@ class DecimalFieldTests(TestCase):
obj = Foo.objects.create(a="bar", d=Decimal("8.320"))
obj.refresh_from_db()
self.assertEqual(obj.d.compare_total(Decimal("8.320")), Decimal("0"))
class TestMethods(SimpleTestCase):
def test_deconstruct(self):
field = models.DecimalField()
*_, kwargs = field.deconstruct()
self.assertEqual(kwargs, {})
custom_context = Context(prec=8, rounding=ROUND_DOWN)
field = models.DecimalField(
decimal_places=4,
max_digits=8,
context=custom_context,
)
*_, kwargs = field.deconstruct()
self.assertEqual(
kwargs, {"decimal_places": 4, "max_digits": 8, "context": custom_context}
)

View file

@ -2,7 +2,7 @@ import datetime
import itertools
import unittest
from copy import copy
from decimal import Decimal
from decimal import ROUND_HALF_UP, Context, Decimal
from unittest import mock
from django.core.exceptions import FieldError
@ -4904,6 +4904,30 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor, self.assertNumQueries(0):
editor.alter_field(Author, new_field, old_field, strict=True)
@isolate_apps("schema")
def test_alter_decimalfield_context_noop(self):
class ModelWithDecimalField(Model):
field = DecimalField(max_digits=5, decimal_places=2)
class Meta:
app_label = "schema"
with connection.schema_editor() as editor:
editor.create_model(ModelWithDecimalField)
self.isolated_local_models = [ModelWithDecimalField]
old_field = ModelWithDecimalField._meta.get_field("field")
new_field = DecimalField(
max_digits=5,
decimal_places=2,
context=Context(prec=13, rounding=ROUND_HALF_UP),
)
new_field.set_attributes_from_name("field")
with connection.schema_editor() as editor, self.assertNumQueries(0):
editor.alter_field(Author, old_field, new_field, strict=True)
with connection.schema_editor() as editor, self.assertNumQueries(0):
editor.alter_field(Author, new_field, old_field, strict=True)
def test_add_textfield_unhashable_default(self):
# Create the table
with connection.schema_editor() as editor: