Handle quoting identifiers properly

This commit is contained in:
Kacper Madej 2024-12-17 22:35:05 +01:00
parent 783ec65c77
commit 9e01c22a5e
4 changed files with 55 additions and 19 deletions

View file

@ -97,12 +97,13 @@ fn bind_column_references(
return Ok(());
}
let mut match_result = None;
let normalized_id = normalize_ident(id.0.as_str());
for (tbl_idx, table) in referenced_tables.iter().enumerate() {
let col_idx = table
.table
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(&id.0));
.position(|c| c.name.eq_ignore_ascii_case(&normalized_id));
if col_idx.is_some() {
if match_result.is_some() {
crate::bail_parse_error!("Column {} is ambiguous", id.0);
@ -124,20 +125,23 @@ fn bind_column_references(
Ok(())
}
ast::Expr::Qualified(tbl, id) => {
let matching_tbl_idx = referenced_tables
.iter()
.position(|t| t.table_identifier.eq_ignore_ascii_case(&tbl.0));
let normalized_table_name = normalize_ident(tbl.0.as_str());
let matching_tbl_idx = referenced_tables.iter().position(|t| {
t.table_identifier
.eq_ignore_ascii_case(&normalized_table_name)
});
if matching_tbl_idx.is_none() {
crate::bail_parse_error!("Table {} not found", tbl.0);
crate::bail_parse_error!("Table {} not found", normalized_table_name);
}
let tbl_idx = matching_tbl_idx.unwrap();
let normalized_id = normalize_ident(id.0.as_str());
let col_idx = referenced_tables[tbl_idx]
.table
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(&id.0));
.position(|c| c.name.eq_ignore_ascii_case(&normalized_id));
if col_idx.is_none() {
crate::bail_parse_error!("Column {} not found", id.0);
crate::bail_parse_error!("Column {} not found", normalized_id);
}
let col = referenced_tables[tbl_idx]
.table
@ -504,8 +508,9 @@ fn parse_from(
let first_table = match *from.select.unwrap() {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let Some(table) = schema.get_table(&qualified_name.name.0) else {
crate::bail_parse_error!("Table {} not found", qualified_name.name.0);
let normalized_qualified_name = normalize_ident(qualified_name.name.0.as_str());
let Some(table) = schema.get_table(&normalized_qualified_name) else {
crate::bail_parse_error!("Table {} not found", normalized_qualified_name);
};
let alias = maybe_alias
.map(|a| match a {
@ -516,7 +521,7 @@ fn parse_from(
BTreeTableReference {
table: table.clone(),
table_identifier: alias.unwrap_or(qualified_name.name.0),
table_identifier: alias.unwrap_or(normalized_qualified_name),
table_index: 0,
}
}
@ -570,8 +575,9 @@ fn parse_join(
let table = match table {
ast::SelectTable::Table(qualified_name, maybe_alias, _) => {
let Some(table) = schema.get_table(&qualified_name.name.0) else {
crate::bail_parse_error!("Table {} not found", qualified_name.name.0);
let normalized_name = normalize_ident(qualified_name.name.0.as_str());
let Some(table) = schema.get_table(&normalized_name) else {
crate::bail_parse_error!("Table {} not found", normalized_name);
};
let alias = maybe_alias
.map(|a| match a {
@ -581,7 +587,7 @@ fn parse_join(
.map(|a| a.0);
BTreeTableReference {
table: table.clone(),
table_identifier: alias.unwrap_or(qualified_name.name.0),
table_identifier: alias.unwrap_or(normalized_name),
table_index,
}
}

View file

@ -7,12 +7,19 @@ use crate::{
Result, RowResult, Rows, IO,
};
pub fn normalize_ident(ident: &str) -> String {
(if ident.starts_with('"') && ident.ends_with('"') {
&ident[1..ident.len() - 1]
// https://sqlite.org/lang_keywords.html
const QUOTE_PAIRS: &[(char, char)] = &[('"', '"'), ('[', ']'), ('`', '`')];
pub fn normalize_ident(identifier: &str) -> String {
let quote_pair = QUOTE_PAIRS
.iter()
.find(|&(start, end)| identifier.starts_with(*start) && identifier.ends_with(*end));
if let Some(&(start, end)) = quote_pair {
&identifier[1..identifier.len() - 1]
} else {
ident
})
identifier
}
.to_lowercase()
}
@ -65,7 +72,6 @@ fn cmp_numeric_strings(num_str: &str, other: &str) -> bool {
}
}
const QUOTE_PAIRS: &[(char, char)] = &[('"', '"'), ('[', ']'), ('`', '`')];
pub fn check_ident_equivalency(ident1: &str, ident2: &str) -> bool {
fn strip_quotes(identifier: &str) -> &str {
for &(start, end) in QUOTE_PAIRS {
@ -276,7 +282,17 @@ pub fn exprs_are_equivalent(expr1: &Expr, expr2: &Expr) -> bool {
#[cfg(test)]
pub mod tests {
use super::*;
use sqlite3_parser::ast::{self, Expr, Id, Literal, Operator::*, Type};
#[test]
fn test_normalize_ident() {
assert_eq!(normalize_ident("foo"), "foo");
assert_eq!(normalize_ident("`foo`"), "foo");
assert_eq!(normalize_ident("[foo]"), "foo");
assert_eq!(normalize_ident("\"foo\""), "foo");
}
#[test]
fn test_basic_addition_exprs_are_equivalent() {
let expr1 = Expr::Binary(

View file

@ -240,6 +240,12 @@ do_execsql_test join-using-multiple {
Cindy|Salazar|cap
Tommy|Perry|shirt"}
do_execsql_test join-using-multiple-with-quoting {
select u.first_name, u.last_name, p.name from users u join users u2 using(id) join [products] p using(`id`) limit 3;
} {"Jamie|Foster|hat
Cindy|Salazar|cap
Tommy|Perry|shirt"}
# NATURAL JOIN desugars to JOIN USING (common_column1, common_column2...)
do_execsql_test join-using {
select * from users natural join products limit 3;

View file

@ -51,6 +51,14 @@ do_execsql_test table-star-2 {
select p.*, u.first_name from users u join products p on u.id = p.id limit 1;
} {1|hat|79.0|Jamie}
do_execsql_test select_with_quoting {
select `users`.id from [users] where users.[id] = 5;
} {5}
do_execsql_test select_with_quoting_2 {
select "users".`id` from users where `users`.[id] = 5;
} {5}
do_execsql_test seekrowid {
select * from users u where u.id = 5;
} {"5|Edward|Miller|christiankramer@example.com|725-281-1033|08522 English Plain|Lake Keith|ID|23283|15"}