Fixed #31709 -- Added support for opclasses in ExclusionConstraint.

This commit is contained in:
Hannes Ljungberg 2020-06-14 20:50:39 +02:00 committed by Mariusz Felisiak
parent dcb4d79ef7
commit 0d6d4e78b1
4 changed files with 179 additions and 5 deletions

View file

@ -12,7 +12,7 @@ class ExclusionConstraint(BaseConstraint):
def __init__(
self, *, name, expressions, index_type=None, condition=None,
deferrable=None, include=None,
deferrable=None, include=None, opclasses=(),
):
if index_type and index_type.lower() not in {'gist', 'spgist'}:
raise ValueError(
@ -48,20 +48,37 @@ class ExclusionConstraint(BaseConstraint):
raise ValueError(
'Covering exclusion constraints only support GiST indexes.'
)
if not isinstance(opclasses, (list, tuple)):
raise ValueError(
'ExclusionConstraint.opclasses must be a list or tuple.'
)
if opclasses and len(expressions) != len(opclasses):
raise ValueError(
'ExclusionConstraint.expressions and '
'ExclusionConstraint.opclasses must have the same number of '
'elements.'
)
self.expressions = expressions
self.index_type = index_type or 'GIST'
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
super().__init__(name=name)
def _get_expression_sql(self, compiler, connection, query):
expressions = []
for expression, operator in self.expressions:
for idx, (expression, operator) in enumerate(self.expressions):
if isinstance(expression, str):
expression = F(expression)
expression = expression.resolve_expression(query=query)
sql, params = expression.as_sql(compiler, connection)
try:
opclass = self.opclasses[idx]
if opclass:
sql = '%s %s' % (sql, opclass)
except IndexError:
pass
expressions.append('%s WITH %s' % (sql % params, operator))
return expressions
@ -119,6 +136,8 @@ class ExclusionConstraint(BaseConstraint):
kwargs['deferrable'] = self.deferrable
if self.include:
kwargs['include'] = self.include
if self.opclasses:
kwargs['opclasses'] = self.opclasses
return path, args, kwargs
def __eq__(self, other):
@ -129,16 +148,18 @@ class ExclusionConstraint(BaseConstraint):
self.expressions == other.expressions and
self.condition == other.condition and
self.deferrable == other.deferrable and
self.include == other.include
self.include == other.include and
self.opclasses == other.opclasses
)
return super().__eq__(other)
def __repr__(self):
return '<%s: index_type=%s, expressions=%s%s%s%s>' % (
return '<%s: index_type=%s, expressions=%s%s%s%s%s>' % (
self.__class__.__qualname__,
self.index_type,
self.expressions,
'' if self.condition is None else ', condition=%s' % self.condition,
'' if self.deferrable is None else ', deferrable=%s' % self.deferrable,
'' if not self.include else ', include=%s' % repr(self.include),
'' if not self.opclasses else ', opclasses=%s' % repr(self.opclasses),
)