Added support for parameters in SELECT clauses.

This commit is contained in:
Aymeric Augustin 2013-02-13 14:47:44 +01:00
parent b4351d2890
commit 924a144ef8
14 changed files with 79 additions and 64 deletions

View file

@ -56,12 +56,13 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
lookup_info = self.geometry_functions.get(lookup_type, False) lookup_info = self.geometry_functions.get(lookup_type, False)
if lookup_info: if lookup_info:
return "%s(%s, %s)" % (lookup_info, geo_col, sql = "%s(%s, %s)" % (lookup_info, geo_col,
self.get_geom_placeholder(value, field.srid)) self.get_geom_placeholder(value, field.srid))
return sql, []
# TODO: Is this really necessary? MySQL can't handle NULL geometries # TODO: Is this really necessary? MySQL can't handle NULL geometries
# in its spatial indexes anyways. # in its spatial indexes anyways.
if lookup_type == 'isnull': if lookup_type == 'isnull':
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))

View file

@ -262,7 +262,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value)) return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value))
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
# Handling 'isnull' lookup type # Handling 'isnull' lookup type
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
@ -288,7 +288,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
def spatial_ref_sys(self): def spatial_ref_sys(self):
from django.contrib.gis.db.backends.oracle.models import SpatialRefSys from django.contrib.gis.db.backends.oracle.models import SpatialRefSys
return SpatialRefSys return SpatialRefSys
def modify_insert_params(self, placeholders, params): def modify_insert_params(self, placeholders, params):
"""Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial """Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial
backend due to #10888 backend due to #10888

View file

@ -560,7 +560,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
# Handling 'isnull' lookup type # Handling 'isnull' lookup type
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))

View file

@ -358,7 +358,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
return op.as_sql(geo_col, self.get_geom_placeholder(field, geom)) return op.as_sql(geo_col, self.get_geom_placeholder(field, geom))
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
# Handling 'isnull' lookup type # Handling 'isnull' lookup type
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))

View file

@ -16,7 +16,7 @@ class SpatialOperation(object):
self.extra = kwargs self.extra = kwargs
def as_sql(self, geo_col, geometry='%s'): def as_sql(self, geo_col, geometry='%s'):
return self.sql_template % self.params(geo_col, geometry) return self.sql_template % self.params(geo_col, geometry), []
def params(self, geo_col, geometry): def params(self, geo_col, geometry):
params = {'function' : self.function, params = {'function' : self.function,

View file

@ -22,13 +22,15 @@ class GeoAggregate(Aggregate):
raise ValueError('Geospatial aggregates only allowed on geometry fields.') raise ValueError('Geospatial aggregates only allowed on geometry fields.')
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
"Return the aggregate, rendered as SQL." "Return the aggregate, rendered as SQL with parameters."
if connection.ops.oracle: if connection.ops.oracle:
self.extra['tolerance'] = self.tolerance self.extra['tolerance'] = self.tolerance
params = []
if hasattr(self.col, 'as_sql'): if hasattr(self.col, 'as_sql'):
field_name = self.col.as_sql(qn, connection) field_name, params = self.col.as_sql(qn, connection)
elif isinstance(self.col, (list, tuple)): elif isinstance(self.col, (list, tuple)):
field_name = '.'.join([qn(c) for c in self.col]) field_name = '.'.join([qn(c) for c in self.col])
else: else:
@ -36,13 +38,13 @@ class GeoAggregate(Aggregate):
sql_template, sql_function = connection.ops.spatial_aggregate_sql(self) sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
params = { substitutions = {
'function': sql_function, 'function': sql_function,
'field': field_name 'field': field_name
} }
params.update(self.extra) substitutions.update(self.extra)
return sql_template % params return sql_template % substitutions, params
class Collect(GeoAggregate): class Collect(GeoAggregate):
pass pass

View file

@ -33,6 +33,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias)) result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
for alias, col in six.iteritems(self.query.extra_select)] for alias, col in six.iteritems(self.query.extra_select)]
params = []
aliases = set(self.query.extra_select.keys()) aliases = set(self.query.extra_select.keys())
if with_aliases: if with_aliases:
col_aliases = aliases.copy() col_aliases = aliases.copy()
@ -63,7 +64,9 @@ class GeoSQLCompiler(compiler.SQLCompiler):
aliases.add(r) aliases.add(r)
col_aliases.add(col[1]) col_aliases.add(col[1])
else: else:
result.append(col.as_sql(qn, self.connection)) col_sql, col_params = col.as_sql(qn, self.connection)
result.append(col_sql)
params.extend(col_params)
if hasattr(col, 'alias'): if hasattr(col, 'alias'):
aliases.add(col.alias) aliases.add(col.alias)
@ -76,15 +79,13 @@ class GeoSQLCompiler(compiler.SQLCompiler):
aliases.update(new_aliases) aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length() max_name_length = self.connection.ops.max_name_length()
result.extend([ for alias, aggregate in self.query.aggregate_select.items():
'%s%s' % ( agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection), if alias is None:
alias is not None result.append(agg_sql)
and ' AS %s' % qn(truncate_name(alias, max_name_length)) else:
or '' result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
) params.extend(agg_params)
for alias, aggregate in self.query.aggregate_select.items()
])
# This loop customized for GeoQuery. # This loop customized for GeoQuery.
for (table, col), field in self.query.related_select_cols: for (table, col), field in self.query.related_select_cols:
@ -100,7 +101,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
col_aliases.add(col) col_aliases.add(col)
self._select_aliases = aliases self._select_aliases = aliases
return result return result, params
def get_default_columns(self, with_aliases=False, col_aliases=None, def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, from_parent=None): start_alias=None, opts=None, as_pairs=False, from_parent=None):

View file

@ -44,8 +44,9 @@ class GeoWhereNode(WhereNode):
lvalue, lookup_type, value_annot, params_or_value = child lvalue, lookup_type, value_annot, params_or_value = child
if isinstance(lvalue, GeoConstraint): if isinstance(lvalue, GeoConstraint):
data, params = lvalue.process(lookup_type, params_or_value, connection) data, params = lvalue.process(lookup_type, params_or_value, connection)
spatial_sql = connection.ops.spatial_lookup_sql(data, lookup_type, params_or_value, lvalue.field, qn) spatial_sql, spatial_params = connection.ops.spatial_lookup_sql(
return spatial_sql, params data, lookup_type, params_or_value, lvalue.field, qn)
return spatial_sql, spatial_params + params
else: else:
return super(GeoWhereNode, self).make_atom(child, qn, connection) return super(GeoWhereNode, self).make_atom(child, qn, connection)

View file

@ -25,7 +25,7 @@ class QueryWrapper(object):
parameters. Can be used to pass opaque data to a where-clause, for example. parameters. Can be used to pass opaque data to a where-clause, for example.
""" """
def __init__(self, sql, params): def __init__(self, sql, params):
self.data = sql, params self.data = sql, list(params)
def as_sql(self, qn=None, connection=None): def as_sql(self, qn=None, connection=None):
return self.data return self.data

View file

@ -73,22 +73,23 @@ class Aggregate(object):
self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
"Return the aggregate, rendered as SQL." "Return the aggregate, rendered as SQL with parameters."
params = []
if hasattr(self.col, 'as_sql'): if hasattr(self.col, 'as_sql'):
field_name = self.col.as_sql(qn, connection) field_name, params = self.col.as_sql(qn, connection)
elif isinstance(self.col, (list, tuple)): elif isinstance(self.col, (list, tuple)):
field_name = '.'.join([qn(c) for c in self.col]) field_name = '.'.join([qn(c) for c in self.col])
else: else:
field_name = self.col field_name = self.col
params = { substitutions = {
'function': self.sql_function, 'function': self.sql_function,
'field': field_name 'field': field_name
} }
params.update(self.extra) substitutions.update(self.extra)
return self.sql_template % params return self.sql_template % substitutions, params
class Avg(Aggregate): class Avg(Aggregate):

View file

@ -74,7 +74,7 @@ class SQLCompiler(object):
# as the pre_sql_setup will modify query state in a way that forbids # as the pre_sql_setup will modify query state in a way that forbids
# another run of it. # another run of it.
self.refcounts_before = self.query.alias_refcount.copy() self.refcounts_before = self.query.alias_refcount.copy()
out_cols = self.get_columns(with_col_aliases) out_cols, s_params = self.get_columns(with_col_aliases)
ordering, ordering_group_by = self.get_ordering() ordering, ordering_group_by = self.get_ordering()
distinct_fields = self.get_distinct() distinct_fields = self.get_distinct()
@ -97,6 +97,7 @@ class SQLCompiler(object):
result.append(self.connection.ops.distinct_sql(distinct_fields)) result.append(self.connection.ops.distinct_sql(distinct_fields))
result.append(', '.join(out_cols + self.query.ordering_aliases)) result.append(', '.join(out_cols + self.query.ordering_aliases))
params.extend(s_params)
result.append('FROM') result.append('FROM')
result.extend(from_) result.extend(from_)
@ -164,9 +165,10 @@ class SQLCompiler(object):
def get_columns(self, with_aliases=False): def get_columns(self, with_aliases=False):
""" """
Returns the list of columns to use in the select statement. If no Returns the list of columns to use in the select statement, as well as
columns have been specified, returns all columns relating to fields in a list any extra parameters that need to be included. If no columns
the model. have been specified, returns all columns relating to fields in the
model.
If 'with_aliases' is true, any column names that are duplicated If 'with_aliases' is true, any column names that are duplicated
(without the table names) are given unique aliases. This is needed in (without the table names) are given unique aliases. This is needed in
@ -175,6 +177,7 @@ class SQLCompiler(object):
qn = self.quote_name_unless_alias qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
params = []
aliases = set(self.query.extra_select.keys()) aliases = set(self.query.extra_select.keys())
if with_aliases: if with_aliases:
col_aliases = aliases.copy() col_aliases = aliases.copy()
@ -204,7 +207,9 @@ class SQLCompiler(object):
aliases.add(r) aliases.add(r)
col_aliases.add(col[1]) col_aliases.add(col[1])
else: else:
result.append(col.as_sql(qn, self.connection)) col_sql, col_params = col.as_sql(qn, self.connection)
result.append(col_sql)
params.extend(col_params)
if hasattr(col, 'alias'): if hasattr(col, 'alias'):
aliases.add(col.alias) aliases.add(col.alias)
@ -217,15 +222,13 @@ class SQLCompiler(object):
aliases.update(new_aliases) aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length() max_name_length = self.connection.ops.max_name_length()
result.extend([ for alias, aggregate in self.query.aggregate_select.items():
'%s%s' % ( agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
aggregate.as_sql(qn, self.connection), if alias is None:
alias is not None result.append(agg_sql)
and ' AS %s' % qn(truncate_name(alias, max_name_length)) else:
or '' result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
) params.extend(agg_params)
for alias, aggregate in self.query.aggregate_select.items()
])
for (table, col), _ in self.query.related_select_cols: for (table, col), _ in self.query.related_select_cols:
r = '%s.%s' % (qn(table), qn(col)) r = '%s.%s' % (qn(table), qn(col))
@ -240,7 +243,7 @@ class SQLCompiler(object):
col_aliases.add(col) col_aliases.add(col)
self._select_aliases = aliases self._select_aliases = aliases
return result return result, params
def get_default_columns(self, with_aliases=False, col_aliases=None, def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, from_parent=None): start_alias=None, opts=None, as_pairs=False, from_parent=None):
@ -545,14 +548,16 @@ class SQLCompiler(object):
seen = set() seen = set()
cols = self.query.group_by + select_cols cols = self.query.group_by + select_cols
for col in cols: for col in cols:
col_params = ()
if isinstance(col, (list, tuple)): if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1])) sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'): elif hasattr(col, 'as_sql'):
sql = col.as_sql(qn, self.connection) sql, col_params = col.as_sql(qn, self.connection)
else: else:
sql = '(%s)' % str(col) sql = '(%s)' % str(col)
if sql not in seen: if sql not in seen:
result.append(sql) result.append(sql)
params.extend(col_params)
seen.add(sql) seen.add(sql)
# Still, we need to add all stuff in ordering (except if the backend can # Still, we need to add all stuff in ordering (except if the backend can
@ -991,15 +996,17 @@ class SQLAggregateCompiler(SQLCompiler):
if qn is None: if qn is None:
qn = self.quote_name_unless_alias qn = self.quote_name_unless_alias
sql = ('SELECT %s FROM (%s) subquery' % ( sql, params = [], []
', '.join([ for aggregate in self.query.aggregate_select.values():
aggregate.as_sql(qn, self.connection) agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
for aggregate in self.query.aggregate_select.values() sql.append(agg_sql)
]), params.extend(agg_params)
self.query.subquery) sql = ', '.join(sql)
) params = tuple(params)
params = self.query.sub_params
return (sql, params) sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
params = params + self.query.sub_params
return sql, params
class SQLDateCompiler(SQLCompiler): class SQLDateCompiler(SQLCompiler):
def results_iter(self): def results_iter(self):

View file

@ -42,7 +42,7 @@ class Date(object):
col = '%s.%s' % tuple([qn(c) for c in self.col]) col = '%s.%s' % tuple([qn(c) for c in self.col])
else: else:
col = self.col col = self.col
return getattr(connection.ops, self.trunc_func)(self.lookup_type, col) return getattr(connection.ops, self.trunc_func)(self.lookup_type, col), []
class DateTime(Date): class DateTime(Date):
""" """

View file

@ -94,9 +94,9 @@ class SQLEvaluator(object):
if col is None: if col is None:
raise ValueError("Given node not found") raise ValueError("Given node not found")
if hasattr(col, 'as_sql'): if hasattr(col, 'as_sql'):
return col.as_sql(qn, connection), () return col.as_sql(qn, connection)
else: else:
return '%s.%s' % (qn(col[0]), qn(col[1])), () return '%s.%s' % (qn(col[0]), qn(col[1])), []
def evaluate_date_modifier_node(self, node, qn, connection): def evaluate_date_modifier_node(self, node, qn, connection):
timedelta = node.children.pop() timedelta = node.children.pop()

View file

@ -172,10 +172,10 @@ class WhereNode(tree.Node):
if isinstance(lvalue, tuple): if isinstance(lvalue, tuple):
# A direct database column lookup. # A direct database column lookup.
field_sql = self.sql_for_columns(lvalue, qn, connection) field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), []
else: else:
# A smart object with an as_sql() method. # A smart object with an as_sql() method.
field_sql = lvalue.as_sql(qn, connection) field_sql, field_params = lvalue.as_sql(qn, connection)
is_datetime_field = value_annotation is datetime.datetime is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
@ -186,6 +186,8 @@ class WhereNode(tree.Node):
else: else:
extra = '' extra = ''
params = field_params + params
if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
and connection.features.interprets_empty_strings_as_nulls): and connection.features.interprets_empty_strings_as_nulls):
lookup_type = 'isnull' lookup_type = 'isnull'
@ -245,7 +247,7 @@ class WhereNode(tree.Node):
""" """
Returns the SQL fragment used for the left-hand side of a column Returns the SQL fragment used for the left-hand side of a column
constraint (for example, the "T1.foo" portion in the clause constraint (for example, the "T1.foo" portion in the clause
"WHERE ... T1.foo = 6"). "WHERE ... T1.foo = 6") and a list of parameters.
""" """
table_alias, name, db_type = data table_alias, name, db_type = data
if table_alias: if table_alias:
@ -338,7 +340,7 @@ class ExtraWhere(object):
def as_sql(self, qn=None, connection=None): def as_sql(self, qn=None, connection=None):
sqls = ["(%s)" % sql for sql in self.sqls] sqls = ["(%s)" % sql for sql in self.sqls]
return " AND ".join(sqls), tuple(self.params or ()) return " AND ".join(sqls), list(self.params or ())
def clone(self): def clone(self):
return self return self