Refs #33476 -- Reformatted code with Black.

This commit is contained in:
django-bot 2022-02-03 20:24:19 +01:00 committed by Mariusz Felisiak
parent f68fa8b45d
commit 9c19aff7c7
1992 changed files with 139577 additions and 96284 deletions

View file

@ -6,18 +6,18 @@ from django.db import connection
from django.test import SimpleTestCase, TestCase, modify_settings
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
class PostgreSQLSimpleTestCase(SimpleTestCase):
pass
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
class PostgreSQLTestCase(TestCase):
pass
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
# To locate the widget's template.
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase):
pass

View file

@ -4,18 +4,29 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
]
dependencies = []
operations = [
migrations.CreateModel(
name='IntegerArrayDefaultModel',
name="IntegerArrayDefaultModel",
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None)),
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"field",
django.contrib.postgres.fields.ArrayField(
models.IntegerField(), size=None
),
),
],
options={
},
options={},
bases=(models.Model,),
),
]

View file

@ -5,14 +5,16 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('postgres_tests', '0001_initial'),
("postgres_tests", "0001_initial"),
]
operations = [
migrations.AddField(
model_name='integerarraydefaultmodel',
name='field_2',
field=django.contrib.postgres.fields.ArrayField(models.IntegerField(), default=[], size=None),
model_name="integerarraydefaultmodel",
name="field_2",
field=django.contrib.postgres.fields.ArrayField(
models.IntegerField(), default=[], size=None
),
preserve_default=False,
),
]

View file

@ -4,22 +4,36 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
]
dependencies = []
operations = [
migrations.CreateModel(
name='CharTextArrayIndexModel',
name="CharTextArrayIndexModel",
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('char', django.contrib.postgres.fields.ArrayField(
models.CharField(max_length=10), db_index=True, size=100)
),
('char2', models.CharField(max_length=11, db_index=True)),
('text', django.contrib.postgres.fields.ArrayField(models.TextField(), db_index=True)),
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"char",
django.contrib.postgres.fields.ArrayField(
models.CharField(max_length=10), db_index=True, size=100
),
),
("char2", models.CharField(max_length=11, db_index=True)),
(
"text",
django.contrib.postgres.fields.ArrayField(
models.TextField(), db_index=True
),
),
],
options={
},
options={},
bases=(models.Model,),
),
]

View file

@ -8,31 +8,41 @@ from django.db import models
try:
from django.contrib.postgres.fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField,
CITextField, DateRangeField, DateTimeRangeField, DecimalRangeField,
HStoreField, IntegerRangeField,
ArrayField,
BigIntegerRangeField,
CICharField,
CIEmailField,
CITextField,
DateRangeField,
DateTimeRangeField,
DecimalRangeField,
HStoreField,
IntegerRangeField,
)
from django.contrib.postgres.search import SearchVector, SearchVectorField
except ImportError:
class DummyArrayField(models.Field):
def __init__(self, base_field, size=None, **kwargs):
super().__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs.update({
'base_field': '',
'size': 1,
})
kwargs.update(
{
"base_field": "",
"size": 1,
}
)
return name, path, args, kwargs
class DummyContinuousRangeField(models.Field):
def __init__(self, *args, default_bounds='[)', **kwargs):
def __init__(self, *args, default_bounds="[)", **kwargs):
super().__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs['default_bounds'] = '[)'
kwargs["default_bounds"] = "[)"
return name, path, args, kwargs
ArrayField = DummyArrayField

View file

@ -1,5 +1,5 @@
SECRET_KEY = 'abcdefg'
SECRET_KEY = "abcdefg"
INSTALLED_APPS = [
'django.contrib.postgres',
"django.contrib.postgres",
]

View file

@ -4,8 +4,14 @@ from django.db import connection, migrations
try:
from django.contrib.postgres.operations import (
BloomExtension, BtreeGinExtension, BtreeGistExtension, CITextExtension,
CreateExtension, CryptoExtension, HStoreExtension, TrigramExtension,
BloomExtension,
BtreeGinExtension,
BtreeGistExtension,
CITextExtension,
CreateExtension,
CryptoExtension,
HStoreExtension,
TrigramExtension,
UnaccentExtension,
)
except ImportError:
@ -20,8 +26,7 @@ except ImportError:
needs_crypto_extension = False
else:
needs_crypto_extension = (
connection.vendor == 'postgresql' and
not connection.features.is_postgresql_13
connection.vendor == "postgresql" and not connection.features.is_postgresql_13
)
@ -34,7 +39,7 @@ class Migration(migrations.Migration):
CITextExtension(),
# Ensure CreateExtension quotes extension names by creating one with a
# dash in its name.
CreateExtension('uuid-ossp'),
CreateExtension("uuid-ossp"),
# CryptoExtension is required for RandomUUID() on PostgreSQL < 13.
CryptoExtension() if needs_crypto_extension else mock.Mock(),
HStoreExtension(),

View file

@ -1,9 +1,18 @@
from django.db import migrations, models
from ..fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
HStoreField, IntegerRangeField, SearchVectorField,
ArrayField,
BigIntegerRangeField,
CICharField,
CIEmailField,
CITextField,
DateRangeField,
DateTimeRangeField,
DecimalRangeField,
EnumField,
HStoreField,
IntegerRangeField,
SearchVectorField,
)
from ..models import TagField
@ -11,305 +20,538 @@ from ..models import TagField
class Migration(migrations.Migration):
dependencies = [
('postgres_tests', '0001_setup_extensions'),
("postgres_tests", "0001_setup_extensions"),
]
operations = [
migrations.CreateModel(
name='CharArrayModel',
name="CharArrayModel",
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', ArrayField(models.CharField(max_length=10), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
migrations.CreateModel(
name='DateTimeArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('datetimes', ArrayField(models.DateTimeField(), size=None)),
('dates', ArrayField(models.DateField(), size=None)),
('times', ArrayField(models.TimeField(), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
migrations.CreateModel(
name='HStoreModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', HStoreField(blank=True, null=True)),
('array_field', ArrayField(HStoreField(), null=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
migrations.CreateModel(
name='OtherTypesArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('ips', ArrayField(models.GenericIPAddressField(), size=None, default=list)),
('uuids', ArrayField(models.UUIDField(), size=None, default=list)),
('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None, default=list)),
('tags', ArrayField(TagField(), blank=True, null=True, size=None)),
('json', ArrayField(models.JSONField(default={}), default=[])),
('int_ranges', ArrayField(IntegerRangeField(), null=True, blank=True)),
('bigint_ranges', ArrayField(BigIntegerRangeField(), null=True, blank=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
migrations.CreateModel(
name='IntegerArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', ArrayField(models.IntegerField(), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
migrations.CreateModel(
name='NestedIntegerArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', ArrayField(ArrayField(models.IntegerField(), size=None), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
migrations.CreateModel(
name='NullableIntegerArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', ArrayField(models.IntegerField(), size=None, null=True, blank=True)),
(
'field_nested',
ArrayField(ArrayField(models.IntegerField(), size=None, null=True), size=None, null=True),
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
('order', models.IntegerField(null=True)),
("field", ArrayField(models.CharField(max_length=10), size=None)),
],
options={
'required_db_vendor': 'postgresql',
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name='CharFieldModel',
name="DateTimeArrayModel",
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', models.CharField(max_length=64)),
],
options=None,
bases=None,
),
migrations.CreateModel(
name='TextFieldModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', models.TextField()),
],
options=None,
bases=None,
),
migrations.CreateModel(
name='SmallAutoFieldModel',
fields=[
('id', models.SmallAutoField(verbose_name='ID', serialize=False, primary_key=True)),
],
options=None,
),
migrations.CreateModel(
name='BigAutoFieldModel',
fields=[
('id', models.BigAutoField(verbose_name='ID', serialize=False, primary_key=True)),
],
options=None,
),
migrations.CreateModel(
name='Scene',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('scene', models.TextField()),
('setting', models.CharField(max_length=255)),
],
options=None,
bases=None,
),
migrations.CreateModel(
name='Character',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('name', models.CharField(max_length=255)),
],
options=None,
bases=None,
),
migrations.CreateModel(
name='CITestModel',
fields=[
('name', CICharField(primary_key=True, max_length=255)),
('email', CIEmailField()),
('description', CITextField()),
('array_field', ArrayField(CITextField(), null=True)),
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("datetimes", ArrayField(models.DateTimeField(), size=None)),
("dates", ArrayField(models.DateField(), size=None)),
("times", ArrayField(models.TimeField(), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=None,
),
migrations.CreateModel(
name='Line',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('scene', models.ForeignKey('postgres_tests.Scene', on_delete=models.SET_NULL)),
('character', models.ForeignKey('postgres_tests.Character', on_delete=models.SET_NULL)),
('dialogue', models.TextField(blank=True, null=True)),
('dialogue_search_vector', SearchVectorField(blank=True, null=True)),
('dialogue_config', models.CharField(max_length=100, blank=True, null=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=None,
),
migrations.CreateModel(
name='LineSavedSearch',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('line', models.ForeignKey('postgres_tests.Line', on_delete=models.CASCADE)),
('query', models.CharField(max_length=100)),
],
options={
'required_db_vendor': 'postgresql',
},
),
migrations.CreateModel(
name='AggregateTestModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('boolean_field', models.BooleanField(null=True)),
('char_field', models.CharField(max_length=30, blank=True)),
('text_field', models.TextField(blank=True)),
('integer_field', models.IntegerField(null=True)),
('json_field', models.JSONField(null=True)),
],
options={
'required_db_vendor': 'postgresql',
},
),
migrations.CreateModel(
name='StatTestModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('int1', models.IntegerField()),
('int2', models.IntegerField()),
('related_field', models.ForeignKey(
'postgres_tests.AggregateTestModel',
models.SET_NULL,
null=True,
)),
],
options={
'required_db_vendor': 'postgresql',
},
),
migrations.CreateModel(
name='NowTestModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('when', models.DateTimeField(null=True, default=None)),
]
),
migrations.CreateModel(
name='UUIDTestModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('uuid', models.UUIDField(default=None, null=True)),
]
),
migrations.CreateModel(
name='RangesModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('ints', IntegerRangeField(null=True, blank=True)),
('bigints', BigIntegerRangeField(null=True, blank=True)),
('decimals', DecimalRangeField(null=True, blank=True)),
('timestamps', DateTimeRangeField(null=True, blank=True)),
('timestamps_inner', DateTimeRangeField(null=True, blank=True)),
('timestamps_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')),
('dates', DateRangeField(null=True, blank=True)),
('dates_inner', DateRangeField(null=True, blank=True)),
],
options={
'required_db_vendor': 'postgresql'
},
bases=(models.Model,)
),
migrations.CreateModel(
name='RangeLookupsModel',
fields=[
('parent', models.ForeignKey(
'postgres_tests.RangesModel',
models.SET_NULL,
blank=True, null=True,
)),
('integer', models.IntegerField(blank=True, null=True)),
('big_integer', models.BigIntegerField(blank=True, null=True)),
('float', models.FloatField(blank=True, null=True)),
('timestamp', models.DateTimeField(blank=True, null=True)),
('date', models.DateField(blank=True, null=True)),
('small_integer', models.SmallIntegerField(blank=True, null=True)),
('decimal_field', models.DecimalField(max_digits=5, decimal_places=2, blank=True, null=True)),
],
options={
'required_db_vendor': 'postgresql',
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name='ArrayEnumModel',
name="HStoreModel",
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('array_of_enums', ArrayField(EnumField(max_length=20), null=True, blank=True)),
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("field", HStoreField(blank=True, null=True)),
("array_field", ArrayField(HStoreField(), null=True)),
],
options={
'required_db_vendor': 'postgresql',
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name='Room',
name="OtherTypesArrayModel",
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('number', models.IntegerField(unique=True)),
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"ips",
ArrayField(models.GenericIPAddressField(), size=None, default=list),
),
("uuids", ArrayField(models.UUIDField(), size=None, default=list)),
(
"decimals",
ArrayField(
models.DecimalField(max_digits=5, decimal_places=2),
size=None,
default=list,
),
),
("tags", ArrayField(TagField(), blank=True, null=True, size=None)),
("json", ArrayField(models.JSONField(default={}), default=[])),
("int_ranges", ArrayField(IntegerRangeField(), null=True, blank=True)),
(
"bigint_ranges",
ArrayField(BigIntegerRangeField(), null=True, blank=True),
),
],
options={
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="IntegerArrayModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("field", ArrayField(models.IntegerField(), size=None)),
],
options={
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="NestedIntegerArrayModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"field",
ArrayField(ArrayField(models.IntegerField(), size=None), size=None),
),
],
options={
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="NullableIntegerArrayModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"field",
ArrayField(models.IntegerField(), size=None, null=True, blank=True),
),
(
"field_nested",
ArrayField(
ArrayField(models.IntegerField(), size=None, null=True),
size=None,
null=True,
),
),
("order", models.IntegerField(null=True)),
],
options={
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="CharFieldModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("field", models.CharField(max_length=64)),
],
options=None,
bases=None,
),
migrations.CreateModel(
name="TextFieldModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("field", models.TextField()),
],
options=None,
bases=None,
),
migrations.CreateModel(
name="SmallAutoFieldModel",
fields=[
(
"id",
models.SmallAutoField(
verbose_name="ID", serialize=False, primary_key=True
),
),
],
options=None,
),
migrations.CreateModel(
name="BigAutoFieldModel",
fields=[
(
"id",
models.BigAutoField(
verbose_name="ID", serialize=False, primary_key=True
),
),
],
options=None,
),
migrations.CreateModel(
name="Scene",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("scene", models.TextField()),
("setting", models.CharField(max_length=255)),
],
options=None,
bases=None,
),
migrations.CreateModel(
name="Character",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("name", models.CharField(max_length=255)),
],
options=None,
bases=None,
),
migrations.CreateModel(
name="CITestModel",
fields=[
("name", CICharField(primary_key=True, max_length=255)),
("email", CIEmailField()),
("description", CITextField()),
("array_field", ArrayField(CITextField(), null=True)),
],
options={
"required_db_vendor": "postgresql",
},
bases=None,
),
migrations.CreateModel(
name="Line",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"scene",
models.ForeignKey(
"postgres_tests.Scene", on_delete=models.SET_NULL
),
),
(
"character",
models.ForeignKey(
"postgres_tests.Character", on_delete=models.SET_NULL
),
),
("dialogue", models.TextField(blank=True, null=True)),
("dialogue_search_vector", SearchVectorField(blank=True, null=True)),
(
"dialogue_config",
models.CharField(max_length=100, blank=True, null=True),
),
],
options={
"required_db_vendor": "postgresql",
},
bases=None,
),
migrations.CreateModel(
name="LineSavedSearch",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"line",
models.ForeignKey("postgres_tests.Line", on_delete=models.CASCADE),
),
("query", models.CharField(max_length=100)),
],
options={
"required_db_vendor": "postgresql",
},
),
migrations.CreateModel(
name="AggregateTestModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("boolean_field", models.BooleanField(null=True)),
("char_field", models.CharField(max_length=30, blank=True)),
("text_field", models.TextField(blank=True)),
("integer_field", models.IntegerField(null=True)),
("json_field", models.JSONField(null=True)),
],
options={
"required_db_vendor": "postgresql",
},
),
migrations.CreateModel(
name="StatTestModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("int1", models.IntegerField()),
("int2", models.IntegerField()),
(
"related_field",
models.ForeignKey(
"postgres_tests.AggregateTestModel",
models.SET_NULL,
null=True,
),
),
],
options={
"required_db_vendor": "postgresql",
},
),
migrations.CreateModel(
name="NowTestModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("when", models.DateTimeField(null=True, default=None)),
],
),
migrations.CreateModel(
name='HotelReservation',
name="UUIDTestModel",
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('room', models.ForeignKey('postgres_tests.Room', models.CASCADE)),
('datespan', DateRangeField()),
('start', models.DateTimeField()),
('end', models.DateTimeField()),
('cancelled', models.BooleanField(default=False)),
('requirements', models.JSONField(blank=True, null=True)),
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("uuid", models.UUIDField(default=None, null=True)),
],
),
migrations.CreateModel(
name="RangesModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("ints", IntegerRangeField(null=True, blank=True)),
("bigints", BigIntegerRangeField(null=True, blank=True)),
("decimals", DecimalRangeField(null=True, blank=True)),
("timestamps", DateTimeRangeField(null=True, blank=True)),
("timestamps_inner", DateTimeRangeField(null=True, blank=True)),
(
"timestamps_closed_bounds",
DateTimeRangeField(null=True, blank=True, default_bounds="[]"),
),
("dates", DateRangeField(null=True, blank=True)),
("dates_inner", DateRangeField(null=True, blank=True)),
],
options={"required_db_vendor": "postgresql"},
bases=(models.Model,),
),
migrations.CreateModel(
name="RangeLookupsModel",
fields=[
(
"parent",
models.ForeignKey(
"postgres_tests.RangesModel",
models.SET_NULL,
blank=True,
null=True,
),
),
("integer", models.IntegerField(blank=True, null=True)),
("big_integer", models.BigIntegerField(blank=True, null=True)),
("float", models.FloatField(blank=True, null=True)),
("timestamp", models.DateTimeField(blank=True, null=True)),
("date", models.DateField(blank=True, null=True)),
("small_integer", models.SmallIntegerField(blank=True, null=True)),
(
"decimal_field",
models.DecimalField(
max_digits=5, decimal_places=2, blank=True, null=True
),
),
],
options={
'required_db_vendor': 'postgresql',
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="ArrayEnumModel",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"array_of_enums",
ArrayField(EnumField(max_length=20), null=True, blank=True),
),
],
options={
"required_db_vendor": "postgresql",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="Room",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("number", models.IntegerField(unique=True)),
],
),
migrations.CreateModel(
name="HotelReservation",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
("room", models.ForeignKey("postgres_tests.Room", models.CASCADE)),
("datespan", DateRangeField()),
("start", models.DateTimeField()),
("end", models.DateTimeField()),
("cancelled", models.BooleanField(default=False)),
("requirements", models.JSONField(blank=True, null=True)),
],
options={
"required_db_vendor": "postgresql",
},
),
]

View file

@ -1,9 +1,18 @@
from django.db import models
from .fields import (
ArrayField, BigIntegerRangeField, CICharField, CIEmailField, CITextField,
DateRangeField, DateTimeRangeField, DecimalRangeField, EnumField,
HStoreField, IntegerRangeField, SearchVectorField,
ArrayField,
BigIntegerRangeField,
CICharField,
CIEmailField,
CITextField,
DateRangeField,
DateTimeRangeField,
DecimalRangeField,
EnumField,
HStoreField,
IntegerRangeField,
SearchVectorField,
)
@ -16,7 +25,6 @@ class Tag:
class TagField(models.SmallIntegerField):
def from_db_value(self, value, expression, connection):
if value is None:
return value
@ -36,7 +44,7 @@ class TagField(models.SmallIntegerField):
class PostgreSQLModel(models.Model):
class Meta:
abstract = True
required_db_vendor = 'postgresql'
required_db_vendor = "postgresql"
class IntegerArrayModel(PostgreSQLModel):
@ -66,7 +74,9 @@ class NestedIntegerArrayModel(PostgreSQLModel):
class OtherTypesArrayModel(PostgreSQLModel):
ips = ArrayField(models.GenericIPAddressField(), default=list)
uuids = ArrayField(models.UUIDField(), default=list)
decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list)
decimals = ArrayField(
models.DecimalField(max_digits=5, decimal_places=2), default=list
)
tags = ArrayField(TagField(), blank=True, null=True)
json = ArrayField(models.JSONField(default=dict), default=list)
int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
@ -117,15 +127,15 @@ class CITestModel(PostgreSQLModel):
class Line(PostgreSQLModel):
scene = models.ForeignKey('Scene', models.CASCADE)
character = models.ForeignKey('Character', models.CASCADE)
scene = models.ForeignKey("Scene", models.CASCADE)
character = models.ForeignKey("Character", models.CASCADE)
dialogue = models.TextField(blank=True, null=True)
dialogue_search_vector = SearchVectorField(blank=True, null=True)
dialogue_config = models.CharField(max_length=100, blank=True, null=True)
class LineSavedSearch(PostgreSQLModel):
line = models.ForeignKey('Line', models.CASCADE)
line = models.ForeignKey("Line", models.CASCADE)
query = models.CharField(max_length=100)
@ -136,7 +146,9 @@ class RangesModel(PostgreSQLModel):
timestamps = DateTimeRangeField(blank=True, null=True)
timestamps_inner = DateTimeRangeField(blank=True, null=True)
timestamps_closed_bounds = DateTimeRangeField(
blank=True, null=True, default_bounds='[]',
blank=True,
null=True,
default_bounds="[]",
)
dates = DateRangeField(blank=True, null=True)
dates_inner = DateRangeField(blank=True, null=True)
@ -150,7 +162,9 @@ class RangeLookupsModel(PostgreSQLModel):
timestamp = models.DateTimeField(blank=True, null=True)
date = models.DateField(blank=True, null=True)
small_integer = models.SmallIntegerField(blank=True, null=True)
decimal_field = models.DecimalField(max_digits=5, decimal_places=2, blank=True, null=True)
decimal_field = models.DecimalField(
max_digits=5, decimal_places=2, blank=True, null=True
)
class ArrayFieldSubclass(ArrayField):
@ -162,6 +176,7 @@ class AggregateTestModel(PostgreSQLModel):
"""
To test postgres-specific general aggregation functions
"""
char_field = models.CharField(max_length=30, blank=True)
text_field = models.TextField(blank=True)
integer_field = models.IntegerField(null=True)
@ -173,6 +188,7 @@ class StatTestModel(PostgreSQLModel):
"""
To test postgres-specific aggregation functions for statistics
"""
int1 = models.IntegerField()
int2 = models.IntegerField()
related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
@ -191,7 +207,7 @@ class Room(models.Model):
class HotelReservation(PostgreSQLModel):
room = models.ForeignKey('Room', on_delete=models.CASCADE)
room = models.ForeignKey("Room", on_delete=models.CASCADE)
datespan = DateRangeField()
start = models.DateTimeField()
end = models.DateTimeField()

File diff suppressed because it is too large Load diff

View file

@ -7,12 +7,12 @@ from django.test.utils import modify_settings
from . import PostgreSQLTestCase
try:
from psycopg2.extras import (
DateRange, DateTimeRange, DateTimeTZRange, NumericRange,
)
from psycopg2.extras import DateRange, DateTimeRange, DateTimeTZRange, NumericRange
from django.contrib.postgres.fields import (
DateRangeField, DateTimeRangeField, DecimalRangeField,
DateRangeField,
DateTimeRangeField,
DecimalRangeField,
IntegerRangeField,
)
except ImportError:
@ -22,17 +22,24 @@ except ImportError:
class PostgresConfigTests(PostgreSQLTestCase):
def test_register_type_handlers_connection(self):
from django.contrib.postgres.signals import register_type_handlers
self.assertNotIn(register_type_handlers, connection_created._live_receivers(None))
with modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}):
self.assertIn(register_type_handlers, connection_created._live_receivers(None))
self.assertNotIn(register_type_handlers, connection_created._live_receivers(None))
self.assertNotIn(
register_type_handlers, connection_created._live_receivers(None)
)
with modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
self.assertIn(
register_type_handlers, connection_created._live_receivers(None)
)
self.assertNotIn(
register_type_handlers, connection_created._live_receivers(None)
)
def test_register_serializer_for_migrations(self):
tests = (
(DateRange(empty=True), DateRangeField),
(DateTimeRange(empty=True), DateRangeField),
(DateTimeTZRange(None, None, '[]'), DateTimeRangeField),
(NumericRange(Decimal('1.0'), Decimal('5.0'), '()'), DecimalRangeField),
(DateTimeTZRange(None, None, "[]"), DateTimeRangeField),
(NumericRange(Decimal("1.0"), Decimal("5.0"), "()"), DecimalRangeField),
(NumericRange(1, 10), IntegerRangeField),
)
@ -40,25 +47,31 @@ class PostgresConfigTests(PostgreSQLTestCase):
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
with self.assertRaisesMessage(ValueError, 'Cannot serialize: %s' % default.__class__.__name__):
with self.assertRaisesMessage(
ValueError, "Cannot serialize: %s" % default.__class__.__name__
):
MigrationWriter.serialize(field)
assertNotSerializable()
with self.modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}):
with self.modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
for default, test_field in tests:
with self.subTest(default=default):
field = test_field(default=default)
serialized_field, imports = MigrationWriter.serialize(field)
self.assertEqual(imports, {
'import django.contrib.postgres.fields.ranges',
'import psycopg2.extras',
})
self.assertEqual(
imports,
{
"import django.contrib.postgres.fields.ranges",
"import psycopg2.extras",
},
)
self.assertIn(
'%s.%s(default=psycopg2.extras.%r)' % (
"%s.%s(default=psycopg2.extras.%r)"
% (
field.__module__,
field.__class__.__name__,
default,
),
serialized_field
serialized_field,
)
assertNotSerializable()

File diff suppressed because it is too large Load diff

View file

@ -2,8 +2,12 @@ from datetime import date
from . import PostgreSQLTestCase
from .models import (
HStoreModel, IntegerArrayModel, NestedIntegerArrayModel,
NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel,
HStoreModel,
IntegerArrayModel,
NestedIntegerArrayModel,
NullableIntegerArrayModel,
OtherTypesArrayModel,
RangesModel,
)
try:
@ -15,19 +19,28 @@ except ImportError:
class BulkSaveTests(PostgreSQLTestCase):
def test_bulk_update(self):
test_data = [
(IntegerArrayModel, 'field', [], [1, 2, 3]),
(NullableIntegerArrayModel, 'field', [1, 2, 3], None),
(NestedIntegerArrayModel, 'field', [], [[1, 2, 3]]),
(HStoreModel, 'field', {}, {1: 2}),
(RangesModel, 'ints', None, NumericRange(lower=1, upper=10)),
(RangesModel, 'dates', None, DateRange(lower=date.today(), upper=date.today())),
(OtherTypesArrayModel, 'ips', [], ['1.2.3.4']),
(OtherTypesArrayModel, 'json', [], [{'a': 'b'}])
(IntegerArrayModel, "field", [], [1, 2, 3]),
(NullableIntegerArrayModel, "field", [1, 2, 3], None),
(NestedIntegerArrayModel, "field", [], [[1, 2, 3]]),
(HStoreModel, "field", {}, {1: 2}),
(RangesModel, "ints", None, NumericRange(lower=1, upper=10)),
(
RangesModel,
"dates",
None,
DateRange(lower=date.today(), upper=date.today()),
),
(OtherTypesArrayModel, "ips", [], ["1.2.3.4"]),
(OtherTypesArrayModel, "json", [], [{"a": "b"}]),
]
for Model, field, initial, new in test_data:
with self.subTest(model=Model, field=field):
instances = Model.objects.bulk_create(Model(**{field: initial}) for _ in range(20))
instances = Model.objects.bulk_create(
Model(**{field: initial}) for _ in range(20)
)
for instance in instances:
setattr(instance, field, new)
Model.objects.bulk_update(instances, [field])
self.assertSequenceEqual(Model.objects.filter(**{field: new}), instances)
self.assertSequenceEqual(
Model.objects.filter(**{field: new}), instances
)

View file

@ -10,26 +10,35 @@ from . import PostgreSQLTestCase
from .models import CITestModel
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class CITextTestCase(PostgreSQLTestCase):
case_sensitive_lookups = ('contains', 'startswith', 'endswith', 'regex')
case_sensitive_lookups = ("contains", "startswith", "endswith", "regex")
@classmethod
def setUpTestData(cls):
cls.john = CITestModel.objects.create(
name='JoHn',
email='joHn@johN.com',
description='Average Joe named JoHn',
array_field=['JoE', 'jOhn'],
name="JoHn",
email="joHn@johN.com",
description="Average Joe named JoHn",
array_field=["JoE", "jOhn"],
)
def test_equal_lowercase(self):
"""
citext removes the need for iexact as the index is case-insensitive.
"""
self.assertEqual(CITestModel.objects.filter(name=self.john.name.lower()).count(), 1)
self.assertEqual(CITestModel.objects.filter(email=self.john.email.lower()).count(), 1)
self.assertEqual(CITestModel.objects.filter(description=self.john.description.lower()).count(), 1)
self.assertEqual(
CITestModel.objects.filter(name=self.john.name.lower()).count(), 1
)
self.assertEqual(
CITestModel.objects.filter(email=self.john.email.lower()).count(), 1
)
self.assertEqual(
CITestModel.objects.filter(
description=self.john.description.lower()
).count(),
1,
)
def test_fail_citext_primary_key(self):
"""
@ -37,27 +46,39 @@ class CITextTestCase(PostgreSQLTestCase):
clashes with an existing value isn't allowed.
"""
with self.assertRaises(IntegrityError):
CITestModel.objects.create(name='John')
CITestModel.objects.create(name="John")
def test_array_field(self):
instance = CITestModel.objects.get()
self.assertEqual(instance.array_field, self.john.array_field)
self.assertTrue(CITestModel.objects.filter(array_field__contains=['joe']).exists())
self.assertTrue(
CITestModel.objects.filter(array_field__contains=["joe"]).exists()
)
def test_lookups_name_char(self):
for lookup in self.case_sensitive_lookups:
with self.subTest(lookup=lookup):
query = {'name__{}'.format(lookup): 'john'}
self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john])
query = {"name__{}".format(lookup): "john"}
self.assertSequenceEqual(
CITestModel.objects.filter(**query), [self.john]
)
def test_lookups_description_text(self):
for lookup, string in zip(self.case_sensitive_lookups, ('average', 'average joe', 'john', 'Joe.named')):
for lookup, string in zip(
self.case_sensitive_lookups, ("average", "average joe", "john", "Joe.named")
):
with self.subTest(lookup=lookup, string=string):
query = {'description__{}'.format(lookup): string}
self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john])
query = {"description__{}".format(lookup): string}
self.assertSequenceEqual(
CITestModel.objects.filter(**query), [self.john]
)
def test_lookups_email(self):
for lookup, string in zip(self.case_sensitive_lookups, ('john', 'john', 'john.com', 'john.com')):
for lookup, string in zip(
self.case_sensitive_lookups, ("john", "john", "john.com", "john.com")
):
with self.subTest(lookup=lookup, string=string):
query = {'email__{}'.format(lookup): string}
self.assertSequenceEqual(CITestModel.objects.filter(**query), [self.john])
query = {"email__{}".format(lookup): string}
self.assertSequenceEqual(
CITestModel.objects.filter(**query), [self.john]
)

File diff suppressed because it is too large Load diff

View file

@ -9,7 +9,6 @@ from .models import NowTestModel, UUIDTestModel
class TestTransactionNow(PostgreSQLTestCase):
def test_transaction_now(self):
"""
The test case puts everything under a transaction, so two models
@ -30,7 +29,6 @@ class TestTransactionNow(PostgreSQLTestCase):
class TestRandomUUID(PostgreSQLTestCase):
def test_random_uuid(self):
m1 = UUIDTestModel.objects.create()
m2 = UUIDTestModel.objects.create()

View file

@ -21,7 +21,7 @@ except ImportError:
class SimpleTests(PostgreSQLTestCase):
def test_save_load_success(self):
value = {'a': 'b'}
value = {"a": "b"}
instance = HStoreModel(field=value)
instance.save()
reloaded = HStoreModel.objects.get()
@ -34,15 +34,15 @@ class SimpleTests(PostgreSQLTestCase):
self.assertIsNone(reloaded.field)
def test_value_null(self):
value = {'a': None}
value = {"a": None}
instance = HStoreModel(field=value)
instance.save()
reloaded = HStoreModel.objects.get()
self.assertEqual(reloaded.field, value)
def test_key_val_cast_to_string(self):
value = {'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'}
expected_value = {'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'}
value = {"a": 1, "b": "B", 2: "c", "ï": "ê"}
expected_value = {"a": "1", "b": "B", "2": "c", "ï": "ê"}
instance = HStoreModel.objects.create(field=value)
instance = HStoreModel.objects.get()
@ -51,17 +51,17 @@ class SimpleTests(PostgreSQLTestCase):
instance = HStoreModel.objects.get(field__a=1)
self.assertEqual(instance.field, expected_value)
instance = HStoreModel.objects.get(field__has_keys=[2, 'a', 'ï'])
instance = HStoreModel.objects.get(field__has_keys=[2, "a", "ï"])
self.assertEqual(instance.field, expected_value)
def test_array_field(self):
value = [
{'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'},
{'a': 1, 'b': 'B', 2: 'c', 'ï': 'ê'},
{"a": 1, "b": "B", 2: "c", "ï": "ê"},
{"a": 1, "b": "B", 2: "c", "ï": "ê"},
]
expected_value = [
{'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'},
{'a': '1', 'b': 'B', '2': 'c', 'ï': 'ê'},
{"a": "1", "b": "B", "2": "c", "ï": "ê"},
{"a": "1", "b": "B", "2": "c", "ï": "ê"},
]
instance = HStoreModel.objects.create(array_field=value)
instance.refresh_from_db()
@ -69,231 +69,225 @@ class SimpleTests(PostgreSQLTestCase):
class TestQuerying(PostgreSQLTestCase):
@classmethod
def setUpTestData(cls):
cls.objs = HStoreModel.objects.bulk_create([
HStoreModel(field={'a': 'b'}),
HStoreModel(field={'a': 'b', 'c': 'd'}),
HStoreModel(field={'c': 'd'}),
HStoreModel(field={}),
HStoreModel(field=None),
HStoreModel(field={'cat': 'TigrOu', 'breed': 'birman'}),
HStoreModel(field={'cat': 'minou', 'breed': 'ragdoll'}),
HStoreModel(field={'cat': 'kitty', 'breed': 'Persian'}),
HStoreModel(field={'cat': 'Kit Kat', 'breed': 'persian'}),
])
cls.objs = HStoreModel.objects.bulk_create(
[
HStoreModel(field={"a": "b"}),
HStoreModel(field={"a": "b", "c": "d"}),
HStoreModel(field={"c": "d"}),
HStoreModel(field={}),
HStoreModel(field=None),
HStoreModel(field={"cat": "TigrOu", "breed": "birman"}),
HStoreModel(field={"cat": "minou", "breed": "ragdoll"}),
HStoreModel(field={"cat": "kitty", "breed": "Persian"}),
HStoreModel(field={"cat": "Kit Kat", "breed": "persian"}),
]
)
def test_exact(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__exact={'a': 'b'}),
self.objs[:1]
HStoreModel.objects.filter(field__exact={"a": "b"}), self.objs[:1]
)
def test_contained_by(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__contained_by={'a': 'b', 'c': 'd'}),
self.objs[:4]
HStoreModel.objects.filter(field__contained_by={"a": "b", "c": "d"}),
self.objs[:4],
)
def test_contains(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__contains={'a': 'b'}),
self.objs[:2]
HStoreModel.objects.filter(field__contains={"a": "b"}), self.objs[:2]
)
def test_in_generator(self):
def search():
yield {'a': 'b'}
yield {"a": "b"}
self.assertSequenceEqual(
HStoreModel.objects.filter(field__in=search()),
self.objs[:1]
HStoreModel.objects.filter(field__in=search()), self.objs[:1]
)
def test_has_key(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__has_key='c'),
self.objs[1:3]
HStoreModel.objects.filter(field__has_key="c"), self.objs[1:3]
)
def test_has_keys(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__has_keys=['a', 'c']),
self.objs[1:2]
HStoreModel.objects.filter(field__has_keys=["a", "c"]), self.objs[1:2]
)
def test_has_any_keys(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__has_any_keys=['a', 'c']),
self.objs[:3]
HStoreModel.objects.filter(field__has_any_keys=["a", "c"]), self.objs[:3]
)
def test_key_transform(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__a='b'),
self.objs[:2]
HStoreModel.objects.filter(field__a="b"), self.objs[:2]
)
def test_key_transform_raw_expression(self):
expr = RawSQL('%s::hstore', ['x => b, y => c'])
expr = RawSQL("%s::hstore", ["x => b, y => c"])
self.assertSequenceEqual(
HStoreModel.objects.filter(field__a=KeyTransform('x', expr)),
self.objs[:2]
HStoreModel.objects.filter(field__a=KeyTransform("x", expr)), self.objs[:2]
)
def test_key_transform_annotation(self):
qs = HStoreModel.objects.annotate(a=F('field__a'))
qs = HStoreModel.objects.annotate(a=F("field__a"))
self.assertCountEqual(
qs.values_list('a', flat=True),
['b', 'b', None, None, None, None, None, None, None],
qs.values_list("a", flat=True),
["b", "b", None, None, None, None, None, None, None],
)
def test_keys(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__keys=['a']),
self.objs[:1]
HStoreModel.objects.filter(field__keys=["a"]), self.objs[:1]
)
def test_values(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__values=['b']),
self.objs[:1]
HStoreModel.objects.filter(field__values=["b"]), self.objs[:1]
)
def test_field_chaining_contains(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__a__contains='b'),
self.objs[:2]
HStoreModel.objects.filter(field__a__contains="b"), self.objs[:2]
)
def test_field_chaining_icontains(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__cat__icontains='INo'),
HStoreModel.objects.filter(field__cat__icontains="INo"),
[self.objs[6]],
)
def test_field_chaining_startswith(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__cat__startswith='kit'),
HStoreModel.objects.filter(field__cat__startswith="kit"),
[self.objs[7]],
)
def test_field_chaining_istartswith(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__cat__istartswith='kit'),
HStoreModel.objects.filter(field__cat__istartswith="kit"),
self.objs[7:],
)
def test_field_chaining_endswith(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__cat__endswith='ou'),
HStoreModel.objects.filter(field__cat__endswith="ou"),
[self.objs[6]],
)
def test_field_chaining_iendswith(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__cat__iendswith='ou'),
HStoreModel.objects.filter(field__cat__iendswith="ou"),
self.objs[5:7],
)
def test_field_chaining_iexact(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__breed__iexact='persian'),
HStoreModel.objects.filter(field__breed__iexact="persian"),
self.objs[7:],
)
def test_field_chaining_regex(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__cat__regex=r'ou$'),
HStoreModel.objects.filter(field__cat__regex=r"ou$"),
[self.objs[6]],
)
def test_field_chaining_iregex(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__cat__iregex=r'oU$'),
HStoreModel.objects.filter(field__cat__iregex=r"oU$"),
self.objs[5:7],
)
def test_order_by_field(self):
more_objs = (
HStoreModel.objects.create(field={'g': '637'}),
HStoreModel.objects.create(field={'g': '002'}),
HStoreModel.objects.create(field={'g': '042'}),
HStoreModel.objects.create(field={'g': '981'}),
HStoreModel.objects.create(field={"g": "637"}),
HStoreModel.objects.create(field={"g": "002"}),
HStoreModel.objects.create(field={"g": "042"}),
HStoreModel.objects.create(field={"g": "981"}),
)
self.assertSequenceEqual(
HStoreModel.objects.filter(field__has_key='g').order_by('field__g'),
[more_objs[1], more_objs[2], more_objs[0], more_objs[3]]
HStoreModel.objects.filter(field__has_key="g").order_by("field__g"),
[more_objs[1], more_objs[2], more_objs[0], more_objs[3]],
)
def test_keys_contains(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__keys__contains=['a']),
self.objs[:2]
HStoreModel.objects.filter(field__keys__contains=["a"]), self.objs[:2]
)
def test_values_overlap(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(field__values__overlap=['b', 'd']),
self.objs[:3]
HStoreModel.objects.filter(field__values__overlap=["b", "d"]), self.objs[:3]
)
def test_key_isnull(self):
obj = HStoreModel.objects.create(field={'a': None})
obj = HStoreModel.objects.create(field={"a": None})
self.assertSequenceEqual(
HStoreModel.objects.filter(field__a__isnull=True),
self.objs[2:9] + [obj],
)
self.assertSequenceEqual(
HStoreModel.objects.filter(field__a__isnull=False),
self.objs[:2]
HStoreModel.objects.filter(field__a__isnull=False), self.objs[:2]
)
def test_usage_in_subquery(self):
self.assertSequenceEqual(
HStoreModel.objects.filter(id__in=HStoreModel.objects.filter(field__a='b')),
self.objs[:2]
HStoreModel.objects.filter(id__in=HStoreModel.objects.filter(field__a="b")),
self.objs[:2],
)
def test_key_sql_injection(self):
with CaptureQueriesContext(connection) as queries:
self.assertFalse(
HStoreModel.objects.filter(**{
"field__test' = 'a') OR 1 = 1 OR ('d": 'x',
}).exists()
HStoreModel.objects.filter(
**{
"field__test' = 'a') OR 1 = 1 OR ('d": "x",
}
).exists()
)
self.assertIn(
"""."field" -> 'test'' = ''a'') OR 1 = 1 OR (''d') = 'x' """,
queries[0]['sql'],
queries[0]["sql"],
)
def test_obj_subquery_lookup(self):
qs = HStoreModel.objects.annotate(
value=Subquery(HStoreModel.objects.filter(pk=OuterRef('pk')).values('field')),
).filter(value__a='b')
value=Subquery(
HStoreModel.objects.filter(pk=OuterRef("pk")).values("field")
),
).filter(value__a="b")
self.assertSequenceEqual(qs, self.objs[:2])
@isolate_apps('postgres_tests')
@isolate_apps("postgres_tests")
class TestChecks(PostgreSQLSimpleTestCase):
def test_invalid_default(self):
class MyModel(PostgreSQLModel):
field = HStoreField(default={})
model = MyModel()
self.assertEqual(model.check(), [
checks.Warning(
msg=(
"HStoreField default should be a callable instead of an "
"instance so that it's not shared between all field "
"instances."
),
hint='Use a callable instead, e.g., use `dict` instead of `{}`.',
obj=MyModel._meta.get_field('field'),
id='fields.E010',
)
])
self.assertEqual(
model.check(),
[
checks.Warning(
msg=(
"HStoreField default should be a callable instead of an "
"instance so that it's not shared between all field "
"instances."
),
hint="Use a callable instead, e.g., use `dict` instead of `{}`.",
obj=MyModel._meta.get_field("field"),
id="fields.E010",
)
],
)
def test_valid_default(self):
class MyModel(PostgreSQLModel):
@ -303,83 +297,90 @@ class TestChecks(PostgreSQLSimpleTestCase):
class TestSerialization(PostgreSQLSimpleTestCase):
test_data = json.dumps([{
'model': 'postgres_tests.hstoremodel',
'pk': None,
'fields': {
'field': json.dumps({'a': 'b'}),
'array_field': json.dumps([
json.dumps({'a': 'b'}),
json.dumps({'b': 'a'}),
]),
},
}])
test_data = json.dumps(
[
{
"model": "postgres_tests.hstoremodel",
"pk": None,
"fields": {
"field": json.dumps({"a": "b"}),
"array_field": json.dumps(
[
json.dumps({"a": "b"}),
json.dumps({"b": "a"}),
]
),
},
}
]
)
def test_dumping(self):
instance = HStoreModel(field={'a': 'b'}, array_field=[{'a': 'b'}, {'b': 'a'}])
data = serializers.serialize('json', [instance])
instance = HStoreModel(field={"a": "b"}, array_field=[{"a": "b"}, {"b": "a"}])
data = serializers.serialize("json", [instance])
self.assertEqual(json.loads(data), json.loads(self.test_data))
def test_loading(self):
instance = list(serializers.deserialize('json', self.test_data))[0].object
self.assertEqual(instance.field, {'a': 'b'})
self.assertEqual(instance.array_field, [{'a': 'b'}, {'b': 'a'}])
instance = list(serializers.deserialize("json", self.test_data))[0].object
self.assertEqual(instance.field, {"a": "b"})
self.assertEqual(instance.array_field, [{"a": "b"}, {"b": "a"}])
def test_roundtrip_with_null(self):
instance = HStoreModel(field={'a': 'b', 'c': None})
data = serializers.serialize('json', [instance])
new_instance = list(serializers.deserialize('json', data))[0].object
instance = HStoreModel(field={"a": "b", "c": None})
data = serializers.serialize("json", [instance])
new_instance = list(serializers.deserialize("json", data))[0].object
self.assertEqual(instance.field, new_instance.field)
class TestValidation(PostgreSQLSimpleTestCase):
def test_not_a_string(self):
field = HStoreField()
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean({'a': 1}, None)
self.assertEqual(cm.exception.code, 'not_a_string')
self.assertEqual(cm.exception.message % cm.exception.params, 'The value of “a” is not a string or null.')
field.clean({"a": 1}, None)
self.assertEqual(cm.exception.code, "not_a_string")
self.assertEqual(
cm.exception.message % cm.exception.params,
"The value of “a” is not a string or null.",
)
def test_none_allowed_as_value(self):
field = HStoreField()
self.assertEqual(field.clean({'a': None}, None), {'a': None})
self.assertEqual(field.clean({"a": None}, None), {"a": None})
class TestFormField(PostgreSQLSimpleTestCase):
def test_valid(self):
field = forms.HStoreField()
value = field.clean('{"a": "b"}')
self.assertEqual(value, {'a': 'b'})
self.assertEqual(value, {"a": "b"})
def test_invalid_json(self):
field = forms.HStoreField()
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('{"a": "b"')
self.assertEqual(cm.exception.messages[0], 'Could not load JSON data.')
self.assertEqual(cm.exception.code, 'invalid_json')
self.assertEqual(cm.exception.messages[0], "Could not load JSON data.")
self.assertEqual(cm.exception.code, "invalid_json")
def test_non_dict_json(self):
field = forms.HStoreField()
msg = 'Input must be a JSON dictionary.'
msg = "Input must be a JSON dictionary."
with self.assertRaisesMessage(exceptions.ValidationError, msg) as cm:
field.clean('["a", "b", 1]')
self.assertEqual(cm.exception.code, 'invalid_format')
self.assertEqual(cm.exception.code, "invalid_format")
def test_not_string_values(self):
field = forms.HStoreField()
value = field.clean('{"a": 1}')
self.assertEqual(value, {'a': '1'})
self.assertEqual(value, {"a": "1"})
def test_none_value(self):
field = forms.HStoreField()
value = field.clean('{"a": null}')
self.assertEqual(value, {'a': None})
self.assertEqual(value, {"a": None})
def test_empty(self):
field = forms.HStoreField(required=False)
value = field.clean('')
value = field.clean("")
self.assertEqual(value, {})
def test_model_field_formfield(self):
@ -390,69 +391,71 @@ class TestFormField(PostgreSQLSimpleTestCase):
def test_field_has_changed(self):
class HStoreFormTest(Form):
f1 = forms.HStoreField()
form_w_hstore = HStoreFormTest()
self.assertFalse(form_w_hstore.has_changed())
form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'})
form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'})
self.assertTrue(form_w_hstore.has_changed())
form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'}, initial={'f1': '{"a": 1}'})
form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'}, initial={"f1": '{"a": 1}'})
self.assertFalse(form_w_hstore.has_changed())
form_w_hstore = HStoreFormTest({'f1': '{"a": 2}'}, initial={'f1': '{"a": 1}'})
form_w_hstore = HStoreFormTest({"f1": '{"a": 2}'}, initial={"f1": '{"a": 1}'})
self.assertTrue(form_w_hstore.has_changed())
form_w_hstore = HStoreFormTest({'f1': '{"a": 1}'}, initial={'f1': {"a": 1}})
form_w_hstore = HStoreFormTest({"f1": '{"a": 1}'}, initial={"f1": {"a": 1}})
self.assertFalse(form_w_hstore.has_changed())
form_w_hstore = HStoreFormTest({'f1': '{"a": 2}'}, initial={'f1': {"a": 1}})
form_w_hstore = HStoreFormTest({"f1": '{"a": 2}'}, initial={"f1": {"a": 1}})
self.assertTrue(form_w_hstore.has_changed())
class TestValidator(PostgreSQLSimpleTestCase):
def test_simple_valid(self):
validator = KeysValidator(keys=['a', 'b'])
validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
validator = KeysValidator(keys=["a", "b"])
validator({"a": "foo", "b": "bar", "c": "baz"})
def test_missing_keys(self):
validator = KeysValidator(keys=['a', 'b'])
validator = KeysValidator(keys=["a", "b"])
with self.assertRaises(exceptions.ValidationError) as cm:
validator({'a': 'foo', 'c': 'baz'})
self.assertEqual(cm.exception.messages[0], 'Some keys were missing: b')
self.assertEqual(cm.exception.code, 'missing_keys')
validator({"a": "foo", "c": "baz"})
self.assertEqual(cm.exception.messages[0], "Some keys were missing: b")
self.assertEqual(cm.exception.code, "missing_keys")
def test_strict_valid(self):
validator = KeysValidator(keys=['a', 'b'], strict=True)
validator({'a': 'foo', 'b': 'bar'})
validator = KeysValidator(keys=["a", "b"], strict=True)
validator({"a": "foo", "b": "bar"})
def test_extra_keys(self):
validator = KeysValidator(keys=['a', 'b'], strict=True)
validator = KeysValidator(keys=["a", "b"], strict=True)
with self.assertRaises(exceptions.ValidationError) as cm:
validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c')
self.assertEqual(cm.exception.code, 'extra_keys')
validator({"a": "foo", "b": "bar", "c": "baz"})
self.assertEqual(cm.exception.messages[0], "Some unknown keys were provided: c")
self.assertEqual(cm.exception.code, "extra_keys")
def test_custom_messages(self):
messages = {
'missing_keys': 'Foobar',
"missing_keys": "Foobar",
}
validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages)
validator = KeysValidator(keys=["a", "b"], strict=True, messages=messages)
with self.assertRaises(exceptions.ValidationError) as cm:
validator({'a': 'foo', 'c': 'baz'})
self.assertEqual(cm.exception.messages[0], 'Foobar')
self.assertEqual(cm.exception.code, 'missing_keys')
validator({"a": "foo", "c": "baz"})
self.assertEqual(cm.exception.messages[0], "Foobar")
self.assertEqual(cm.exception.code, "missing_keys")
with self.assertRaises(exceptions.ValidationError) as cm:
validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c')
self.assertEqual(cm.exception.code, 'extra_keys')
validator({"a": "foo", "b": "bar", "c": "baz"})
self.assertEqual(cm.exception.messages[0], "Some unknown keys were provided: c")
self.assertEqual(cm.exception.code, "extra_keys")
def test_deconstruct(self):
messages = {
'missing_keys': 'Foobar',
"missing_keys": "Foobar",
}
validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages)
validator = KeysValidator(keys=["a", "b"], strict=True, messages=messages)
path, args, kwargs = validator.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.validators.KeysValidator')
self.assertEqual(path, "django.contrib.postgres.validators.KeysValidator")
self.assertEqual(args, ())
self.assertEqual(kwargs, {'keys': ['a', 'b'], 'strict': True, 'messages': messages})
self.assertEqual(
kwargs, {"keys": ["a", "b"], "strict": True, "messages": messages}
)

View file

@ -1,7 +1,13 @@
from unittest import mock
from django.contrib.postgres.indexes import (
BloomIndex, BrinIndex, BTreeIndex, GinIndex, GistIndex, HashIndex, OpClass,
BloomIndex,
BrinIndex,
BTreeIndex,
GinIndex,
GistIndex,
HashIndex,
OpClass,
SpGistIndex,
)
from django.db import NotSupportedError, connection
@ -16,191 +22,226 @@ from .models import CharFieldModel, IntegerArrayModel, Scene, TextFieldModel
class IndexTestMixin:
def test_name_auto_generation(self):
index = self.index_class(fields=['field'])
index = self.index_class(fields=["field"])
index.set_name_with_model(CharFieldModel)
self.assertRegex(index.name, r'postgres_te_field_[0-9a-f]{6}_%s' % self.index_class.suffix)
self.assertRegex(
index.name, r"postgres_te_field_[0-9a-f]{6}_%s" % self.index_class.suffix
)
def test_deconstruction_no_customization(self):
index = self.index_class(fields=['title'], name='test_title_%s' % self.index_class.suffix)
index = self.index_class(
fields=["title"], name="test_title_%s" % self.index_class.suffix
)
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.%s' % self.index_class.__name__)
self.assertEqual(
path, "django.contrib.postgres.indexes.%s" % self.index_class.__name__
)
self.assertEqual(args, ())
self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_%s' % self.index_class.suffix})
self.assertEqual(
kwargs,
{"fields": ["title"], "name": "test_title_%s" % self.index_class.suffix},
)
def test_deconstruction_with_expressions_no_customization(self):
name = f'test_title_{self.index_class.suffix}'
index = self.index_class(Lower('title'), name=name)
name = f"test_title_{self.index_class.suffix}"
index = self.index_class(Lower("title"), name=name)
path, args, kwargs = index.deconstruct()
self.assertEqual(
path,
f'django.contrib.postgres.indexes.{self.index_class.__name__}',
f"django.contrib.postgres.indexes.{self.index_class.__name__}",
)
self.assertEqual(args, (Lower('title'),))
self.assertEqual(kwargs, {'name': name})
self.assertEqual(args, (Lower("title"),))
self.assertEqual(kwargs, {"name": name})
class BloomIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = BloomIndex
def test_suffix(self):
self.assertEqual(BloomIndex.suffix, 'bloom')
self.assertEqual(BloomIndex.suffix, "bloom")
def test_deconstruction(self):
index = BloomIndex(fields=['title'], name='test_bloom', length=80, columns=[4])
index = BloomIndex(fields=["title"], name="test_bloom", length=80, columns=[4])
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.BloomIndex')
self.assertEqual(path, "django.contrib.postgres.indexes.BloomIndex")
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'fields': ['title'],
'name': 'test_bloom',
'length': 80,
'columns': [4],
})
self.assertEqual(
kwargs,
{
"fields": ["title"],
"name": "test_bloom",
"length": 80,
"columns": [4],
},
)
def test_invalid_fields(self):
msg = 'Bloom indexes support a maximum of 32 fields.'
msg = "Bloom indexes support a maximum of 32 fields."
with self.assertRaisesMessage(ValueError, msg):
BloomIndex(fields=['title'] * 33, name='test_bloom')
BloomIndex(fields=["title"] * 33, name="test_bloom")
def test_invalid_columns(self):
msg = 'BloomIndex.columns must be a list or tuple.'
msg = "BloomIndex.columns must be a list or tuple."
with self.assertRaisesMessage(ValueError, msg):
BloomIndex(fields=['title'], name='test_bloom', columns='x')
msg = 'BloomIndex.columns cannot have more values than fields.'
BloomIndex(fields=["title"], name="test_bloom", columns="x")
msg = "BloomIndex.columns cannot have more values than fields."
with self.assertRaisesMessage(ValueError, msg):
BloomIndex(fields=['title'], name='test_bloom', columns=[4, 3])
BloomIndex(fields=["title"], name="test_bloom", columns=[4, 3])
def test_invalid_columns_value(self):
msg = 'BloomIndex.columns must contain integers from 1 to 4095.'
msg = "BloomIndex.columns must contain integers from 1 to 4095."
for length in (0, 4096):
with self.subTest(length), self.assertRaisesMessage(ValueError, msg):
BloomIndex(fields=['title'], name='test_bloom', columns=[length])
BloomIndex(fields=["title"], name="test_bloom", columns=[length])
def test_invalid_length(self):
msg = 'BloomIndex.length must be None or an integer from 1 to 4096.'
msg = "BloomIndex.length must be None or an integer from 1 to 4096."
for length in (0, 4097):
with self.subTest(length), self.assertRaisesMessage(ValueError, msg):
BloomIndex(fields=['title'], name='test_bloom', length=length)
BloomIndex(fields=["title"], name="test_bloom", length=length)
class BrinIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = BrinIndex
def test_suffix(self):
self.assertEqual(BrinIndex.suffix, 'brin')
self.assertEqual(BrinIndex.suffix, "brin")
def test_deconstruction(self):
index = BrinIndex(fields=['title'], name='test_title_brin', autosummarize=True, pages_per_range=16)
index = BrinIndex(
fields=["title"],
name="test_title_brin",
autosummarize=True,
pages_per_range=16,
)
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.BrinIndex')
self.assertEqual(path, "django.contrib.postgres.indexes.BrinIndex")
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'fields': ['title'],
'name': 'test_title_brin',
'autosummarize': True,
'pages_per_range': 16,
})
self.assertEqual(
kwargs,
{
"fields": ["title"],
"name": "test_title_brin",
"autosummarize": True,
"pages_per_range": 16,
},
)
def test_invalid_pages_per_range(self):
with self.assertRaisesMessage(ValueError, 'pages_per_range must be None or a positive integer'):
BrinIndex(fields=['title'], name='test_title_brin', pages_per_range=0)
with self.assertRaisesMessage(
ValueError, "pages_per_range must be None or a positive integer"
):
BrinIndex(fields=["title"], name="test_title_brin", pages_per_range=0)
class BTreeIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = BTreeIndex
def test_suffix(self):
self.assertEqual(BTreeIndex.suffix, 'btree')
self.assertEqual(BTreeIndex.suffix, "btree")
def test_deconstruction(self):
index = BTreeIndex(fields=['title'], name='test_title_btree', fillfactor=80)
index = BTreeIndex(fields=["title"], name="test_title_btree", fillfactor=80)
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.BTreeIndex')
self.assertEqual(path, "django.contrib.postgres.indexes.BTreeIndex")
self.assertEqual(args, ())
self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_btree', 'fillfactor': 80})
self.assertEqual(
kwargs, {"fields": ["title"], "name": "test_title_btree", "fillfactor": 80}
)
class GinIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = GinIndex
def test_suffix(self):
self.assertEqual(GinIndex.suffix, 'gin')
self.assertEqual(GinIndex.suffix, "gin")
def test_deconstruction(self):
index = GinIndex(
fields=['title'],
name='test_title_gin',
fields=["title"],
name="test_title_gin",
fastupdate=True,
gin_pending_list_limit=128,
)
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.GinIndex')
self.assertEqual(path, "django.contrib.postgres.indexes.GinIndex")
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'fields': ['title'],
'name': 'test_title_gin',
'fastupdate': True,
'gin_pending_list_limit': 128,
})
self.assertEqual(
kwargs,
{
"fields": ["title"],
"name": "test_title_gin",
"fastupdate": True,
"gin_pending_list_limit": 128,
},
)
class GistIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = GistIndex
def test_suffix(self):
self.assertEqual(GistIndex.suffix, 'gist')
self.assertEqual(GistIndex.suffix, "gist")
def test_deconstruction(self):
index = GistIndex(fields=['title'], name='test_title_gist', buffering=False, fillfactor=80)
index = GistIndex(
fields=["title"], name="test_title_gist", buffering=False, fillfactor=80
)
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.GistIndex')
self.assertEqual(path, "django.contrib.postgres.indexes.GistIndex")
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'fields': ['title'],
'name': 'test_title_gist',
'buffering': False,
'fillfactor': 80,
})
self.assertEqual(
kwargs,
{
"fields": ["title"],
"name": "test_title_gist",
"buffering": False,
"fillfactor": 80,
},
)
class HashIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = HashIndex
def test_suffix(self):
self.assertEqual(HashIndex.suffix, 'hash')
self.assertEqual(HashIndex.suffix, "hash")
def test_deconstruction(self):
index = HashIndex(fields=['title'], name='test_title_hash', fillfactor=80)
index = HashIndex(fields=["title"], name="test_title_hash", fillfactor=80)
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.HashIndex')
self.assertEqual(path, "django.contrib.postgres.indexes.HashIndex")
self.assertEqual(args, ())
self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_hash', 'fillfactor': 80})
self.assertEqual(
kwargs, {"fields": ["title"], "name": "test_title_hash", "fillfactor": 80}
)
class SpGistIndexTests(IndexTestMixin, PostgreSQLSimpleTestCase):
index_class = SpGistIndex
def test_suffix(self):
self.assertEqual(SpGistIndex.suffix, 'spgist')
self.assertEqual(SpGistIndex.suffix, "spgist")
def test_deconstruction(self):
index = SpGistIndex(fields=['title'], name='test_title_spgist', fillfactor=80)
index = SpGistIndex(fields=["title"], name="test_title_spgist", fillfactor=80)
path, args, kwargs = index.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.indexes.SpGistIndex')
self.assertEqual(path, "django.contrib.postgres.indexes.SpGistIndex")
self.assertEqual(args, ())
self.assertEqual(kwargs, {'fields': ['title'], 'name': 'test_title_spgist', 'fillfactor': 80})
self.assertEqual(
kwargs, {"fields": ["title"], "name": "test_title_spgist", "fillfactor": 80}
)
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SchemaTests(PostgreSQLTestCase):
get_opclass_query = '''
get_opclass_query = """
SELECT opcname, c.relname FROM pg_opclass AS oc
JOIN pg_index as i on oc.oid = ANY(i.indclass)
JOIN pg_class as c on c.oid = i.indexrelid
WHERE c.relname = %s
'''
"""
def get_constraints(self, table):
"""
@ -211,229 +252,274 @@ class SchemaTests(PostgreSQLTestCase):
def test_gin_index(self):
# Ensure the table is there and doesn't have an index.
self.assertNotIn('field', self.get_constraints(IntegerArrayModel._meta.db_table))
self.assertNotIn(
"field", self.get_constraints(IntegerArrayModel._meta.db_table)
)
# Add the index
index_name = 'integer_array_model_field_gin'
index = GinIndex(fields=['field'], name=index_name)
index_name = "integer_array_model_field_gin"
index = GinIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(IntegerArrayModel, index)
constraints = self.get_constraints(IntegerArrayModel._meta.db_table)
# Check gin index was added
self.assertEqual(constraints[index_name]['type'], GinIndex.suffix)
self.assertEqual(constraints[index_name]["type"], GinIndex.suffix)
# Drop the index
with connection.schema_editor() as editor:
editor.remove_index(IntegerArrayModel, index)
self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(IntegerArrayModel._meta.db_table)
)
def test_gin_fastupdate(self):
index_name = 'integer_array_gin_fastupdate'
index = GinIndex(fields=['field'], name=index_name, fastupdate=False)
index_name = "integer_array_gin_fastupdate"
index = GinIndex(fields=["field"], name=index_name, fastupdate=False)
with connection.schema_editor() as editor:
editor.add_index(IntegerArrayModel, index)
constraints = self.get_constraints(IntegerArrayModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], 'gin')
self.assertEqual(constraints[index_name]['options'], ['fastupdate=off'])
self.assertEqual(constraints[index_name]["type"], "gin")
self.assertEqual(constraints[index_name]["options"], ["fastupdate=off"])
with connection.schema_editor() as editor:
editor.remove_index(IntegerArrayModel, index)
self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(IntegerArrayModel._meta.db_table)
)
def test_partial_gin_index(self):
with register_lookup(CharField, Length):
index_name = 'char_field_gin_partial_idx'
index = GinIndex(fields=['field'], name=index_name, condition=Q(field__length=40))
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], 'gin')
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
def test_partial_gin_index_with_tablespace(self):
with register_lookup(CharField, Length):
index_name = 'char_field_gin_partial_idx'
index_name = "char_field_gin_partial_idx"
index = GinIndex(
fields=['field'],
name=index_name,
condition=Q(field__length=40),
db_tablespace='pg_default',
fields=["field"], name=index_name, condition=Q(field__length=40)
)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
self.assertIn('TABLESPACE "pg_default" ', str(index.create_sql(CharFieldModel, editor)))
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], 'gin')
self.assertEqual(constraints[index_name]["type"], "gin")
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_partial_gin_index_with_tablespace(self):
with register_lookup(CharField, Length):
index_name = "char_field_gin_partial_idx"
index = GinIndex(
fields=["field"],
name=index_name,
condition=Q(field__length=40),
db_tablespace="pg_default",
)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
self.assertIn(
'TABLESPACE "pg_default" ',
str(index.create_sql(CharFieldModel, editor)),
)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]["type"], "gin")
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_gin_parameters(self):
index_name = 'integer_array_gin_params'
index = GinIndex(fields=['field'], name=index_name, fastupdate=True, gin_pending_list_limit=64)
index_name = "integer_array_gin_params"
index = GinIndex(
fields=["field"],
name=index_name,
fastupdate=True,
gin_pending_list_limit=64,
)
with connection.schema_editor() as editor:
editor.add_index(IntegerArrayModel, index)
constraints = self.get_constraints(IntegerArrayModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], 'gin')
self.assertEqual(constraints[index_name]['options'], ['gin_pending_list_limit=64', 'fastupdate=on'])
self.assertEqual(constraints[index_name]["type"], "gin")
self.assertEqual(
constraints[index_name]["options"],
["gin_pending_list_limit=64", "fastupdate=on"],
)
with connection.schema_editor() as editor:
editor.remove_index(IntegerArrayModel, index)
self.assertNotIn(index_name, self.get_constraints(IntegerArrayModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(IntegerArrayModel._meta.db_table)
)
def test_trigram_op_class_gin_index(self):
index_name = 'trigram_op_class_gin'
index = GinIndex(OpClass(F('scene'), name='gin_trgm_ops'), name=index_name)
index_name = "trigram_op_class_gin"
index = GinIndex(OpClass(F("scene"), name="gin_trgm_ops"), name=index_name)
with connection.schema_editor() as editor:
editor.add_index(Scene, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
self.assertCountEqual(cursor.fetchall(), [('gin_trgm_ops', index_name)])
self.assertCountEqual(cursor.fetchall(), [("gin_trgm_ops", index_name)])
constraints = self.get_constraints(Scene._meta.db_table)
self.assertIn(index_name, constraints)
self.assertIn(constraints[index_name]['type'], GinIndex.suffix)
self.assertIn(constraints[index_name]["type"], GinIndex.suffix)
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_cast_search_vector_gin_index(self):
index_name = 'cast_search_vector_gin'
index = GinIndex(Cast('field', SearchVectorField()), name=index_name)
index_name = "cast_search_vector_gin"
index = GinIndex(Cast("field", SearchVectorField()), name=index_name)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
sql = index.create_sql(TextFieldModel, editor)
table = TextFieldModel._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(index_name, constraints)
self.assertIn(constraints[index_name]['type'], GinIndex.suffix)
self.assertIs(sql.references_column(table, 'field'), True)
self.assertIn('::tsvector', str(sql))
self.assertIn(constraints[index_name]["type"], GinIndex.suffix)
self.assertIs(sql.references_column(table, "field"), True)
self.assertIn("::tsvector", str(sql))
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(table))
def test_bloom_index(self):
index_name = 'char_field_model_field_bloom'
index = BloomIndex(fields=['field'], name=index_name)
index_name = "char_field_model_field_bloom"
index = BloomIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], BloomIndex.suffix)
self.assertEqual(constraints[index_name]["type"], BloomIndex.suffix)
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_bloom_parameters(self):
index_name = 'char_field_model_field_bloom_params'
index = BloomIndex(fields=['field'], name=index_name, length=512, columns=[3])
index_name = "char_field_model_field_bloom_params"
index = BloomIndex(fields=["field"], name=index_name, length=512, columns=[3])
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], BloomIndex.suffix)
self.assertEqual(constraints[index_name]['options'], ['length=512', 'col1=3'])
self.assertEqual(constraints[index_name]["type"], BloomIndex.suffix)
self.assertEqual(constraints[index_name]["options"], ["length=512", "col1=3"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_brin_index(self):
index_name = 'char_field_model_field_brin'
index = BrinIndex(fields=['field'], name=index_name, pages_per_range=4)
index_name = "char_field_model_field_brin"
index = BrinIndex(fields=["field"], name=index_name, pages_per_range=4)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], BrinIndex.suffix)
self.assertEqual(constraints[index_name]['options'], ['pages_per_range=4'])
self.assertEqual(constraints[index_name]["type"], BrinIndex.suffix)
self.assertEqual(constraints[index_name]["options"], ["pages_per_range=4"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_brin_parameters(self):
index_name = 'char_field_brin_params'
index = BrinIndex(fields=['field'], name=index_name, autosummarize=True)
index_name = "char_field_brin_params"
index = BrinIndex(fields=["field"], name=index_name, autosummarize=True)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], BrinIndex.suffix)
self.assertEqual(constraints[index_name]['options'], ['autosummarize=on'])
self.assertEqual(constraints[index_name]["type"], BrinIndex.suffix)
self.assertEqual(constraints[index_name]["options"], ["autosummarize=on"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_btree_index(self):
# Ensure the table is there and doesn't have an index.
self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table))
# Add the index.
index_name = 'char_field_model_field_btree'
index = BTreeIndex(fields=['field'], name=index_name)
index_name = "char_field_model_field_btree"
index = BTreeIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
# The index was added.
self.assertEqual(constraints[index_name]['type'], BTreeIndex.suffix)
self.assertEqual(constraints[index_name]["type"], BTreeIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_btree_parameters(self):
index_name = 'integer_array_btree_fillfactor'
index = BTreeIndex(fields=['field'], name=index_name, fillfactor=80)
index_name = "integer_array_btree_fillfactor"
index = BTreeIndex(fields=["field"], name=index_name, fillfactor=80)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], BTreeIndex.suffix)
self.assertEqual(constraints[index_name]['options'], ['fillfactor=80'])
self.assertEqual(constraints[index_name]["type"], BTreeIndex.suffix)
self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_gist_index(self):
# Ensure the table is there and doesn't have an index.
self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table))
# Add the index.
index_name = 'char_field_model_field_gist'
index = GistIndex(fields=['field'], name=index_name)
index_name = "char_field_model_field_gist"
index = GistIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
# The index was added.
self.assertEqual(constraints[index_name]['type'], GistIndex.suffix)
self.assertEqual(constraints[index_name]["type"], GistIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_gist_parameters(self):
index_name = 'integer_array_gist_buffering'
index = GistIndex(fields=['field'], name=index_name, buffering=True, fillfactor=80)
index_name = "integer_array_gist_buffering"
index = GistIndex(
fields=["field"], name=index_name, buffering=True, fillfactor=80
)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], GistIndex.suffix)
self.assertEqual(constraints[index_name]['options'], ['buffering=on', 'fillfactor=80'])
self.assertEqual(constraints[index_name]["type"], GistIndex.suffix)
self.assertEqual(
constraints[index_name]["options"], ["buffering=on", "fillfactor=80"]
)
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
@skipUnlessDBFeature('supports_covering_gist_indexes')
@skipUnlessDBFeature("supports_covering_gist_indexes")
def test_gist_include(self):
index_name = 'scene_gist_include_setting'
index = GistIndex(name=index_name, fields=['scene'], include=['setting'])
index_name = "scene_gist_include_setting"
index = GistIndex(name=index_name, fields=["scene"], include=["setting"])
with connection.schema_editor() as editor:
editor.add_index(Scene, index)
constraints = self.get_constraints(Scene._meta.db_table)
self.assertIn(index_name, constraints)
self.assertEqual(constraints[index_name]['type'], GistIndex.suffix)
self.assertEqual(constraints[index_name]['columns'], ['scene', 'setting'])
self.assertEqual(constraints[index_name]["type"], GistIndex.suffix)
self.assertEqual(constraints[index_name]["columns"], ["scene", "setting"])
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_gist_include_not_supported(self):
index_name = 'gist_include_exception'
index = GistIndex(fields=['scene'], name=index_name, include=['setting'])
msg = 'Covering GiST indexes require PostgreSQL 12+.'
index_name = "gist_include_exception"
index = GistIndex(fields=["scene"], name=index_name, include=["setting"])
msg = "Covering GiST indexes require PostgreSQL 12+."
with self.assertRaisesMessage(NotSupportedError, msg):
with mock.patch(
'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes',
"django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes",
False,
):
with connection.schema_editor() as editor:
@ -441,11 +527,11 @@ class SchemaTests(PostgreSQLTestCase):
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_tsvector_op_class_gist_index(self):
index_name = 'tsvector_op_class_gist'
index_name = "tsvector_op_class_gist"
index = GistIndex(
OpClass(
SearchVector('scene', 'setting', config='english'),
name='tsvector_ops',
SearchVector("scene", "setting", config="english"),
name="tsvector_ops",
),
name=index_name,
)
@ -455,90 +541,98 @@ class SchemaTests(PostgreSQLTestCase):
table = Scene._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(index_name, constraints)
self.assertIn(constraints[index_name]['type'], GistIndex.suffix)
self.assertIs(sql.references_column(table, 'scene'), True)
self.assertIs(sql.references_column(table, 'setting'), True)
self.assertIn(constraints[index_name]["type"], GistIndex.suffix)
self.assertIs(sql.references_column(table, "scene"), True)
self.assertIs(sql.references_column(table, "setting"), True)
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(table))
def test_hash_index(self):
# Ensure the table is there and doesn't have an index.
self.assertNotIn('field', self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn("field", self.get_constraints(CharFieldModel._meta.db_table))
# Add the index.
index_name = 'char_field_model_field_hash'
index = HashIndex(fields=['field'], name=index_name)
index_name = "char_field_model_field_hash"
index = HashIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
# The index was added.
self.assertEqual(constraints[index_name]['type'], HashIndex.suffix)
self.assertEqual(constraints[index_name]["type"], HashIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_hash_parameters(self):
index_name = 'integer_array_hash_fillfactor'
index = HashIndex(fields=['field'], name=index_name, fillfactor=80)
index_name = "integer_array_hash_fillfactor"
index = HashIndex(fields=["field"], name=index_name, fillfactor=80)
with connection.schema_editor() as editor:
editor.add_index(CharFieldModel, index)
constraints = self.get_constraints(CharFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], HashIndex.suffix)
self.assertEqual(constraints[index_name]['options'], ['fillfactor=80'])
self.assertEqual(constraints[index_name]["type"], HashIndex.suffix)
self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"])
with connection.schema_editor() as editor:
editor.remove_index(CharFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(CharFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(CharFieldModel._meta.db_table)
)
def test_spgist_index(self):
# Ensure the table is there and doesn't have an index.
self.assertNotIn('field', self.get_constraints(TextFieldModel._meta.db_table))
self.assertNotIn("field", self.get_constraints(TextFieldModel._meta.db_table))
# Add the index.
index_name = 'text_field_model_field_spgist'
index = SpGistIndex(fields=['field'], name=index_name)
index_name = "text_field_model_field_spgist"
index = SpGistIndex(fields=["field"], name=index_name)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
constraints = self.get_constraints(TextFieldModel._meta.db_table)
# The index was added.
self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix)
self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix)
# Drop the index.
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(TextFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(TextFieldModel._meta.db_table)
)
def test_spgist_parameters(self):
index_name = 'text_field_model_spgist_fillfactor'
index = SpGistIndex(fields=['field'], name=index_name, fillfactor=80)
index_name = "text_field_model_spgist_fillfactor"
index = SpGistIndex(fields=["field"], name=index_name, fillfactor=80)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
constraints = self.get_constraints(TextFieldModel._meta.db_table)
self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix)
self.assertEqual(constraints[index_name]['options'], ['fillfactor=80'])
self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix)
self.assertEqual(constraints[index_name]["options"], ["fillfactor=80"])
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(TextFieldModel._meta.db_table))
self.assertNotIn(
index_name, self.get_constraints(TextFieldModel._meta.db_table)
)
@skipUnlessDBFeature('supports_covering_spgist_indexes')
@skipUnlessDBFeature("supports_covering_spgist_indexes")
def test_spgist_include(self):
index_name = 'scene_spgist_include_setting'
index = SpGistIndex(name=index_name, fields=['scene'], include=['setting'])
index_name = "scene_spgist_include_setting"
index = SpGistIndex(name=index_name, fields=["scene"], include=["setting"])
with connection.schema_editor() as editor:
editor.add_index(Scene, index)
constraints = self.get_constraints(Scene._meta.db_table)
self.assertIn(index_name, constraints)
self.assertEqual(constraints[index_name]['type'], SpGistIndex.suffix)
self.assertEqual(constraints[index_name]['columns'], ['scene', 'setting'])
self.assertEqual(constraints[index_name]["type"], SpGistIndex.suffix)
self.assertEqual(constraints[index_name]["columns"], ["scene", "setting"])
with connection.schema_editor() as editor:
editor.remove_index(Scene, index)
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_spgist_include_not_supported(self):
index_name = 'spgist_include_exception'
index = SpGistIndex(fields=['scene'], name=index_name, include=['setting'])
msg = 'Covering SP-GiST indexes require PostgreSQL 14+.'
index_name = "spgist_include_exception"
index = SpGistIndex(fields=["scene"], name=index_name, include=["setting"])
msg = "Covering SP-GiST indexes require PostgreSQL 14+."
with self.assertRaisesMessage(NotSupportedError, msg):
with mock.patch(
'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_spgist_indexes',
"django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_spgist_indexes",
False,
):
with connection.schema_editor() as editor:
@ -546,27 +640,25 @@ class SchemaTests(PostgreSQLTestCase):
self.assertNotIn(index_name, self.get_constraints(Scene._meta.db_table))
def test_op_class(self):
index_name = 'test_op_class'
index_name = "test_op_class"
index = Index(
OpClass(Lower('field'), name='text_pattern_ops'),
OpClass(Lower("field"), name="text_pattern_ops"),
name=index_name,
)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
def test_op_class_descending_collation(self):
collation = connection.features.test_collations.get('non_default')
collation = connection.features.test_collations.get("non_default")
if not collation:
self.skipTest(
'This backend does not support case-insensitive collations.'
)
index_name = 'test_op_class_descending_collation'
self.skipTest("This backend does not support case-insensitive collations.")
index_name = "test_op_class_descending_collation"
index = Index(
Collate(
OpClass(Lower('field'), name='text_pattern_ops').desc(nulls_last=True),
OpClass(Lower("field"), name="text_pattern_ops").desc(nulls_last=True),
collation=collation,
),
name=index_name,
@ -574,53 +666,53 @@ class SchemaTests(PostgreSQLTestCase):
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
self.assertIn(
'COLLATE %s' % editor.quote_name(collation),
"COLLATE %s" % editor.quote_name(collation),
str(index.create_sql(TextFieldModel, editor)),
)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
table = TextFieldModel._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(index_name, constraints)
self.assertEqual(constraints[index_name]['orders'], ['DESC'])
self.assertEqual(constraints[index_name]["orders"], ["DESC"])
with connection.schema_editor() as editor:
editor.remove_index(TextFieldModel, index)
self.assertNotIn(index_name, self.get_constraints(table))
def test_op_class_descending_partial(self):
index_name = 'test_op_class_descending_partial'
index_name = "test_op_class_descending_partial"
index = Index(
OpClass(Lower('field'), name='text_pattern_ops').desc(),
OpClass(Lower("field"), name="text_pattern_ops").desc(),
name=index_name,
condition=Q(field__contains='China'),
condition=Q(field__contains="China"),
)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
constraints = self.get_constraints(TextFieldModel._meta.db_table)
self.assertIn(index_name, constraints)
self.assertEqual(constraints[index_name]['orders'], ['DESC'])
self.assertEqual(constraints[index_name]["orders"], ["DESC"])
def test_op_class_descending_partial_tablespace(self):
index_name = 'test_op_class_descending_partial_tablespace'
index_name = "test_op_class_descending_partial_tablespace"
index = Index(
OpClass(Lower('field').desc(), name='text_pattern_ops'),
OpClass(Lower("field").desc(), name="text_pattern_ops"),
name=index_name,
condition=Q(field__contains='China'),
db_tablespace='pg_default',
condition=Q(field__contains="China"),
db_tablespace="pg_default",
)
with connection.schema_editor() as editor:
editor.add_index(TextFieldModel, index)
self.assertIn(
'TABLESPACE "pg_default" ',
str(index.create_sql(TextFieldModel, editor))
str(index.create_sql(TextFieldModel, editor)),
)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [index_name])
self.assertCountEqual(cursor.fetchall(), [('text_pattern_ops', index_name)])
self.assertCountEqual(cursor.fetchall(), [("text_pattern_ops", index_name)])
constraints = self.get_constraints(TextFieldModel._meta.db_table)
self.assertIn(index_name, constraints)
self.assertEqual(constraints[index_name]['orders'], ['DESC'])
self.assertEqual(constraints[index_name]["orders"], ["DESC"])

View file

@ -8,15 +8,22 @@ from . import PostgreSQLSimpleTestCase
class PostgresIntegrationTests(PostgreSQLSimpleTestCase):
def test_check(self):
test_environ = os.environ.copy()
if 'DJANGO_SETTINGS_MODULE' in test_environ:
del test_environ['DJANGO_SETTINGS_MODULE']
test_environ['PYTHONPATH'] = os.path.join(os.path.dirname(__file__), '../../')
if "DJANGO_SETTINGS_MODULE" in test_environ:
del test_environ["DJANGO_SETTINGS_MODULE"]
test_environ["PYTHONPATH"] = os.path.join(os.path.dirname(__file__), "../../")
result = subprocess.run(
[sys.executable, '-m', 'django', 'check', '--settings', 'integration_settings'],
[
sys.executable,
"-m",
"django",
"check",
"--settings",
"integration_settings",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
cwd=os.path.dirname(__file__),
env=test_environ,
encoding='utf-8',
encoding="utf-8",
)
self.assertEqual(result.returncode, 0, msg=result.stderr)

View file

@ -6,12 +6,12 @@ from django.test.utils import modify_settings
from . import PostgreSQLTestCase
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class InspectDBTests(PostgreSQLTestCase):
def assertFieldsInModel(self, model, field_outputs):
out = StringIO()
call_command(
'inspectdb',
"inspectdb",
table_name_filter=lambda tn: tn.startswith(model),
stdout=out,
)
@ -21,12 +21,12 @@ class InspectDBTests(PostgreSQLTestCase):
def test_range_fields(self):
self.assertFieldsInModel(
'postgres_tests_rangesmodel',
"postgres_tests_rangesmodel",
[
'ints = django.contrib.postgres.fields.IntegerRangeField(blank=True, null=True)',
'bigints = django.contrib.postgres.fields.BigIntegerRangeField(blank=True, null=True)',
'decimals = django.contrib.postgres.fields.DecimalRangeField(blank=True, null=True)',
'timestamps = django.contrib.postgres.fields.DateTimeRangeField(blank=True, null=True)',
'dates = django.contrib.postgres.fields.DateRangeField(blank=True, null=True)',
"ints = django.contrib.postgres.fields.IntegerRangeField(blank=True, null=True)",
"bigints = django.contrib.postgres.fields.BigIntegerRangeField(blank=True, null=True)",
"decimals = django.contrib.postgres.fields.DecimalRangeField(blank=True, null=True)",
"timestamps = django.contrib.postgres.fields.DateTimeRangeField(blank=True, null=True)",
"dates = django.contrib.postgres.fields.DateRangeField(blank=True, null=True)",
],
)

View file

@ -3,9 +3,7 @@ from unittest import mock
from migrations.test_base import OperationTestBase
from django.db import (
IntegrityError, NotSupportedError, connection, transaction,
)
from django.db import IntegrityError, NotSupportedError, connection, transaction
from django.db.migrations.state import ProjectState
from django.db.models import CheckConstraint, Index, Q, UniqueConstraint
from django.db.utils import ProgrammingError
@ -17,262 +15,315 @@ from . import PostgreSQLTestCase
try:
from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
from django.contrib.postgres.operations import (
AddConstraintNotValid, AddIndexConcurrently, BloomExtension,
CreateCollation, CreateExtension, RemoveCollation,
RemoveIndexConcurrently, ValidateConstraint,
AddConstraintNotValid,
AddIndexConcurrently,
BloomExtension,
CreateCollation,
CreateExtension,
RemoveCollation,
RemoveIndexConcurrently,
ValidateConstraint,
)
except ImportError:
pass
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
@modify_settings(INSTALLED_APPS={'append': 'migrations'})
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
@modify_settings(INSTALLED_APPS={"append": "migrations"})
class AddIndexConcurrentlyTests(OperationTestBase):
app_label = 'test_add_concurrently'
app_label = "test_add_concurrently"
def test_requires_atomic_false(self):
project_state = self.set_up_test_model(self.app_label)
new_state = project_state.clone()
operation = AddIndexConcurrently(
'Pony',
Index(fields=['pink'], name='pony_pink_idx'),
"Pony",
Index(fields=["pink"], name="pony_pink_idx"),
)
msg = (
'The AddIndexConcurrently operation cannot be executed inside '
'a transaction (set atomic = False on the migration).'
"The AddIndexConcurrently operation cannot be executed inside "
"a transaction (set atomic = False on the migration)."
)
with self.assertRaisesMessage(NotSupportedError, msg):
with connection.schema_editor(atomic=True) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
def test_add(self):
project_state = self.set_up_test_model(self.app_label, index=False)
table_name = '%s_pony' % self.app_label
index = Index(fields=['pink'], name='pony_pink_idx')
table_name = "%s_pony" % self.app_label
index = Index(fields=["pink"], name="pony_pink_idx")
new_state = project_state.clone()
operation = AddIndexConcurrently('Pony', index)
operation = AddIndexConcurrently("Pony", index)
self.assertEqual(
operation.describe(),
'Concurrently create index pony_pink_idx on field(s) pink of model Pony',
"Concurrently create index pony_pink_idx on field(s) pink of model Pony",
)
operation.state_forwards(self.app_label, new_state)
self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 1)
self.assertIndexNotExists(table_name, ['pink'])
self.assertEqual(
len(new_state.models[self.app_label, "pony"].options["indexes"]), 1
)
self.assertIndexNotExists(table_name, ["pink"])
# Add index.
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertIndexExists(table_name, ['pink'])
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertIndexExists(table_name, ["pink"])
# Reversal.
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
self.assertIndexNotExists(table_name, ['pink'])
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertIndexNotExists(table_name, ["pink"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
self.assertEqual(name, 'AddIndexConcurrently')
self.assertEqual(name, "AddIndexConcurrently")
self.assertEqual(args, [])
self.assertEqual(kwargs, {'model_name': 'Pony', 'index': index})
self.assertEqual(kwargs, {"model_name": "Pony", "index": index})
def test_add_other_index_type(self):
project_state = self.set_up_test_model(self.app_label, index=False)
table_name = '%s_pony' % self.app_label
table_name = "%s_pony" % self.app_label
new_state = project_state.clone()
operation = AddIndexConcurrently(
'Pony',
BrinIndex(fields=['pink'], name='pony_pink_brin_idx'),
"Pony",
BrinIndex(fields=["pink"], name="pony_pink_brin_idx"),
)
self.assertIndexNotExists(table_name, ['pink'])
self.assertIndexNotExists(table_name, ["pink"])
# Add index.
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertIndexExists(table_name, ['pink'], index_type='brin')
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertIndexExists(table_name, ["pink"], index_type="brin")
# Reversal.
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
self.assertIndexNotExists(table_name, ['pink'])
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertIndexNotExists(table_name, ["pink"])
def test_add_with_options(self):
project_state = self.set_up_test_model(self.app_label, index=False)
table_name = '%s_pony' % self.app_label
table_name = "%s_pony" % self.app_label
new_state = project_state.clone()
index = BTreeIndex(fields=['pink'], name='pony_pink_btree_idx', fillfactor=70)
operation = AddIndexConcurrently('Pony', index)
self.assertIndexNotExists(table_name, ['pink'])
index = BTreeIndex(fields=["pink"], name="pony_pink_btree_idx", fillfactor=70)
operation = AddIndexConcurrently("Pony", index)
self.assertIndexNotExists(table_name, ["pink"])
# Add index.
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertIndexExists(table_name, ['pink'], index_type='btree')
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertIndexExists(table_name, ["pink"], index_type="btree")
# Reversal.
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
self.assertIndexNotExists(table_name, ['pink'])
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertIndexNotExists(table_name, ["pink"])
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
@modify_settings(INSTALLED_APPS={'append': 'migrations'})
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
@modify_settings(INSTALLED_APPS={"append": "migrations"})
class RemoveIndexConcurrentlyTests(OperationTestBase):
app_label = 'test_rm_concurrently'
app_label = "test_rm_concurrently"
def test_requires_atomic_false(self):
project_state = self.set_up_test_model(self.app_label, index=True)
new_state = project_state.clone()
operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx')
operation = RemoveIndexConcurrently("Pony", "pony_pink_idx")
msg = (
'The RemoveIndexConcurrently operation cannot be executed inside '
'a transaction (set atomic = False on the migration).'
"The RemoveIndexConcurrently operation cannot be executed inside "
"a transaction (set atomic = False on the migration)."
)
with self.assertRaisesMessage(NotSupportedError, msg):
with connection.schema_editor(atomic=True) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
def test_remove(self):
project_state = self.set_up_test_model(self.app_label, index=True)
table_name = '%s_pony' % self.app_label
table_name = "%s_pony" % self.app_label
self.assertTableExists(table_name)
new_state = project_state.clone()
operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx')
operation = RemoveIndexConcurrently("Pony", "pony_pink_idx")
self.assertEqual(
operation.describe(),
'Concurrently remove index pony_pink_idx from Pony',
"Concurrently remove index pony_pink_idx from Pony",
)
operation.state_forwards(self.app_label, new_state)
self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 0)
self.assertIndexExists(table_name, ['pink'])
self.assertEqual(
len(new_state.models[self.app_label, "pony"].options["indexes"]), 0
)
self.assertIndexExists(table_name, ["pink"])
# Remove index.
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
self.assertIndexNotExists(table_name, ['pink'])
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertIndexNotExists(table_name, ["pink"])
# Reversal.
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
self.assertIndexExists(table_name, ['pink'])
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertIndexExists(table_name, ["pink"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
self.assertEqual(name, 'RemoveIndexConcurrently')
self.assertEqual(name, "RemoveIndexConcurrently")
self.assertEqual(args, [])
self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
self.assertEqual(kwargs, {"model_name": "Pony", "name": "pony_pink_idx"})
class NoMigrationRouter():
class NoMigrationRouter:
def allow_migrate(self, db, app_label, **hints):
return False
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
class CreateExtensionTests(PostgreSQLTestCase):
app_label = 'test_allow_create_extention'
app_label = "test_allow_create_extention"
@override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
def test_no_allow_migrate(self):
operation = CreateExtension('tablefunc')
operation = CreateExtension("tablefunc")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create an extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 0)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 0)
def test_allow_migrate(self):
operation = CreateExtension('tablefunc')
self.assertEqual(operation.migration_name_fragment, 'create_extension_tablefunc')
operation = CreateExtension("tablefunc")
self.assertEqual(
operation.migration_name_fragment, "create_extension_tablefunc"
)
project_state = ProjectState()
new_state = project_state.clone()
# Create an extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 4)
self.assertIn('CREATE EXTENSION IF NOT EXISTS', captured_queries[1]['sql'])
self.assertIn("CREATE EXTENSION IF NOT EXISTS", captured_queries[1]["sql"])
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 2)
self.assertIn('DROP EXTENSION IF EXISTS', captured_queries[1]['sql'])
self.assertIn("DROP EXTENSION IF EXISTS", captured_queries[1]["sql"])
def test_create_existing_extension(self):
operation = BloomExtension()
self.assertEqual(operation.migration_name_fragment, 'create_extension_bloom')
self.assertEqual(operation.migration_name_fragment, "create_extension_bloom")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create an existing extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 3)
self.assertIn('SELECT', captured_queries[0]['sql'])
self.assertIn("SELECT", captured_queries[0]["sql"])
def test_drop_nonexistent_extension(self):
operation = CreateExtension('tablefunc')
operation = CreateExtension("tablefunc")
project_state = ProjectState()
new_state = project_state.clone()
# Don't drop a nonexistent extension.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, project_state, new_state)
operation.database_backwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('SELECT', captured_queries[0]['sql'])
self.assertIn("SELECT", captured_queries[0]["sql"])
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
class CreateCollationTests(PostgreSQLTestCase):
app_label = 'test_allow_create_collation'
app_label = "test_allow_create_collation"
@override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
def test_no_allow_migrate(self):
operation = CreateCollation('C_test', locale='C')
operation = CreateCollation("C_test", locale="C")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 0)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 0)
def test_create(self):
operation = CreateCollation('C_test', locale='C')
self.assertEqual(operation.migration_name_fragment, 'create_collation_c_test')
self.assertEqual(operation.describe(), 'Create collation C_test')
operation = CreateCollation("C_test", locale="C")
self.assertEqual(operation.migration_name_fragment, "create_collation_c_test")
self.assertEqual(operation.describe(), "Create collation C_test")
project_state = ProjectState()
new_state = project_state.clone()
# Create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Creating the same collation raises an exception.
with self.assertRaisesMessage(ProgrammingError, 'already exists'):
with self.assertRaisesMessage(ProgrammingError, "already exists"):
with connection.schema_editor(atomic=True) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
self.assertEqual(name, 'CreateCollation')
self.assertEqual(name, "CreateCollation")
self.assertEqual(args, [])
self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
self.assertEqual(kwargs, {"name": "C_test", "locale": "C"})
@skipUnlessDBFeature('supports_non_deterministic_collations')
@skipUnlessDBFeature("supports_non_deterministic_collations")
def test_create_non_deterministic_collation(self):
operation = CreateCollation(
'case_insensitive_test',
'und-u-ks-level2',
provider='icu',
"case_insensitive_test",
"und-u-ks-level2",
provider="icu",
deterministic=False,
)
project_state = ProjectState()
@ -280,216 +331,253 @@ class CreateCollationTests(PostgreSQLTestCase):
# Create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
self.assertEqual(name, 'CreateCollation')
self.assertEqual(name, "CreateCollation")
self.assertEqual(args, [])
self.assertEqual(kwargs, {
'name': 'case_insensitive_test',
'locale': 'und-u-ks-level2',
'provider': 'icu',
'deterministic': False,
})
self.assertEqual(
kwargs,
{
"name": "case_insensitive_test",
"locale": "und-u-ks-level2",
"provider": "icu",
"deterministic": False,
},
)
def test_create_collation_alternate_provider(self):
operation = CreateCollation(
'german_phonebook_test',
provider='icu',
locale='de-u-co-phonebk',
"german_phonebook_test",
provider="icu",
locale="de-u-co-phonebk",
)
project_state = ProjectState()
new_state = project_state.clone()
# Create an collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
def test_nondeterministic_collation_not_supported(self):
operation = CreateCollation(
'case_insensitive_test',
provider='icu',
locale='und-u-ks-level2',
"case_insensitive_test",
provider="icu",
locale="und-u-ks-level2",
deterministic=False,
)
project_state = ProjectState()
new_state = project_state.clone()
msg = 'Non-deterministic collations require PostgreSQL 12+.'
msg = "Non-deterministic collations require PostgreSQL 12+."
with connection.schema_editor(atomic=False) as editor:
with mock.patch(
'django.db.backends.postgresql.features.DatabaseFeatures.'
'supports_non_deterministic_collations',
"django.db.backends.postgresql.features.DatabaseFeatures."
"supports_non_deterministic_collations",
False,
):
with self.assertRaisesMessage(NotSupportedError, msg):
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
class RemoveCollationTests(PostgreSQLTestCase):
app_label = 'test_allow_remove_collation'
app_label = "test_allow_remove_collation"
@override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
def test_no_allow_migrate(self):
operation = RemoveCollation('C_test', locale='C')
operation = RemoveCollation("C_test", locale="C")
project_state = ProjectState()
new_state = project_state.clone()
# Don't create a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 0)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 0)
def test_remove(self):
operation = CreateCollation('C_test', locale='C')
operation = CreateCollation("C_test", locale="C")
project_state = ProjectState()
new_state = project_state.clone()
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
operation = RemoveCollation('C_test', locale='C')
self.assertEqual(operation.migration_name_fragment, 'remove_collation_c_test')
self.assertEqual(operation.describe(), 'Remove collation C_test')
operation = RemoveCollation("C_test", locale="C")
self.assertEqual(operation.migration_name_fragment, "remove_collation_c_test")
self.assertEqual(operation.describe(), "Remove collation C_test")
project_state = ProjectState()
new_state = project_state.clone()
# Remove a collation.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
self.assertIn("DROP COLLATION", captured_queries[0]["sql"])
# Removing a nonexistent collation raises an exception.
with self.assertRaisesMessage(ProgrammingError, 'does not exist'):
with self.assertRaisesMessage(ProgrammingError, "does not exist"):
with connection.schema_editor(atomic=True) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
# Reversal.
with CaptureQueriesContext(connection) as captured_queries:
with connection.schema_editor(atomic=False) as editor:
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
self.assertEqual(len(captured_queries), 1)
self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
self.assertIn("CREATE COLLATION", captured_queries[0]["sql"])
# Deconstruction.
name, args, kwargs = operation.deconstruct()
self.assertEqual(name, 'RemoveCollation')
self.assertEqual(name, "RemoveCollation")
self.assertEqual(args, [])
self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
self.assertEqual(kwargs, {"name": "C_test", "locale": "C"})
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
@modify_settings(INSTALLED_APPS={'append': 'migrations'})
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
@modify_settings(INSTALLED_APPS={"append": "migrations"})
class AddConstraintNotValidTests(OperationTestBase):
app_label = 'test_add_constraint_not_valid'
app_label = "test_add_constraint_not_valid"
def test_non_check_constraint_not_supported(self):
constraint = UniqueConstraint(fields=['pink'], name='pony_pink_uniq')
msg = 'AddConstraintNotValid.constraint must be a check constraint.'
constraint = UniqueConstraint(fields=["pink"], name="pony_pink_uniq")
msg = "AddConstraintNotValid.constraint must be a check constraint."
with self.assertRaisesMessage(TypeError, msg):
AddConstraintNotValid(model_name='pony', constraint=constraint)
AddConstraintNotValid(model_name="pony", constraint=constraint)
def test_add(self):
table_name = f'{self.app_label}_pony'
constraint_name = 'pony_pink_gte_check'
table_name = f"{self.app_label}_pony"
constraint_name = "pony_pink_gte_check"
constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name)
operation = AddConstraintNotValid('Pony', constraint=constraint)
operation = AddConstraintNotValid("Pony", constraint=constraint)
project_state, new_state = self.make_test_state(self.app_label, operation)
self.assertEqual(
operation.describe(),
f'Create not valid constraint {constraint_name} on model Pony',
f"Create not valid constraint {constraint_name} on model Pony",
)
self.assertEqual(
operation.migration_name_fragment,
f'pony_{constraint_name}_not_valid',
f"pony_{constraint_name}_not_valid",
)
self.assertEqual(
len(new_state.models[self.app_label, 'pony'].options['constraints']),
len(new_state.models[self.app_label, "pony"].options["constraints"]),
1,
)
self.assertConstraintNotExists(table_name, constraint_name)
Pony = new_state.apps.get_model(self.app_label, 'Pony')
Pony = new_state.apps.get_model(self.app_label, "Pony")
self.assertEqual(len(Pony._meta.constraints), 1)
Pony.objects.create(pink=2, weight=1.0)
# Add constraint.
with connection.schema_editor(atomic=True) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
msg = f'check constraint "{constraint_name}"'
with self.assertRaisesMessage(IntegrityError, msg), transaction.atomic():
Pony.objects.create(pink=3, weight=1.0)
self.assertConstraintExists(table_name, constraint_name)
# Reversal.
with connection.schema_editor(atomic=True) as editor:
operation.database_backwards(self.app_label, editor, project_state, new_state)
operation.database_backwards(
self.app_label, editor, project_state, new_state
)
self.assertConstraintNotExists(table_name, constraint_name)
Pony.objects.create(pink=3, weight=1.0)
# Deconstruction.
name, args, kwargs = operation.deconstruct()
self.assertEqual(name, 'AddConstraintNotValid')
self.assertEqual(name, "AddConstraintNotValid")
self.assertEqual(args, [])
self.assertEqual(kwargs, {'model_name': 'Pony', 'constraint': constraint})
self.assertEqual(kwargs, {"model_name": "Pony", "constraint": constraint})
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
@modify_settings(INSTALLED_APPS={'append': 'migrations'})
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
@modify_settings(INSTALLED_APPS={"append": "migrations"})
class ValidateConstraintTests(OperationTestBase):
app_label = 'test_validate_constraint'
app_label = "test_validate_constraint"
def test_validate(self):
constraint_name = 'pony_pink_gte_check'
constraint_name = "pony_pink_gte_check"
constraint = CheckConstraint(check=Q(pink__gte=4), name=constraint_name)
operation = AddConstraintNotValid('Pony', constraint=constraint)
operation = AddConstraintNotValid("Pony", constraint=constraint)
project_state, new_state = self.make_test_state(self.app_label, operation)
Pony = new_state.apps.get_model(self.app_label, 'Pony')
Pony = new_state.apps.get_model(self.app_label, "Pony")
obj = Pony.objects.create(pink=2, weight=1.0)
# Add constraint.
with connection.schema_editor(atomic=True) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
project_state = new_state
new_state = new_state.clone()
operation = ValidateConstraint('Pony', name=constraint_name)
operation = ValidateConstraint("Pony", name=constraint_name)
operation.state_forwards(self.app_label, new_state)
self.assertEqual(
operation.describe(),
f'Validate constraint {constraint_name} on model Pony',
f"Validate constraint {constraint_name} on model Pony",
)
self.assertEqual(
operation.migration_name_fragment,
f'pony_validate_{constraint_name}',
f"pony_validate_{constraint_name}",
)
# Validate constraint.
with connection.schema_editor(atomic=True) as editor:
msg = f'check constraint "{constraint_name}"'
with self.assertRaisesMessage(IntegrityError, msg):
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
obj.pink = 5
obj.save()
with connection.schema_editor(atomic=True) as editor:
operation.database_forwards(self.app_label, editor, project_state, new_state)
operation.database_forwards(
self.app_label, editor, project_state, new_state
)
# Reversal is a noop.
with connection.schema_editor() as editor:
with self.assertNumQueries(0):
operation.database_backwards(self.app_label, editor, new_state, project_state)
operation.database_backwards(
self.app_label, editor, new_state, project_state
)
# Deconstruction.
name, args, kwargs = operation.deconstruct()
self.assertEqual(name, 'ValidateConstraint')
self.assertEqual(name, "ValidateConstraint")
self.assertEqual(args, [])
self.assertEqual(kwargs, {'model_name': 'Pony', 'name': constraint_name})
self.assertEqual(kwargs, {"model_name": "Pony", "name": constraint_name})

File diff suppressed because it is too large Load diff

View file

@ -14,108 +14,125 @@ from .models import Character, Line, LineSavedSearch, Scene
try:
from django.contrib.postgres.search import (
SearchConfig, SearchHeadline, SearchQuery, SearchRank, SearchVector,
SearchConfig,
SearchHeadline,
SearchQuery,
SearchRank,
SearchVector,
)
except ImportError:
pass
class GrailTestData:
@classmethod
def setUpTestData(cls):
cls.robin = Scene.objects.create(scene='Scene 10', setting='The dark forest of Ewing')
cls.minstrel = Character.objects.create(name='Minstrel')
cls.robin = Scene.objects.create(
scene="Scene 10", setting="The dark forest of Ewing"
)
cls.minstrel = Character.objects.create(name="Minstrel")
verses = [
(
'Bravely bold Sir Robin, rode forth from Camelot. '
'He was not afraid to die, o Brave Sir Robin. '
'He was not at all afraid to be killed in nasty ways. '
'Brave, brave, brave, brave Sir Robin'
"Bravely bold Sir Robin, rode forth from Camelot. "
"He was not afraid to die, o Brave Sir Robin. "
"He was not at all afraid to be killed in nasty ways. "
"Brave, brave, brave, brave Sir Robin"
),
(
'He was not in the least bit scared to be mashed into a pulp, '
'Or to have his eyes gouged out, and his elbows broken. '
'To have his kneecaps split, and his body burned away, '
'And his limbs all hacked and mangled, brave Sir Robin!'
"He was not in the least bit scared to be mashed into a pulp, "
"Or to have his eyes gouged out, and his elbows broken. "
"To have his kneecaps split, and his body burned away, "
"And his limbs all hacked and mangled, brave Sir Robin!"
),
(
'His head smashed in and his heart cut out, '
'And his liver removed and his bowels unplugged, '
'And his nostrils ripped and his bottom burned off,'
'And his --'
"His head smashed in and his heart cut out, "
"And his liver removed and his bowels unplugged, "
"And his nostrils ripped and his bottom burned off,"
"And his --"
),
]
cls.verses = [Line.objects.create(
scene=cls.robin,
character=cls.minstrel,
dialogue=verse,
) for verse in verses]
cls.verses = [
Line.objects.create(
scene=cls.robin,
character=cls.minstrel,
dialogue=verse,
)
for verse in verses
]
cls.verse0, cls.verse1, cls.verse2 = cls.verses
cls.witch_scene = Scene.objects.create(scene='Scene 5', setting="Sir Bedemir's Castle")
bedemir = Character.objects.create(name='Bedemir')
crowd = Character.objects.create(name='Crowd')
witch = Character.objects.create(name='Witch')
duck = Character.objects.create(name='Duck')
cls.witch_scene = Scene.objects.create(
scene="Scene 5", setting="Sir Bedemir's Castle"
)
bedemir = Character.objects.create(name="Bedemir")
crowd = Character.objects.create(name="Crowd")
witch = Character.objects.create(name="Witch")
duck = Character.objects.create(name="Duck")
cls.bedemir0 = Line.objects.create(
scene=cls.witch_scene,
character=bedemir,
dialogue='We shall use my larger scales!',
dialogue_config='english',
dialogue="We shall use my larger scales!",
dialogue_config="english",
)
cls.bedemir1 = Line.objects.create(
scene=cls.witch_scene,
character=bedemir,
dialogue='Right, remove the supports!',
dialogue_config='english',
dialogue="Right, remove the supports!",
dialogue_config="english",
)
cls.duck = Line.objects.create(
scene=cls.witch_scene, character=duck, dialogue=None
)
cls.crowd = Line.objects.create(
scene=cls.witch_scene, character=crowd, dialogue="A witch! A witch!"
)
cls.witch = Line.objects.create(
scene=cls.witch_scene, character=witch, dialogue="It's a fair cop."
)
cls.duck = Line.objects.create(scene=cls.witch_scene, character=duck, dialogue=None)
cls.crowd = Line.objects.create(scene=cls.witch_scene, character=crowd, dialogue='A witch! A witch!')
cls.witch = Line.objects.create(scene=cls.witch_scene, character=witch, dialogue="It's a fair cop.")
trojan_rabbit = Scene.objects.create(scene='Scene 8', setting="The castle of Our Master Ruiz' de lu la Ramper")
guards = Character.objects.create(name='French Guards')
trojan_rabbit = Scene.objects.create(
scene="Scene 8", setting="The castle of Our Master Ruiz' de lu la Ramper"
)
guards = Character.objects.create(name="French Guards")
cls.french = Line.objects.create(
scene=trojan_rabbit,
character=guards,
dialogue='Oh. Un beau cadeau. Oui oui.',
dialogue_config='french',
dialogue="Oh. Un beau cadeau. Oui oui.",
dialogue_config="french",
)
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SimpleSearchTest(GrailTestData, PostgreSQLTestCase):
def test_simple(self):
searched = Line.objects.filter(dialogue__search='elbows')
searched = Line.objects.filter(dialogue__search="elbows")
self.assertSequenceEqual(searched, [self.verse1])
def test_non_exact_match(self):
searched = Line.objects.filter(dialogue__search='hearts')
searched = Line.objects.filter(dialogue__search="hearts")
self.assertSequenceEqual(searched, [self.verse2])
def test_search_two_terms(self):
searched = Line.objects.filter(dialogue__search='heart bowel')
searched = Line.objects.filter(dialogue__search="heart bowel")
self.assertSequenceEqual(searched, [self.verse2])
def test_search_two_terms_with_partial_match(self):
searched = Line.objects.filter(dialogue__search='Robin killed')
searched = Line.objects.filter(dialogue__search="Robin killed")
self.assertSequenceEqual(searched, [self.verse0])
def test_search_query_config(self):
searched = Line.objects.filter(
dialogue__search=SearchQuery('nostrils', config='simple'),
dialogue__search=SearchQuery("nostrils", config="simple"),
)
self.assertSequenceEqual(searched, [self.verse2])
def test_search_with_F_expression(self):
# Non-matching query.
LineSavedSearch.objects.create(line=self.verse1, query='hearts')
LineSavedSearch.objects.create(line=self.verse1, query="hearts")
# Matching query.
match = LineSavedSearch.objects.create(line=self.verse1, query='elbows')
for query_expression in [F('query'), SearchQuery(F('query'))]:
match = LineSavedSearch.objects.create(line=self.verse1, query="elbows")
for query_expression in [F("query"), SearchQuery(F("query"))]:
with self.subTest(query_expression):
searched = LineSavedSearch.objects.filter(
line__dialogue__search=query_expression,
@ -123,254 +140,296 @@ class SimpleSearchTest(GrailTestData, PostgreSQLTestCase):
self.assertSequenceEqual(searched, [match])
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SearchVectorFieldTest(GrailTestData, PostgreSQLTestCase):
def test_existing_vector(self):
Line.objects.update(dialogue_search_vector=SearchVector('dialogue'))
searched = Line.objects.filter(dialogue_search_vector=SearchQuery('Robin killed'))
Line.objects.update(dialogue_search_vector=SearchVector("dialogue"))
searched = Line.objects.filter(
dialogue_search_vector=SearchQuery("Robin killed")
)
self.assertSequenceEqual(searched, [self.verse0])
def test_existing_vector_config_explicit(self):
Line.objects.update(dialogue_search_vector=SearchVector('dialogue'))
searched = Line.objects.filter(dialogue_search_vector=SearchQuery('cadeaux', config='french'))
Line.objects.update(dialogue_search_vector=SearchVector("dialogue"))
searched = Line.objects.filter(
dialogue_search_vector=SearchQuery("cadeaux", config="french")
)
self.assertSequenceEqual(searched, [self.french])
def test_single_coalesce_expression(self):
searched = Line.objects.annotate(search=SearchVector('dialogue')).filter(search='cadeaux')
self.assertNotIn('COALESCE(COALESCE', str(searched.query))
searched = Line.objects.annotate(search=SearchVector("dialogue")).filter(
search="cadeaux"
)
self.assertNotIn("COALESCE(COALESCE", str(searched.query))
class SearchConfigTests(PostgreSQLSimpleTestCase):
def test_from_parameter(self):
self.assertIsNone(SearchConfig.from_parameter(None))
self.assertEqual(SearchConfig.from_parameter('foo'), SearchConfig('foo'))
self.assertEqual(SearchConfig.from_parameter(SearchConfig('bar')), SearchConfig('bar'))
self.assertEqual(SearchConfig.from_parameter("foo"), SearchConfig("foo"))
self.assertEqual(
SearchConfig.from_parameter(SearchConfig("bar")), SearchConfig("bar")
)
class MultipleFieldsTest(GrailTestData, PostgreSQLTestCase):
def test_simple_on_dialogue(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search='elbows')
search=SearchVector("scene__setting", "dialogue"),
).filter(search="elbows")
self.assertSequenceEqual(searched, [self.verse1])
def test_simple_on_scene(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search='Forest')
search=SearchVector("scene__setting", "dialogue"),
).filter(search="Forest")
self.assertCountEqual(searched, self.verses)
def test_non_exact_match(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search='heart')
search=SearchVector("scene__setting", "dialogue"),
).filter(search="heart")
self.assertSequenceEqual(searched, [self.verse2])
def test_search_two_terms(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search='heart forest')
search=SearchVector("scene__setting", "dialogue"),
).filter(search="heart forest")
self.assertSequenceEqual(searched, [self.verse2])
def test_terms_adjacent(self):
searched = Line.objects.annotate(
search=SearchVector('character__name', 'dialogue'),
).filter(search='minstrel')
search=SearchVector("character__name", "dialogue"),
).filter(search="minstrel")
self.assertCountEqual(searched, self.verses)
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search='minstrelbravely')
search=SearchVector("scene__setting", "dialogue"),
).filter(search="minstrelbravely")
self.assertSequenceEqual(searched, [])
def test_search_with_null(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search='bedemir')
self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck])
search=SearchVector("scene__setting", "dialogue"),
).filter(search="bedemir")
self.assertCountEqual(
searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]
)
def test_search_with_non_text(self):
searched = Line.objects.annotate(
search=SearchVector('id'),
search=SearchVector("id"),
).filter(search=str(self.crowd.id))
self.assertSequenceEqual(searched, [self.crowd])
def test_phrase_search(self):
line_qs = Line.objects.annotate(search=SearchVector('dialogue'))
searched = line_qs.filter(search=SearchQuery('burned body his away', search_type='phrase'))
line_qs = Line.objects.annotate(search=SearchVector("dialogue"))
searched = line_qs.filter(
search=SearchQuery("burned body his away", search_type="phrase")
)
self.assertSequenceEqual(searched, [])
searched = line_qs.filter(search=SearchQuery('his body burned away', search_type='phrase'))
searched = line_qs.filter(
search=SearchQuery("his body burned away", search_type="phrase")
)
self.assertSequenceEqual(searched, [self.verse1])
def test_phrase_search_with_config(self):
line_qs = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue', config='french'),
search=SearchVector("scene__setting", "dialogue", config="french"),
)
searched = line_qs.filter(
search=SearchQuery('cadeau beau un', search_type='phrase', config='french'),
search=SearchQuery("cadeau beau un", search_type="phrase", config="french"),
)
self.assertSequenceEqual(searched, [])
searched = line_qs.filter(
search=SearchQuery('un beau cadeau', search_type='phrase', config='french'),
search=SearchQuery("un beau cadeau", search_type="phrase", config="french"),
)
self.assertSequenceEqual(searched, [self.french])
def test_raw_search(self):
line_qs = Line.objects.annotate(search=SearchVector('dialogue'))
searched = line_qs.filter(search=SearchQuery('Robin', search_type='raw'))
line_qs = Line.objects.annotate(search=SearchVector("dialogue"))
searched = line_qs.filter(search=SearchQuery("Robin", search_type="raw"))
self.assertCountEqual(searched, [self.verse0, self.verse1])
searched = line_qs.filter(search=SearchQuery("Robin & !'Camelot'", search_type='raw'))
searched = line_qs.filter(
search=SearchQuery("Robin & !'Camelot'", search_type="raw")
)
self.assertSequenceEqual(searched, [self.verse1])
def test_raw_search_with_config(self):
line_qs = Line.objects.annotate(search=SearchVector('dialogue', config='french'))
line_qs = Line.objects.annotate(
search=SearchVector("dialogue", config="french")
)
searched = line_qs.filter(
search=SearchQuery("'cadeaux' & 'beaux'", search_type='raw', config='french'),
search=SearchQuery(
"'cadeaux' & 'beaux'", search_type="raw", config="french"
),
)
self.assertSequenceEqual(searched, [self.french])
@skipUnlessDBFeature('has_websearch_to_tsquery')
@skipUnlessDBFeature("has_websearch_to_tsquery")
def test_web_search(self):
line_qs = Line.objects.annotate(search=SearchVector('dialogue'))
line_qs = Line.objects.annotate(search=SearchVector("dialogue"))
searched = line_qs.filter(
search=SearchQuery(
'"burned body" "split kneecaps"',
search_type='websearch',
search_type="websearch",
),
)
self.assertSequenceEqual(searched, [])
searched = line_qs.filter(
search=SearchQuery(
'"body burned" "kneecaps split" -"nostrils"',
search_type='websearch',
search_type="websearch",
),
)
self.assertSequenceEqual(searched, [self.verse1])
searched = line_qs.filter(
search=SearchQuery(
'"Sir Robin" ("kneecaps" OR "Camelot")',
search_type='websearch',
search_type="websearch",
),
)
self.assertSequenceEqual(searched, [self.verse0, self.verse1])
@skipUnlessDBFeature('has_websearch_to_tsquery')
@skipUnlessDBFeature("has_websearch_to_tsquery")
def test_web_search_with_config(self):
line_qs = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue', config='french'),
search=SearchVector("scene__setting", "dialogue", config="french"),
)
searched = line_qs.filter(
search=SearchQuery('cadeau -beau', search_type='websearch', config='french'),
search=SearchQuery(
"cadeau -beau", search_type="websearch", config="french"
),
)
self.assertSequenceEqual(searched, [])
searched = line_qs.filter(
search=SearchQuery('beau cadeau', search_type='websearch', config='french'),
search=SearchQuery("beau cadeau", search_type="websearch", config="french"),
)
self.assertSequenceEqual(searched, [self.french])
def test_bad_search_type(self):
with self.assertRaisesMessage(ValueError, "Unknown search_type argument 'foo'."):
SearchQuery('kneecaps', search_type='foo')
with self.assertRaisesMessage(
ValueError, "Unknown search_type argument 'foo'."
):
SearchQuery("kneecaps", search_type="foo")
def test_config_query_explicit(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue', config='french'),
).filter(search=SearchQuery('cadeaux', config='french'))
search=SearchVector("scene__setting", "dialogue", config="french"),
).filter(search=SearchQuery("cadeaux", config="french"))
self.assertSequenceEqual(searched, [self.french])
def test_config_query_implicit(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue', config='french'),
).filter(search='cadeaux')
search=SearchVector("scene__setting", "dialogue", config="french"),
).filter(search="cadeaux")
self.assertSequenceEqual(searched, [self.french])
def test_config_from_field_explicit(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue', config=F('dialogue_config')),
).filter(search=SearchQuery('cadeaux', config=F('dialogue_config')))
search=SearchVector(
"scene__setting", "dialogue", config=F("dialogue_config")
),
).filter(search=SearchQuery("cadeaux", config=F("dialogue_config")))
self.assertSequenceEqual(searched, [self.french])
def test_config_from_field_implicit(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue', config=F('dialogue_config')),
).filter(search='cadeaux')
search=SearchVector(
"scene__setting", "dialogue", config=F("dialogue_config")
),
).filter(search="cadeaux")
self.assertSequenceEqual(searched, [self.french])
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class TestCombinations(GrailTestData, PostgreSQLTestCase):
def test_vector_add(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting') + SearchVector('character__name'),
).filter(search='bedemir')
self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck])
search=SearchVector("scene__setting") + SearchVector("character__name"),
).filter(search="bedemir")
self.assertCountEqual(
searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]
)
def test_vector_add_multi(self):
searched = Line.objects.annotate(
search=(
SearchVector('scene__setting') +
SearchVector('character__name') +
SearchVector('dialogue')
SearchVector("scene__setting")
+ SearchVector("character__name")
+ SearchVector("dialogue")
),
).filter(search='bedemir')
self.assertCountEqual(searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck])
).filter(search="bedemir")
self.assertCountEqual(
searched, [self.bedemir0, self.bedemir1, self.crowd, self.witch, self.duck]
)
def test_vector_combined_mismatch(self):
msg = (
'SearchVector can only be combined with other SearchVector '
'instances, got NoneType.'
"SearchVector can only be combined with other SearchVector "
"instances, got NoneType."
)
with self.assertRaisesMessage(TypeError, msg):
Line.objects.filter(dialogue__search=None + SearchVector('character__name'))
Line.objects.filter(dialogue__search=None + SearchVector("character__name"))
def test_combine_different_vector_configs(self):
searched = Line.objects.annotate(
search=(
SearchVector('dialogue', config='english') +
SearchVector('dialogue', config='french')
SearchVector("dialogue", config="english")
+ SearchVector("dialogue", config="french")
),
).filter(
search=SearchQuery('cadeaux', config='french') | SearchQuery('nostrils')
search=SearchQuery("cadeaux", config="french") | SearchQuery("nostrils")
)
self.assertCountEqual(searched, [self.french, self.verse2])
def test_query_and(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search=SearchQuery('bedemir') & SearchQuery('scales'))
search=SearchVector("scene__setting", "dialogue"),
).filter(search=SearchQuery("bedemir") & SearchQuery("scales"))
self.assertSequenceEqual(searched, [self.bedemir0])
def test_query_multiple_and(self):
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search=SearchQuery('bedemir') & SearchQuery('scales') & SearchQuery('nostrils'))
search=SearchVector("scene__setting", "dialogue"),
).filter(
search=SearchQuery("bedemir")
& SearchQuery("scales")
& SearchQuery("nostrils")
)
self.assertSequenceEqual(searched, [])
searched = Line.objects.annotate(
search=SearchVector('scene__setting', 'dialogue'),
).filter(search=SearchQuery('shall') & SearchQuery('use') & SearchQuery('larger'))
search=SearchVector("scene__setting", "dialogue"),
).filter(
search=SearchQuery("shall") & SearchQuery("use") & SearchQuery("larger")
)
self.assertSequenceEqual(searched, [self.bedemir0])
def test_query_or(self):
searched = Line.objects.filter(dialogue__search=SearchQuery('kneecaps') | SearchQuery('nostrils'))
searched = Line.objects.filter(
dialogue__search=SearchQuery("kneecaps") | SearchQuery("nostrils")
)
self.assertCountEqual(searched, [self.verse1, self.verse2])
def test_query_multiple_or(self):
searched = Line.objects.filter(
dialogue__search=SearchQuery('kneecaps') | SearchQuery('nostrils') | SearchQuery('Sir Robin')
dialogue__search=SearchQuery("kneecaps")
| SearchQuery("nostrils")
| SearchQuery("Sir Robin")
)
self.assertCountEqual(searched, [self.verse1, self.verse2, self.verse0])
def test_query_invert(self):
searched = Line.objects.filter(character=self.minstrel, dialogue__search=~SearchQuery('kneecaps'))
searched = Line.objects.filter(
character=self.minstrel, dialogue__search=~SearchQuery("kneecaps")
)
self.assertCountEqual(searched, [self.verse0, self.verse2])
def test_combine_different_configs(self):
searched = Line.objects.filter(
dialogue__search=(
SearchQuery('cadeau', config='french') |
SearchQuery('nostrils', config='english')
SearchQuery("cadeau", config="french")
| SearchQuery("nostrils", config="english")
)
)
self.assertCountEqual(searched, [self.french, self.verse2])
@ -378,8 +437,8 @@ class TestCombinations(GrailTestData, PostgreSQLTestCase):
def test_combined_configs(self):
searched = Line.objects.filter(
dialogue__search=(
SearchQuery('nostrils', config='simple') &
SearchQuery('bowels', config='simple')
SearchQuery("nostrils", config="simple")
& SearchQuery("bowels", config="simple")
),
)
self.assertSequenceEqual(searched, [self.verse2])
@ -387,63 +446,96 @@ class TestCombinations(GrailTestData, PostgreSQLTestCase):
def test_combine_raw_phrase(self):
searched = Line.objects.filter(
dialogue__search=(
SearchQuery('burn:*', search_type='raw', config='simple') |
SearchQuery('rode forth from Camelot', search_type='phrase')
SearchQuery("burn:*", search_type="raw", config="simple")
| SearchQuery("rode forth from Camelot", search_type="phrase")
)
)
self.assertCountEqual(searched, [self.verse0, self.verse1, self.verse2])
def test_query_combined_mismatch(self):
msg = (
'SearchQuery can only be combined with other SearchQuery '
'instances, got NoneType.'
"SearchQuery can only be combined with other SearchQuery "
"instances, got NoneType."
)
with self.assertRaisesMessage(TypeError, msg):
Line.objects.filter(dialogue__search=None | SearchQuery('kneecaps'))
Line.objects.filter(dialogue__search=None | SearchQuery("kneecaps"))
with self.assertRaisesMessage(TypeError, msg):
Line.objects.filter(dialogue__search=None & SearchQuery('kneecaps'))
Line.objects.filter(dialogue__search=None & SearchQuery("kneecaps"))
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
def test_ranking(self):
searched = Line.objects.filter(character=self.minstrel).annotate(
rank=SearchRank(SearchVector('dialogue'), SearchQuery('brave sir robin')),
).order_by('rank')
searched = (
Line.objects.filter(character=self.minstrel)
.annotate(
rank=SearchRank(
SearchVector("dialogue"), SearchQuery("brave sir robin")
),
)
.order_by("rank")
)
self.assertSequenceEqual(searched, [self.verse2, self.verse1, self.verse0])
def test_rank_passing_untyped_args(self):
searched = Line.objects.filter(character=self.minstrel).annotate(
rank=SearchRank('dialogue', 'brave sir robin'),
).order_by('rank')
searched = (
Line.objects.filter(character=self.minstrel)
.annotate(
rank=SearchRank("dialogue", "brave sir robin"),
)
.order_by("rank")
)
self.assertSequenceEqual(searched, [self.verse2, self.verse1, self.verse0])
def test_weights_in_vector(self):
vector = SearchVector('dialogue', weight='A') + SearchVector('character__name', weight='D')
searched = Line.objects.filter(scene=self.witch_scene).annotate(
rank=SearchRank(vector, SearchQuery('witch')),
).order_by('-rank')[:2]
vector = SearchVector("dialogue", weight="A") + SearchVector(
"character__name", weight="D"
)
searched = (
Line.objects.filter(scene=self.witch_scene)
.annotate(
rank=SearchRank(vector, SearchQuery("witch")),
)
.order_by("-rank")[:2]
)
self.assertSequenceEqual(searched, [self.crowd, self.witch])
vector = SearchVector('dialogue', weight='D') + SearchVector('character__name', weight='A')
searched = Line.objects.filter(scene=self.witch_scene).annotate(
rank=SearchRank(vector, SearchQuery('witch')),
).order_by('-rank')[:2]
vector = SearchVector("dialogue", weight="D") + SearchVector(
"character__name", weight="A"
)
searched = (
Line.objects.filter(scene=self.witch_scene)
.annotate(
rank=SearchRank(vector, SearchQuery("witch")),
)
.order_by("-rank")[:2]
)
self.assertSequenceEqual(searched, [self.witch, self.crowd])
def test_ranked_custom_weights(self):
vector = SearchVector('dialogue', weight='D') + SearchVector('character__name', weight='A')
searched = Line.objects.filter(scene=self.witch_scene).annotate(
rank=SearchRank(vector, SearchQuery('witch'), weights=[1, 0, 0, 0.5]),
).order_by('-rank')[:2]
vector = SearchVector("dialogue", weight="D") + SearchVector(
"character__name", weight="A"
)
searched = (
Line.objects.filter(scene=self.witch_scene)
.annotate(
rank=SearchRank(vector, SearchQuery("witch"), weights=[1, 0, 0, 0.5]),
)
.order_by("-rank")[:2]
)
self.assertSequenceEqual(searched, [self.crowd, self.witch])
def test_ranking_chaining(self):
searched = Line.objects.filter(character=self.minstrel).annotate(
rank=SearchRank(SearchVector('dialogue'), SearchQuery('brave sir robin')),
).filter(rank__gt=0.3)
searched = (
Line.objects.filter(character=self.minstrel)
.annotate(
rank=SearchRank(
SearchVector("dialogue"), SearchQuery("brave sir robin")
),
)
.filter(rank__gt=0.3)
)
self.assertSequenceEqual(searched, [self.verse0])
def test_cover_density_ranking(self):
@ -451,17 +543,21 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
scene=self.robin,
character=self.minstrel,
dialogue=(
'Bravely taking to his feet, he beat a very brave retreat. '
'A brave retreat brave Sir Robin.'
)
)
searched = Line.objects.filter(character=self.minstrel).annotate(
rank=SearchRank(
SearchVector('dialogue'),
SearchQuery('brave robin'),
cover_density=True,
"Bravely taking to his feet, he beat a very brave retreat. "
"A brave retreat brave Sir Robin."
),
).order_by('rank', '-pk')
)
searched = (
Line.objects.filter(character=self.minstrel)
.annotate(
rank=SearchRank(
SearchVector("dialogue"),
SearchQuery("brave robin"),
cover_density=True,
),
)
.order_by("rank", "-pk")
)
self.assertSequenceEqual(
searched,
[self.verse2, not_dense_verse, self.verse1, self.verse0],
@ -471,16 +567,20 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
short_verse = Line.objects.create(
scene=self.robin,
character=self.minstrel,
dialogue='A brave retreat brave Sir Robin.',
dialogue="A brave retreat brave Sir Robin.",
)
searched = (
Line.objects.filter(character=self.minstrel)
.annotate(
rank=SearchRank(
SearchVector("dialogue"),
SearchQuery("brave sir robin"),
# Divide the rank by the document length.
normalization=2,
),
)
.order_by("rank")
)
searched = Line.objects.filter(character=self.minstrel).annotate(
rank=SearchRank(
SearchVector('dialogue'),
SearchQuery('brave sir robin'),
# Divide the rank by the document length.
normalization=2,
),
).order_by('rank')
self.assertSequenceEqual(
searched,
[self.verse2, self.verse1, self.verse0, short_verse],
@ -490,17 +590,21 @@ class TestRankingAndWeights(GrailTestData, PostgreSQLTestCase):
short_verse = Line.objects.create(
scene=self.robin,
character=self.minstrel,
dialogue='A brave retreat brave Sir Robin.',
dialogue="A brave retreat brave Sir Robin.",
)
searched = (
Line.objects.filter(character=self.minstrel)
.annotate(
rank=SearchRank(
SearchVector("dialogue"),
SearchQuery("brave sir robin"),
# Divide the rank by the document length and by the number of
# unique words in document.
normalization=Value(2).bitor(Value(8)),
),
)
.order_by("rank")
)
searched = Line.objects.filter(character=self.minstrel).annotate(
rank=SearchRank(
SearchVector('dialogue'),
SearchQuery('brave sir robin'),
# Divide the rank by the document length and by the number of
# unique words in document.
normalization=Value(2).bitor(Value(8)),
),
).order_by('rank')
self.assertSequenceEqual(
searched,
[self.verse2, self.verse1, self.verse0, short_verse],
@ -513,13 +617,16 @@ class SearchVectorIndexTests(PostgreSQLTestCase):
# This test should be moved to test_indexes and use a functional
# index instead once support lands (see #26167).
query = Line.objects.all().query
resolved = SearchVector('id', 'dialogue', config='english').resolve_expression(query)
resolved = SearchVector("id", "dialogue", config="english").resolve_expression(
query
)
compiler = query.get_compiler(connection.alias)
sql, params = resolved.as_sql(compiler, connection)
# Indexed function must be IMMUTABLE.
with connection.cursor() as cursor:
cursor.execute(
'CREATE INDEX search_vector_index ON %s USING GIN (%s)' % (Line._meta.db_table, sql),
"CREATE INDEX search_vector_index ON %s USING GIN (%s)"
% (Line._meta.db_table, sql),
params,
)
@ -527,24 +634,26 @@ class SearchVectorIndexTests(PostgreSQLTestCase):
class SearchQueryTests(PostgreSQLSimpleTestCase):
def test_str(self):
tests = (
(~SearchQuery('a'), "~SearchQuery(Value('a'))"),
(~SearchQuery("a"), "~SearchQuery(Value('a'))"),
(
(SearchQuery('a') | SearchQuery('b')) & (SearchQuery('c') | SearchQuery('d')),
(SearchQuery("a") | SearchQuery("b"))
& (SearchQuery("c") | SearchQuery("d")),
"((SearchQuery(Value('a')) || SearchQuery(Value('b'))) && "
"(SearchQuery(Value('c')) || SearchQuery(Value('d'))))",
),
(
SearchQuery('a') & (SearchQuery('b') | SearchQuery('c')),
SearchQuery("a") & (SearchQuery("b") | SearchQuery("c")),
"(SearchQuery(Value('a')) && (SearchQuery(Value('b')) || "
"SearchQuery(Value('c'))))",
),
(
(SearchQuery('a') | SearchQuery('b')) & SearchQuery('c'),
(SearchQuery("a") | SearchQuery("b")) & SearchQuery("c"),
"((SearchQuery(Value('a')) || SearchQuery(Value('b'))) && "
"SearchQuery(Value('c')))"
"SearchQuery(Value('c')))",
),
(
SearchQuery('a') & (SearchQuery('b') & (SearchQuery('c') | SearchQuery('d'))),
SearchQuery("a")
& (SearchQuery("b") & (SearchQuery("c") | SearchQuery("d"))),
"(SearchQuery(Value('a')) && (SearchQuery(Value('b')) && "
"(SearchQuery(Value('c')) || SearchQuery(Value('d')))))",
),
@ -554,109 +663,112 @@ class SearchQueryTests(PostgreSQLSimpleTestCase):
self.assertEqual(str(query), expected_str)
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class SearchHeadlineTests(GrailTestData, PostgreSQLTestCase):
def test_headline(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
F('dialogue'),
SearchQuery('brave sir robin'),
config=SearchConfig('english'),
F("dialogue"),
SearchQuery("brave sir robin"),
config=SearchConfig("english"),
),
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
'<b>Robin</b>. He was not at all afraid to be killed in nasty '
'ways. <b>Brave</b>, <b>brave</b>, <b>brave</b>, <b>brave</b> '
'<b>Sir</b> <b>Robin</b>',
"<b>Robin</b>. He was not at all afraid to be killed in nasty "
"ways. <b>Brave</b>, <b>brave</b>, <b>brave</b>, <b>brave</b> "
"<b>Sir</b> <b>Robin</b>",
)
def test_headline_untyped_args(self):
searched = Line.objects.annotate(
headline=SearchHeadline('dialogue', 'killed', config='english'),
headline=SearchHeadline("dialogue", "killed", config="english"),
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
'Robin. He was not at all afraid to be <b>killed</b> in nasty '
'ways. Brave, brave, brave, brave Sir Robin',
"Robin. He was not at all afraid to be <b>killed</b> in nasty "
"ways. Brave, brave, brave, brave Sir Robin",
)
def test_headline_with_config(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
'dialogue',
SearchQuery('cadeaux', config='french'),
config='french',
"dialogue",
SearchQuery("cadeaux", config="french"),
config="french",
),
).get(pk=self.french.pk)
self.assertEqual(
searched.headline,
'Oh. Un beau <b>cadeau</b>. Oui oui.',
"Oh. Un beau <b>cadeau</b>. Oui oui.",
)
def test_headline_with_config_from_field(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
'dialogue',
SearchQuery('cadeaux', config=F('dialogue_config')),
config=F('dialogue_config'),
"dialogue",
SearchQuery("cadeaux", config=F("dialogue_config")),
config=F("dialogue_config"),
),
).get(pk=self.french.pk)
self.assertEqual(
searched.headline,
'Oh. Un beau <b>cadeau</b>. Oui oui.',
"Oh. Un beau <b>cadeau</b>. Oui oui.",
)
def test_headline_separator_options(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
'dialogue',
'brave sir robin',
start_sel='<span>',
stop_sel='</span>',
"dialogue",
"brave sir robin",
start_sel="<span>",
stop_sel="</span>",
),
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
'<span>Robin</span>. He was not at all afraid to be killed in '
'nasty ways. <span>Brave</span>, <span>brave</span>, <span>brave'
'</span>, <span>brave</span> <span>Sir</span> <span>Robin</span>',
"<span>Robin</span>. He was not at all afraid to be killed in "
"nasty ways. <span>Brave</span>, <span>brave</span>, <span>brave"
"</span>, <span>brave</span> <span>Sir</span> <span>Robin</span>",
)
def test_headline_highlight_all_option(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
'dialogue',
SearchQuery('brave sir robin', config='english'),
"dialogue",
SearchQuery("brave sir robin", config="english"),
highlight_all=True,
),
).get(pk=self.verse0.pk)
self.assertIn(
'<b>Bravely</b> bold <b>Sir</b> <b>Robin</b>, rode forth from '
'Camelot. He was not afraid to die, o ',
"<b>Bravely</b> bold <b>Sir</b> <b>Robin</b>, rode forth from "
"Camelot. He was not afraid to die, o ",
searched.headline,
)
def test_headline_short_word_option(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
'dialogue',
SearchQuery('Camelot', config='english'),
"dialogue",
SearchQuery("Camelot", config="english"),
short_word=5,
min_words=8,
),
).get(pk=self.verse0.pk)
self.assertEqual(searched.headline, (
'<b>Camelot</b>. He was not afraid to die, o Brave Sir Robin. He '
'was not at all afraid'
))
self.assertEqual(
searched.headline,
(
"<b>Camelot</b>. He was not afraid to die, o Brave Sir Robin. He "
"was not at all afraid"
),
)
def test_headline_fragments_words_options(self):
searched = Line.objects.annotate(
headline=SearchHeadline(
'dialogue',
SearchQuery('brave sir robin', config='english'),
fragment_delimiter='...<br>',
"dialogue",
SearchQuery("brave sir robin", config="english"),
fragment_delimiter="...<br>",
max_fragments=4,
max_words=3,
min_words=1,
@ -664,8 +776,8 @@ class SearchHeadlineTests(GrailTestData, PostgreSQLTestCase):
).get(pk=self.verse0.pk)
self.assertEqual(
searched.headline,
'<b>Sir</b> <b>Robin</b>, rode...<br>'
'<b>Brave</b> <b>Sir</b> <b>Robin</b>...<br>'
'<b>Brave</b>, <b>brave</b>, <b>brave</b>...<br>'
'<b>brave</b> <b>Sir</b> <b>Robin</b>',
"<b>Sir</b> <b>Robin</b>, rode...<br>"
"<b>Brave</b> <b>Sir</b> <b>Robin</b>...<br>"
"<b>Brave</b>, <b>brave</b>, <b>brave</b>...<br>"
"<b>brave</b> <b>Sir</b> <b>Robin</b>",
)

View file

@ -4,14 +4,15 @@ from . import PostgreSQLTestCase
try:
from django.contrib.postgres.signals import (
get_citext_oids, get_hstore_oids, register_type_handlers,
get_citext_oids,
get_hstore_oids,
register_type_handlers,
)
except ImportError:
pass # pyscogp2 isn't installed.
class OIDTests(PostgreSQLTestCase):
def assertOIDs(self, oids):
self.assertIsInstance(oids, tuple)
self.assertGreater(len(oids), 0)

View file

@ -5,65 +5,74 @@ from .models import CharFieldModel, TextFieldModel
try:
from django.contrib.postgres.search import (
TrigramDistance, TrigramSimilarity, TrigramWordDistance,
TrigramDistance,
TrigramSimilarity,
TrigramWordDistance,
TrigramWordSimilarity,
)
except ImportError:
pass
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class TrigramTest(PostgreSQLTestCase):
Model = CharFieldModel
@classmethod
def setUpTestData(cls):
cls.Model.objects.bulk_create([
cls.Model(field='Matthew'),
cls.Model(field='Cat sat on mat.'),
cls.Model(field='Dog sat on rug.'),
])
cls.Model.objects.bulk_create(
[
cls.Model(field="Matthew"),
cls.Model(field="Cat sat on mat."),
cls.Model(field="Dog sat on rug."),
]
)
def test_trigram_search(self):
self.assertQuerysetEqual(
self.Model.objects.filter(field__trigram_similar='Mathew'),
['Matthew'],
self.Model.objects.filter(field__trigram_similar="Mathew"),
["Matthew"],
transform=lambda instance: instance.field,
)
def test_trigram_word_search(self):
obj = self.Model.objects.create(
field='Gumby rides on the path of Middlesbrough',
field="Gumby rides on the path of Middlesbrough",
)
self.assertSequenceEqual(
self.Model.objects.filter(field__trigram_word_similar='Middlesborough'),
self.Model.objects.filter(field__trigram_word_similar="Middlesborough"),
[obj],
)
def test_trigram_similarity(self):
search = 'Bat sat on cat.'
search = "Bat sat on cat."
# Round result of similarity because PostgreSQL 12+ uses greater
# precision.
self.assertQuerysetEqual(
self.Model.objects.filter(
field__trigram_similar=search,
).annotate(similarity=TrigramSimilarity('field', search)).order_by('-similarity'),
[('Cat sat on mat.', 0.625), ('Dog sat on rug.', 0.333333)],
)
.annotate(similarity=TrigramSimilarity("field", search))
.order_by("-similarity"),
[("Cat sat on mat.", 0.625), ("Dog sat on rug.", 0.333333)],
transform=lambda instance: (instance.field, round(instance.similarity, 6)),
ordered=True,
)
def test_trigram_word_similarity(self):
search = 'mat'
search = "mat"
self.assertSequenceEqual(
self.Model.objects.filter(
field__trigram_word_similar=search,
).annotate(
word_similarity=TrigramWordSimilarity(search, 'field'),
).values('field', 'word_similarity').order_by('-word_similarity'),
)
.annotate(
word_similarity=TrigramWordSimilarity(search, "field"),
)
.values("field", "word_similarity")
.order_by("-word_similarity"),
[
{'field': 'Cat sat on mat.', 'word_similarity': 1.0},
{'field': 'Matthew', 'word_similarity': 0.75},
{"field": "Cat sat on mat.", "word_similarity": 1.0},
{"field": "Matthew", "word_similarity": 0.75},
],
)
@ -72,9 +81,11 @@ class TrigramTest(PostgreSQLTestCase):
# precision.
self.assertQuerysetEqual(
self.Model.objects.annotate(
distance=TrigramDistance('field', 'Bat sat on cat.'),
).filter(distance__lte=0.7).order_by('distance'),
[('Cat sat on mat.', 0.375), ('Dog sat on rug.', 0.666667)],
distance=TrigramDistance("field", "Bat sat on cat."),
)
.filter(distance__lte=0.7)
.order_by("distance"),
[("Cat sat on mat.", 0.375), ("Dog sat on rug.", 0.666667)],
transform=lambda instance: (instance.field, round(instance.distance, 6)),
ordered=True,
)
@ -82,13 +93,16 @@ class TrigramTest(PostgreSQLTestCase):
def test_trigram_word_similarity_alternate(self):
self.assertSequenceEqual(
self.Model.objects.annotate(
word_distance=TrigramWordDistance('mat', 'field'),
).filter(
word_distance=TrigramWordDistance("mat", "field"),
)
.filter(
word_distance__lte=0.7,
).values('field', 'word_distance').order_by('word_distance'),
)
.values("field", "word_distance")
.order_by("word_distance"),
[
{'field': 'Cat sat on mat.', 'word_distance': 0},
{'field': 'Matthew', 'word_distance': 0.25},
{"field": "Cat sat on mat.", "word_distance": 0},
{"field": "Matthew", "word_distance": 0.25},
],
)
@ -97,4 +111,5 @@ class TrigramTextFieldTest(TrigramTest):
"""
TextField has the same behavior as CharField regarding trigram lookups.
"""
Model = TextFieldModel

View file

@ -5,25 +5,27 @@ from . import PostgreSQLTestCase
from .models import CharFieldModel, TextFieldModel
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
class UnaccentTest(PostgreSQLTestCase):
Model = CharFieldModel
@classmethod
def setUpTestData(cls):
cls.Model.objects.bulk_create([
cls.Model(field="àéÖ"),
cls.Model(field="aeO"),
cls.Model(field="aeo"),
])
cls.Model.objects.bulk_create(
[
cls.Model(field="àéÖ"),
cls.Model(field="aeO"),
cls.Model(field="aeo"),
]
)
def test_unaccent(self):
self.assertQuerysetEqual(
self.Model.objects.filter(field__unaccent="aeO"),
["àéÖ", "aeO"],
transform=lambda instance: instance.field,
ordered=False
ordered=False,
)
def test_unaccent_chained(self):
@ -35,39 +37,39 @@ class UnaccentTest(PostgreSQLTestCase):
self.Model.objects.filter(field__unaccent__iexact="aeO"),
["àéÖ", "aeO", "aeo"],
transform=lambda instance: instance.field,
ordered=False
ordered=False,
)
self.assertQuerysetEqual(
self.Model.objects.filter(field__unaccent__endswith="éÖ"),
["àéÖ", "aeO"],
transform=lambda instance: instance.field,
ordered=False
ordered=False,
)
def test_unaccent_with_conforming_strings_off(self):
"""SQL is valid when standard_conforming_strings is off."""
with connection.cursor() as cursor:
cursor.execute('SHOW standard_conforming_strings')
disable_conforming_strings = cursor.fetchall()[0][0] == 'on'
cursor.execute("SHOW standard_conforming_strings")
disable_conforming_strings = cursor.fetchall()[0][0] == "on"
if disable_conforming_strings:
cursor.execute('SET standard_conforming_strings TO off')
cursor.execute("SET standard_conforming_strings TO off")
try:
self.assertQuerysetEqual(
self.Model.objects.filter(field__unaccent__endswith='éÖ'),
['àéÖ', 'aeO'],
self.Model.objects.filter(field__unaccent__endswith="éÖ"),
["àéÖ", "aeO"],
transform=lambda instance: instance.field,
ordered=False,
)
finally:
if disable_conforming_strings:
cursor.execute('SET standard_conforming_strings TO on')
cursor.execute("SET standard_conforming_strings TO on")
def test_unaccent_accentuated_needle(self):
self.assertQuerysetEqual(
self.Model.objects.filter(field__unaccent="aéÖ"),
["àéÖ", "aeO"],
transform=lambda instance: instance.field,
ordered=False
ordered=False,
)
@ -76,4 +78,5 @@ class UnaccentTextFieldTest(UnaccentTest):
TextField should have the exact same behavior as CharField
regarding unaccent lookups.
"""
Model = TextFieldModel