mirror of
https://github.com/django/django.git
synced 2025-11-19 19:24:46 +00:00
Merge 997226c957 into 1ce6e78dd4
This commit is contained in:
commit
62d7b5c3d2
7 changed files with 125 additions and 4 deletions
|
|
@ -100,6 +100,19 @@ class DecimalSerializer(BaseSerializer):
|
|||
return repr(self.value), {"from decimal import Decimal"}
|
||||
|
||||
|
||||
class DecimalContextSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
decimal_imports = ["Context"]
|
||||
for trap, is_used in self.value.traps.items():
|
||||
# Decimal exceptions defined in the decimal module.
|
||||
if is_used and trap.__module__ == "decimal":
|
||||
decimal_imports.append(trap.__name__)
|
||||
decimal_imports = ", ".join(sorted(decimal_imports))
|
||||
return repr(self.value), {
|
||||
f"from decimal import {self.value.rounding}, {decimal_imports}"
|
||||
}
|
||||
|
||||
|
||||
class DeconstructibleSerializer(BaseSerializer):
|
||||
@staticmethod
|
||||
def serialize_deconstructed(path, args, kwargs):
|
||||
|
|
@ -372,6 +385,7 @@ class Serializer:
|
|||
SettingsReference: SettingsReferenceSerializer,
|
||||
float: FloatSerializer,
|
||||
(bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
|
||||
decimal.Context: DecimalContextSerializer,
|
||||
decimal.Decimal: DecimalSerializer,
|
||||
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
|
||||
FUNCTION_TYPES: FunctionTypeSerializer,
|
||||
|
|
|
|||
|
|
@ -1706,11 +1706,19 @@ class DecimalField(Field):
|
|||
name=None,
|
||||
max_digits=None,
|
||||
decimal_places=None,
|
||||
context=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.max_digits, self.decimal_places = max_digits, decimal_places
|
||||
if context is not None and not isinstance(context, decimal.Context):
|
||||
raise ValueError("DecimalField.context must be a decimal.Context instance.")
|
||||
self._context = context
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
@property
|
||||
def non_db_attrs(self):
|
||||
return (*super().non_db_attrs, "context")
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
|
||||
|
|
@ -1822,6 +1830,8 @@ class DecimalField(Field):
|
|||
|
||||
@cached_property
|
||||
def context(self):
|
||||
if self._context is not None:
|
||||
return self._context
|
||||
return decimal.Context(prec=self.max_digits)
|
||||
|
||||
def deconstruct(self):
|
||||
|
|
@ -1830,6 +1840,8 @@ class DecimalField(Field):
|
|||
kwargs["max_digits"] = self.max_digits
|
||||
if self.decimal_places is not None:
|
||||
kwargs["decimal_places"] = self.decimal_places
|
||||
if self._context is not None:
|
||||
kwargs["context"] = self.context
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
|
|
|
|||
|
|
@ -856,7 +856,7 @@ The default form widget for this field is a single
|
|||
``DecimalField``
|
||||
----------------
|
||||
|
||||
.. class:: DecimalField(max_digits=None, decimal_places=None, **options)
|
||||
.. class:: DecimalField(max_digits=None, decimal_places=None, context=None, **options)
|
||||
|
||||
A fixed-precision decimal number, represented in Python by a
|
||||
:class:`~decimal.Decimal` instance. It validates the input using
|
||||
|
|
@ -879,6 +879,15 @@ Has the following arguments:
|
|||
precision. It's also required for all database backends when
|
||||
:attr:`~DecimalField.max_digits` is provided.
|
||||
|
||||
.. attribute:: DecimalField.context
|
||||
|
||||
.. versionadded:: 6.1
|
||||
|
||||
Optional. A :class:`decimal.Context` used to create ``Decimal`` instances
|
||||
from floats. If :attr:`~DecimalField.context` is not specified, Django will
|
||||
use a default context based on :attr:`~DecimalField.max_digits`, i.e.
|
||||
``decimal.Context(prec=max_digits)``.
|
||||
|
||||
For example, to store numbers up to ``999.99`` with a resolution of 2 decimal
|
||||
places, you'd use::
|
||||
|
||||
|
|
@ -906,6 +915,10 @@ when :attr:`~django.forms.Field.localize` is ``False`` or
|
|||
Support for ``DecimalField`` with no precision was added on Oracle,
|
||||
PostgreSQL, and SQLite.
|
||||
|
||||
.. versionchanged:: 6.1
|
||||
|
||||
The ``context`` argument was added.
|
||||
|
||||
``DurationField``
|
||||
-----------------
|
||||
|
||||
|
|
|
|||
|
|
@ -252,6 +252,10 @@ Models
|
|||
<django.db.models.DecimalField.decimal_places>` are no longer required to be
|
||||
set on Oracle, PostgreSQL, and SQLite.
|
||||
|
||||
* The new :attr:`DecimalField.context <django.db.models.DecimalField.context>`
|
||||
attribute allows customizing the creation of ``Decimal`` instances from
|
||||
floats.
|
||||
|
||||
Pagination
|
||||
~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -631,6 +631,23 @@ class WriterTests(SimpleTestCase):
|
|||
string = MigrationWriter.serialize(field)[0]
|
||||
self.assertEqual(string, "models.FilePathField(path=%r)" % path_like.path)
|
||||
|
||||
def test_serialize_decimal_context(self):
|
||||
decimal_context = decimal.Context(
|
||||
prec=7,
|
||||
rounding=decimal.ROUND_FLOOR,
|
||||
traps=[decimal.InvalidOperation, decimal.Overflow],
|
||||
)
|
||||
self.assertSerializedResultEqual(
|
||||
decimal_context,
|
||||
(
|
||||
repr(decimal_context),
|
||||
{
|
||||
"from decimal import ROUND_FLOOR, Context, InvalidOperation, "
|
||||
"Overflow"
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def test_serialize_zoneinfo(self):
|
||||
self.assertSerializedEqual(zoneinfo.ZoneInfo("Asia/Kolkata"))
|
||||
self.assertSerializedResultEqual(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import math
|
||||
from decimal import Decimal
|
||||
from decimal import ROUND_DOWN, Context, Decimal
|
||||
from unittest import mock
|
||||
|
||||
from django.core import validators
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import connection, models
|
||||
from django.test import TestCase
|
||||
from django.test import SimpleTestCase, TestCase
|
||||
|
||||
from .models import BigD, Foo
|
||||
|
||||
|
|
@ -22,6 +22,26 @@ class DecimalFieldTests(TestCase):
|
|||
self.assertEqual(f.to_python(2.0625), Decimal("2.062"))
|
||||
self.assertEqual(f.to_python(2.1875), Decimal("2.188"))
|
||||
|
||||
def test_to_python_custom_context(self):
|
||||
f = models.DecimalField(
|
||||
max_digits=4,
|
||||
decimal_places=2,
|
||||
context=Context(prec=5, rounding=ROUND_DOWN),
|
||||
)
|
||||
self.assertEqual(f.to_python(3), Decimal("3"))
|
||||
self.assertEqual(f.to_python("3.14"), Decimal("3.14"))
|
||||
# to_python() converts floats and honors the custom context.
|
||||
self.assertEqual(f.to_python(3.1415926535897), Decimal("3.1415"))
|
||||
self.assertEqual(f.to_python(2.41), Decimal("2.4100"))
|
||||
# Uses custom rounding of ROUND_DOWN.
|
||||
self.assertEqual(f.to_python(2.06245), Decimal("2.0624"))
|
||||
self.assertEqual(f.to_python(2.18775), Decimal("2.1877"))
|
||||
|
||||
def test_invalid_context(self):
|
||||
msg = "DecimalField.context must be a decimal.Context instance."
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
models.DecimalField(context=ROUND_DOWN)
|
||||
|
||||
def test_invalid_value(self):
|
||||
field = models.DecimalField(max_digits=4, decimal_places=2)
|
||||
msg = "“%s” value must be a decimal number."
|
||||
|
|
@ -140,3 +160,20 @@ class DecimalFieldTests(TestCase):
|
|||
obj = Foo.objects.create(a="bar", d=Decimal("8.320"))
|
||||
obj.refresh_from_db()
|
||||
self.assertEqual(obj.d.compare_total(Decimal("8.320")), Decimal("0"))
|
||||
|
||||
|
||||
class TestMethods(SimpleTestCase):
|
||||
def test_deconstruct(self):
|
||||
field = models.DecimalField()
|
||||
*_, kwargs = field.deconstruct()
|
||||
self.assertEqual(kwargs, {})
|
||||
custom_context = Context(prec=8, rounding=ROUND_DOWN)
|
||||
field = models.DecimalField(
|
||||
decimal_places=4,
|
||||
max_digits=8,
|
||||
context=custom_context,
|
||||
)
|
||||
*_, kwargs = field.deconstruct()
|
||||
self.assertEqual(
|
||||
kwargs, {"decimal_places": 4, "max_digits": 8, "context": custom_context}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import itertools
|
||||
import unittest
|
||||
from copy import copy
|
||||
from decimal import Decimal
|
||||
from decimal import ROUND_HALF_UP, Context, Decimal
|
||||
from unittest import mock
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
|
|
@ -4904,6 +4904,30 @@ class SchemaTests(TransactionTestCase):
|
|||
with connection.schema_editor() as editor, self.assertNumQueries(0):
|
||||
editor.alter_field(Author, new_field, old_field, strict=True)
|
||||
|
||||
@isolate_apps("schema")
|
||||
def test_alter_decimalfield_context_noop(self):
|
||||
class ModelWithDecimalField(Model):
|
||||
field = DecimalField(max_digits=5, decimal_places=2)
|
||||
|
||||
class Meta:
|
||||
app_label = "schema"
|
||||
|
||||
with connection.schema_editor() as editor:
|
||||
editor.create_model(ModelWithDecimalField)
|
||||
self.isolated_local_models = [ModelWithDecimalField]
|
||||
|
||||
old_field = ModelWithDecimalField._meta.get_field("field")
|
||||
new_field = DecimalField(
|
||||
max_digits=5,
|
||||
decimal_places=2,
|
||||
context=Context(prec=13, rounding=ROUND_HALF_UP),
|
||||
)
|
||||
new_field.set_attributes_from_name("field")
|
||||
with connection.schema_editor() as editor, self.assertNumQueries(0):
|
||||
editor.alter_field(Author, old_field, new_field, strict=True)
|
||||
with connection.schema_editor() as editor, self.assertNumQueries(0):
|
||||
editor.alter_field(Author, new_field, old_field, strict=True)
|
||||
|
||||
def test_add_textfield_unhashable_default(self):
|
||||
# Create the table
|
||||
with connection.schema_editor() as editor:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue