This commit is contained in:
Samriddh Tripathi 2025-11-17 13:45:10 +01:00 committed by GitHub
commit 734de7f157
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 22 additions and 18 deletions

View file

@ -429,7 +429,7 @@ class GeometryType(GeoFuncMixin, Transform):
lookup_name = "geom_type"
def as_oracle(self, compiler, connection, **extra_context):
lhs, params = compiler.compile(self.lhs)
lhs, params = self.process_lhs(compiler, connection)
sql = (
"(SELECT DECODE("
f"SDO_GEOMETRY.GET_GTYPE({lhs}),"

View file

@ -8,7 +8,7 @@ from django.utils.regex_helper import _lazy_re_compile
class RasterBandTransform(Transform):
def as_sql(self, compiler, connection):
return compiler.compile(self.lhs)
return self.process_lhs(compiler, connection)
class GISLookup(Lookup):

View file

@ -304,7 +304,7 @@ class ArrayLenTransform(Transform):
output_field = IntegerField()
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
lhs, params = self.process_lhs(compiler, connection)
# Distinguish NULL and empty arrays
return (
"CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
@ -336,7 +336,7 @@ class IndexTransform(Transform):
self.base_field = base_field
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
lhs, params = self.process_lhs(compiler, connection)
if not lhs.endswith("]"):
lhs = "(%s)" % lhs
return "%s[%%s]" % lhs, (*params, self.index)
@ -362,7 +362,7 @@ class SliceTransform(Transform):
self.end = end
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
lhs, params = self.process_lhs(compiler, connection)
# self.start is set to 1 if slice start is not provided.
if self.end is None:
return f"({lhs})[%s:]", (*params, self.start)

View file

@ -87,7 +87,7 @@ class KeyTransform(Transform):
self.key_name = key_name
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
lhs, params = self.process_lhs(compiler, connection)
return "(%s -> %%s)" % lhs, (*params, self.key_name)

View file

@ -51,7 +51,7 @@ class Extract(TimezoneMixin, Transform):
super().__init__(expression, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
sql, params = self.process_lhs(compiler, connection)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname()
@ -258,7 +258,7 @@ class TruncBase(TimezoneMixin, Transform):
super().__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
sql, params = self.process_lhs(compiler, connection)
tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
@ -405,7 +405,7 @@ class TruncDate(TruncBase):
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
sql, params = compiler.compile(self.lhs)
sql, params = self.process_lhs(compiler, connection)
tzname = self.get_tzname()
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
@ -417,7 +417,7 @@ class TruncTime(TruncBase):
def as_sql(self, compiler, connection):
# Cast to time rather than truncate to time.
sql, params = compiler.compile(self.lhs)
sql, params = self.process_lhs(compiler, connection)
tzname = self.get_tzname()
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)

View file

@ -215,6 +215,10 @@ class Transform(RegisterLookupMixin, Func):
def lhs(self):
return self.get_source_expressions()[0]
def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
return compiler.compile(lhs)
def get_bilateral_transforms(self):
if hasattr(self.lhs, "get_bilateral_transforms"):
bilateral_transforms = self.lhs.get_bilateral_transforms()

View file

@ -34,11 +34,11 @@ class Div3Transform(models.Transform):
lookup_name = "div3"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
lhs, lhs_params = self.process_lhs(compiler, connection)
return "(%s) %%%% 3" % lhs, lhs_params
def as_oracle(self, compiler, connection, **extra_context):
lhs, lhs_params = compiler.compile(self.lhs)
lhs, lhs_params = self.process_lhs(compiler, connection)
return "mod(%s, 3)" % lhs, lhs_params
@ -51,7 +51,7 @@ class Mult3BilateralTransform(models.Transform):
lookup_name = "mult3"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
lhs, lhs_params = self.process_lhs(compiler, connection)
return "3 * (%s)" % lhs, lhs_params
@ -59,7 +59,7 @@ class LastDigitTransform(models.Transform):
lookup_name = "lastdigit"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
lhs, lhs_params = self.process_lhs(compiler, connection)
return "SUBSTR(CAST(%s AS CHAR(2)), 2, 1)" % lhs, lhs_params
@ -68,7 +68,7 @@ class UpperBilateralTransform(models.Transform):
lookup_name = "upper"
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
lhs, lhs_params = self.process_lhs(compiler, connection)
return "UPPER(%s)" % lhs, lhs_params
@ -77,7 +77,7 @@ class YearTransform(models.Transform):
lookup_name = "testyear"
def as_sql(self, compiler, connection):
lhs_sql, params = compiler.compile(self.lhs)
lhs_sql, params = self.process_lhs(compiler, connection)
return connection.ops.date_extract_sql("year", lhs_sql, params)
@property
@ -219,7 +219,7 @@ class DateTimeTransform(models.Transform):
return models.DateTimeField()
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
lhs, params = self.process_lhs(compiler, connection)
return "from_unixtime({})".format(lhs), params
@ -618,7 +618,7 @@ class TrackCallsYearTransform(YearTransform):
call_order = []
def as_sql(self, compiler, connection):
lhs_sql, params = compiler.compile(self.lhs)
lhs_sql, params = self.process_lhs(compiler, connection)
return connection.ops.date_extract_sql("year", lhs_sql), params
@property