django/tests/constraints/tests.py
utkarsh.arya@zomato.com 1efb80ab0f Allow in_bulk() for fields with UniqueConstraint
Previously, in_bulk() required fields to have unique=True. Now, fields with an unconditional UniqueConstraint are also supported, enabling lookups like in_bulk(field_name="slug") for such fields.

Fixes issue where fields unique via UniqueConstraint could not be used with in_bulk().
2025-11-15 22:58:45 +00:00

278 lines
12 KiB
Python

from unittest import mock
from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models
from django.db.models.constraints import BaseConstraint
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import (
ChildModel, Product, UniqueConstraintConditionProduct,
UniqueConstraintInBulkProduct, UniqueConstraintProduct,
)
def get_constraints(table):
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
class BaseConstraintTests(SimpleTestCase):
def test_constraint_sql(self):
c = BaseConstraint('name')
msg = 'This method must be implemented by a subclass.'
with self.assertRaisesMessage(NotImplementedError, msg):
c.constraint_sql(None, None)
def test_create_sql(self):
c = BaseConstraint('name')
msg = 'This method must be implemented by a subclass.'
with self.assertRaisesMessage(NotImplementedError, msg):
c.create_sql(None, None)
def test_remove_sql(self):
c = BaseConstraint('name')
msg = 'This method must be implemented by a subclass.'
with self.assertRaisesMessage(NotImplementedError, msg):
c.remove_sql(None, None)
class CheckConstraintTests(TestCase):
def test_eq(self):
check1 = models.Q(price__gt=models.F('discounted_price'))
check2 = models.Q(price__lt=models.F('discounted_price'))
self.assertEqual(
models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price'),
)
self.assertEqual(models.CheckConstraint(check=check1, name='price'), mock.ANY)
self.assertNotEqual(
models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price2'),
)
self.assertNotEqual(
models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check2, name='price'),
)
self.assertNotEqual(models.CheckConstraint(check=check1, name='price'), 1)
def test_repr(self):
check = models.Q(price__gt=models.F('discounted_price'))
name = 'price_gt_discounted_price'
constraint = models.CheckConstraint(check=check, name=name)
self.assertEqual(
repr(constraint),
"<CheckConstraint: check='{}' name='{}'>".format(check, name),
)
def test_invalid_check_types(self):
msg = (
'CheckConstraint.check must be a Q instance or boolean expression.'
)
with self.assertRaisesMessage(TypeError, msg):
models.CheckConstraint(check=models.F('discounted_price'), name='check')
def test_deconstruction(self):
check = models.Q(price__gt=models.F('discounted_price'))
name = 'price_gt_discounted_price'
constraint = models.CheckConstraint(check=check, name=name)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.db.models.CheckConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {'check': check, 'name': name})
@skipUnlessDBFeature('supports_table_check_constraints')
def test_database_constraint(self):
Product.objects.create(price=10, discounted_price=5)
with self.assertRaises(IntegrityError):
Product.objects.create(price=10, discounted_price=20)
@skipUnlessDBFeature('supports_table_check_constraints')
def test_database_constraint_expression(self):
Product.objects.create(price=999, discounted_price=5)
with self.assertRaises(IntegrityError):
Product.objects.create(price=1000, discounted_price=5)
@skipUnlessDBFeature('supports_table_check_constraints')
def test_database_constraint_expressionwrapper(self):
Product.objects.create(price=499, discounted_price=5)
with self.assertRaises(IntegrityError):
Product.objects.create(price=500, discounted_price=5)
@skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints')
def test_name(self):
constraints = get_constraints(Product._meta.db_table)
for expected_name in (
'price_gt_discounted_price',
'constraints_price_lt_1000_raw',
'constraints_price_neq_500_wrap',
'constraints_product_price_gt_0',
):
with self.subTest(expected_name):
self.assertIn(expected_name, constraints)
@skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints')
def test_abstract_name(self):
constraints = get_constraints(ChildModel._meta.db_table)
self.assertIn('constraints_childmodel_adult', constraints)
class UniqueConstraintTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.p1, cls.p2 = UniqueConstraintProduct.objects.bulk_create([
UniqueConstraintProduct(name='p1', color='red'),
UniqueConstraintProduct(name='p2'),
])
def test_eq(self):
self.assertEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
)
self.assertEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
mock.ANY,
)
self.assertNotEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'),
)
self.assertNotEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'baz'], name='unique'),
)
self.assertNotEqual(models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), 1)
def test_eq_with_condition(self):
self.assertEqual(
models.UniqueConstraint(
fields=['foo', 'bar'], name='unique',
condition=models.Q(foo=models.F('bar'))
),
models.UniqueConstraint(
fields=['foo', 'bar'], name='unique',
condition=models.Q(foo=models.F('bar'))),
)
self.assertNotEqual(
models.UniqueConstraint(
fields=['foo', 'bar'],
name='unique',
condition=models.Q(foo=models.F('bar'))
),
models.UniqueConstraint(
fields=['foo', 'bar'],
name='unique',
condition=models.Q(foo=models.F('baz'))
),
)
def test_repr(self):
fields = ['foo', 'bar']
name = 'unique_fields'
constraint = models.UniqueConstraint(fields=fields, name=name)
self.assertEqual(
repr(constraint),
"<UniqueConstraint: fields=('foo', 'bar') name='unique_fields'>",
)
def test_repr_with_condition(self):
constraint = models.UniqueConstraint(
fields=['foo', 'bar'],
name='unique_fields',
condition=models.Q(foo=models.F('bar')),
)
self.assertEqual(
repr(constraint),
"<UniqueConstraint: fields=('foo', 'bar') name='unique_fields' "
"condition=(AND: ('foo', F(bar)))>",
)
def test_deconstruction(self):
fields = ['foo', 'bar']
name = 'unique_fields'
constraint = models.UniqueConstraint(fields=fields, name=name)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.db.models.UniqueConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {'fields': tuple(fields), 'name': name})
def test_deconstruction_with_condition(self):
fields = ['foo', 'bar']
name = 'unique_fields'
condition = models.Q(foo=models.F('bar'))
constraint = models.UniqueConstraint(fields=fields, name=name, condition=condition)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.db.models.UniqueConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {'fields': tuple(fields), 'name': name, 'condition': condition})
def test_database_constraint(self):
with self.assertRaises(IntegrityError):
UniqueConstraintProduct.objects.create(name=self.p1.name, color=self.p1.color)
@skipUnlessDBFeature('supports_partial_indexes')
def test_database_constraint_with_condition(self):
UniqueConstraintConditionProduct.objects.create(name='p1')
UniqueConstraintConditionProduct.objects.create(name='p2')
with self.assertRaises(IntegrityError):
UniqueConstraintConditionProduct.objects.create(name='p1')
def test_model_validation(self):
msg = 'Unique constraint product with this Name and Color already exists.'
with self.assertRaisesMessage(ValidationError, msg):
UniqueConstraintProduct(name=self.p1.name, color=self.p1.color).validate_unique()
@skipUnlessDBFeature('supports_partial_indexes')
def test_model_validation_with_condition(self):
"""Partial unique constraints are ignored by Model.validate_unique()."""
obj1 = UniqueConstraintConditionProduct.objects.create(name='p1', color='red')
obj2 = UniqueConstraintConditionProduct.objects.create(name='p2')
UniqueConstraintConditionProduct(name=obj1.name, color='blue').validate_unique()
UniqueConstraintConditionProduct(name=obj2.name).validate_unique()
def test_name(self):
constraints = get_constraints(UniqueConstraintProduct._meta.db_table)
expected_name = 'name_color_uniq'
self.assertIn(expected_name, constraints)
def test_condition_must_be_q(self):
with self.assertRaisesMessage(ValueError, 'UniqueConstraint.condition must be a Q instance.'):
models.UniqueConstraint(name='uniq', fields=['name'], condition='invalid')
def test_in_bulk_with_unique_constraint(self):
"""Test that in_bulk() works with UniqueConstraint on a single field."""
# Create test objects
obj1 = UniqueConstraintInBulkProduct.objects.create(name='product1')
obj2 = UniqueConstraintInBulkProduct.objects.create(name='product2')
obj3 = UniqueConstraintInBulkProduct.objects.create(name='product3')
# Test in_bulk with field_name that has UniqueConstraint
result = UniqueConstraintInBulkProduct.objects.in_bulk(
['product1', 'product2', 'product3'],
field_name='name'
)
self.assertEqual(len(result), 3)
self.assertEqual(result['product1'], obj1)
self.assertEqual(result['product2'], obj2)
self.assertEqual(result['product3'], obj3)
def test_in_bulk_without_id_list_with_unique_constraint(self):
"""Test that in_bulk() works without id_list for UniqueConstraint fields."""
obj1 = UniqueConstraintInBulkProduct.objects.create(name='product_a')
obj2 = UniqueConstraintInBulkProduct.objects.create(name='product_b')
result = UniqueConstraintInBulkProduct.objects.in_bulk(field_name='name')
self.assertEqual(len(result), 2)
self.assertEqual(result['product_a'], obj1)
self.assertEqual(result['product_b'], obj2)
@skipUnlessDBFeature('supports_partial_indexes')
def test_in_bulk_with_partial_unique_constraint_fails(self):
"""Test that in_bulk() fails for fields with conditional UniqueConstraint."""
UniqueConstraintConditionProduct.objects.create(name='p1')
msg = "in_bulk()'s field_name must be a unique field but 'name' isn't."
with self.assertRaisesMessage(ValueError, msg):
UniqueConstraintConditionProduct.objects.in_bulk(['p1'], field_name='name')