Another approach

This commit is contained in:
Baptiste Mispelon 2025-04-25 13:45:46 +02:00 committed by Baptiste Mispelon
parent d5c2e925b4
commit a74954a5a7
4 changed files with 17 additions and 21 deletions

View file

@ -923,11 +923,8 @@ class BaseDatabaseSchemaEditor:
def _field_db_check(self, field, field_db_params):
# Always check constraints with the same mocked column name to avoid
# recreating constrains when the column is renamed.
if (constraint := field.db_check(self.connection)) is None:
return None
data = field.db_type_parameters(self.connection)
data["column"] = "__column_name__"
return constraint % data
overrides = {"column": "__column_name__"}
return field.db_check(self.connection, **overrides)
def _alter_field(
self,

View file

@ -840,13 +840,15 @@ class Field(RegisterLookupMixin):
def db_type_parameters(self, connection):
return DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
def db_check(self, connection):
def db_check(self, connection, **overrides):
"""
Return the database column check constraint for this field, for the
provided connection. Works the same way as db_type() for the case that
get_internal_type() does not map to a preexisting model field.
Any keyword arguments provided will override the ones received from
db_type_parameters() and used for formatting the constraint's SQL string.
"""
data = self.db_type_parameters(connection)
data = self.db_type_parameters(connection) | overrides
try:
return (
connection.data_type_check_constraints[self.get_internal_type()] % data

View file

@ -1215,7 +1215,7 @@ class ForeignKey(ForeignObject):
}
)
def db_check(self, connection):
def db_check(self, connection, **overrides):
return None
def db_type(self, connection):
@ -2052,7 +2052,7 @@ class ManyToManyField(RelatedField):
defaults["initial"] = [i.pk for i in initial]
return super().formfield(**defaults)
def db_check(self, connection):
def db_check(self, connection, **overrides):
return None
def db_type(self, connection):

View file

@ -65,6 +65,7 @@ from django.db.models.functions import (
Upper,
)
from django.db.models.indexes import IndexExpression
from django.db.models.lookups import In as InLookup
from django.db.transaction import TransactionManagementError, atomic
from django.test import TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext, isolate_apps, register_lookup
@ -2486,6 +2487,7 @@ class SchemaTests(TransactionTestCase):
A custom CharField that automatically creates a db constraint to guarante
that the stored value respects the field's `choices`.
"""
@property
def non_db_attrs(self):
# Remove `choices` from non_db_attrs so that migrations that only change
@ -2493,11 +2495,12 @@ class SchemaTests(TransactionTestCase):
attrs = super().non_db_attrs
return tuple({*attrs} - {"choices"})
def db_check(self, connection):
def db_check(self, connection, **overrides):
if not self.choices:
return None
data = self.db_type_parameters(connection) | overrides
constraint = CheckConstraint(
condition=Q(**{f"{self.name}__in": dict(self.choices)}),
condition=InLookup(F(data["column"]), dict(self.choices)),
name="", # doesn't matter, Django will reassign one anyway
)
with connection.schema_editor() as schema_editor:
@ -2513,11 +2516,8 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor:
editor.create_model(ModelWithCustomField)
constraints = self.get_constraints(ModelWithCustomField._meta.db_table)
self.assertEqual(
len(constraints),
1, # just the pk constraint
)
constraints = self.get_constraints_for_column(ModelWithCustomField, "f")
self.assertEqual(len(constraints), 0)
old_field = ModelWithCustomField._meta.get_field("f")
new_field = CharChoiceField(choices=[("a", "a")])
@ -2525,11 +2525,8 @@ class SchemaTests(TransactionTestCase):
with connection.schema_editor() as editor:
editor.alter_field(ModelWithCustomField, old_field, new_field, strict=True)
constraints = self.get_constraints(ModelWithCustomField._meta.db_table)
self.assertEqual(
len(constraints),
2, # pk + custom constraint
)
constraints = self.get_constraints_for_column(ModelWithCustomField, "f")
self.assertEqual(len(constraints), 1)
def _test_m2m_create(self, M2MFieldClass):
"""