Enhancing Trailing Comma Option (#1212)

This commit is contained in:
Mohamed Abdeen 2024-06-07 13:44:04 +03:00 committed by GitHub
parent a0f511cb21
commit 6d4776b482
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 118 additions and 12 deletions

View file

@ -22,6 +22,10 @@ impl Dialect for BigQueryDialect {
ch == '`' ch == '`'
} }
fn supports_projection_trailing_commas(&self) -> bool {
true
}
fn is_identifier_start(&self, ch: char) -> bool { fn is_identifier_start(&self, ch: char) -> bool {
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_'
} }

View file

@ -18,6 +18,10 @@ pub struct DuckDbDialect;
// In most cases the redshift dialect is identical to [`PostgresSqlDialect`]. // In most cases the redshift dialect is identical to [`PostgresSqlDialect`].
impl Dialect for DuckDbDialect { impl Dialect for DuckDbDialect {
fn supports_trailing_commas(&self) -> bool {
true
}
fn is_identifier_start(&self, ch: char) -> bool { fn is_identifier_start(&self, ch: char) -> bool {
ch.is_alphabetic() || ch == '_' ch.is_alphabetic() || ch == '_'
} }

View file

@ -251,6 +251,14 @@ pub trait Dialect: Debug + Any {
// return None to fall back to the default behavior // return None to fall back to the default behavior
None None
} }
/// Does the dialect support trailing commas around the query?
fn supports_trailing_commas(&self) -> bool {
false
}
/// Does the dialect support trailing commas in the projection list?
fn supports_projection_trailing_commas(&self) -> bool {
self.supports_trailing_commas()
}
/// Dialect-specific infix parser override /// Dialect-specific infix parser override
fn parse_infix( fn parse_infix(
&self, &self,

View file

@ -38,6 +38,10 @@ impl Dialect for SnowflakeDialect {
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_'
} }
fn supports_projection_trailing_commas(&self) -> bool {
true
}
fn is_identifier_part(&self, ch: char) -> bool { fn is_identifier_part(&self, ch: char) -> bool {
ch.is_ascii_lowercase() ch.is_ascii_lowercase()
|| ch.is_ascii_uppercase() || ch.is_ascii_uppercase()

View file

@ -305,7 +305,7 @@ impl<'a> Parser<'a> {
state: ParserState::Normal, state: ParserState::Normal,
dialect, dialect,
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH), recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
options: ParserOptions::default(), options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()),
} }
} }
@ -3225,7 +3225,7 @@ impl<'a> Parser<'a> {
// This pattern could be captured better with RAII type semantics, but it's quite a bit of // This pattern could be captured better with RAII type semantics, but it's quite a bit of
// code to add for just one case, so we'll just do it manually here. // code to add for just one case, so we'll just do it manually here.
let old_value = self.options.trailing_commas; let old_value = self.options.trailing_commas;
self.options.trailing_commas |= dialect_of!(self is BigQueryDialect | SnowflakeDialect); self.options.trailing_commas |= self.dialect.supports_projection_trailing_commas();
let ret = self.parse_comma_separated(|p| p.parse_select_item()); let ret = self.parse_comma_separated(|p| p.parse_select_item());
self.options.trailing_commas = old_value; self.options.trailing_commas = old_value;
@ -5413,12 +5413,17 @@ impl<'a> Parser<'a> {
} else { } else {
return self.expected("column name or constraint definition", self.peek_token()); return self.expected("column name or constraint definition", self.peek_token());
} }
let comma = self.consume_token(&Token::Comma); let comma = self.consume_token(&Token::Comma);
if self.consume_token(&Token::RParen) { let rparen = self.peek_token().token == Token::RParen;
// allow a trailing comma, even though it's not in standard
break; if !comma && !rparen {
} else if !comma {
return self.expected("',' or ')' after column definition", self.peek_token()); return self.expected("',' or ')' after column definition", self.peek_token());
};
if rparen && (!comma || self.options.trailing_commas) {
let _ = self.consume_token(&Token::RParen);
break;
} }
} }
@ -9411,6 +9416,9 @@ impl<'a> Parser<'a> {
with_privileges_keyword: self.parse_keyword(Keyword::PRIVILEGES), with_privileges_keyword: self.parse_keyword(Keyword::PRIVILEGES),
} }
} else { } else {
let old_value = self.options.trailing_commas;
self.options.trailing_commas = false;
let (actions, err): (Vec<_>, Vec<_>) = self let (actions, err): (Vec<_>, Vec<_>) = self
.parse_comma_separated(Parser::parse_grant_permission)? .parse_comma_separated(Parser::parse_grant_permission)?
.into_iter() .into_iter()
@ -9434,6 +9442,8 @@ impl<'a> Parser<'a> {
}) })
.partition(Result::is_ok); .partition(Result::is_ok);
self.options.trailing_commas = old_value;
if !err.is_empty() { if !err.is_empty() {
let errors: Vec<Keyword> = err.into_iter().filter_map(|x| x.err()).collect(); let errors: Vec<Keyword> = err.into_iter().filter_map(|x| x.err()).collect();
return Err(ParserError::ParserError(format!( return Err(ParserError::ParserError(format!(
@ -9939,6 +9949,12 @@ impl<'a> Parser<'a> {
Expr::Wildcard => Ok(SelectItem::Wildcard( Expr::Wildcard => Ok(SelectItem::Wildcard(
self.parse_wildcard_additional_options()?, self.parse_wildcard_additional_options()?,
)), )),
Expr::Identifier(v) if v.value.to_lowercase() == "from" => {
parser_err!(
format!("Expected an expression, found: {}", v),
self.peek_token().location
)
}
expr => self expr => self
.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS) .parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)
.map(|alias| match alias { .map(|alias| match alias {

View file

@ -3552,8 +3552,13 @@ fn parse_create_table_clone() {
#[test] #[test]
fn parse_create_table_trailing_comma() { fn parse_create_table_trailing_comma() {
let sql = "CREATE TABLE foo (bar int,)"; let dialect = TestedDialects {
all_dialects().one_statement_parses_to(sql, "CREATE TABLE foo (bar INT)"); dialects: vec![Box::new(DuckDbDialect {})],
options: None,
};
let sql = "CREATE TABLE foo (bar int,);";
dialect.one_statement_parses_to(sql, "CREATE TABLE foo (bar INT)");
} }
#[test] #[test]
@ -4418,7 +4423,7 @@ fn parse_window_clause() {
ORDER BY C3"; ORDER BY C3";
verified_only_select(sql); verified_only_select(sql);
let sql = "SELECT from mytable WINDOW window1 AS window2"; let sql = "SELECT * from mytable WINDOW window1 AS window2";
let dialects = all_dialects_except(|d| d.is::<BigQueryDialect>() || d.is::<GenericDialect>()); let dialects = all_dialects_except(|d| d.is::<BigQueryDialect>() || d.is::<GenericDialect>());
let res = dialects.parse_sql_statements(sql); let res = dialects.parse_sql_statements(sql);
assert_eq!( assert_eq!(
@ -8846,9 +8851,11 @@ fn parse_non_latin_identifiers() {
#[test] #[test]
fn parse_trailing_comma() { fn parse_trailing_comma() {
// At the moment, Duck DB is the only dialect that allows
// trailing commas anywhere in the query
let trailing_commas = TestedDialects { let trailing_commas = TestedDialects {
dialects: vec![Box::new(GenericDialect {})], dialects: vec![Box::new(DuckDbDialect {})],
options: Some(ParserOptions::new().with_trailing_commas(true)), options: None,
}; };
trailing_commas.one_statement_parses_to( trailing_commas.one_statement_parses_to(
@ -8866,11 +8873,74 @@ fn parse_trailing_comma() {
"SELECT DISTINCT ON (album_id) name FROM track", "SELECT DISTINCT ON (album_id) name FROM track",
); );
trailing_commas.one_statement_parses_to(
"CREATE TABLE employees (name text, age int,)",
"CREATE TABLE employees (name TEXT, age INT)",
);
trailing_commas.verified_stmt("SELECT album_id, name FROM track"); trailing_commas.verified_stmt("SELECT album_id, name FROM track");
trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds"); trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds");
trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track"); trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track");
// doesn't allow any trailing commas
let trailing_commas = TestedDialects {
dialects: vec![Box::new(GenericDialect {})],
options: None,
};
assert_eq!(
trailing_commas
.parse_sql_statements("SELECT name, age, from employees;")
.unwrap_err(),
ParserError::ParserError("Expected an expression, found: from".to_string())
);
assert_eq!(
trailing_commas
.parse_sql_statements("CREATE TABLE employees (name text, age int,)")
.unwrap_err(),
ParserError::ParserError(
"Expected column name or constraint definition, found: )".to_string()
)
);
}
#[test]
fn parse_projection_trailing_comma() {
// Some dialects allow trailing commas only in the projection
let trailing_commas = TestedDialects {
dialects: vec![Box::new(SnowflakeDialect {}), Box::new(BigQueryDialect {})],
options: None,
};
trailing_commas.one_statement_parses_to(
"SELECT album_id, name, FROM track",
"SELECT album_id, name FROM track",
);
trailing_commas.verified_stmt("SELECT album_id, name FROM track");
trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds");
trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track");
assert_eq!(
trailing_commas
.parse_sql_statements("SELECT * FROM track ORDER BY milliseconds,")
.unwrap_err(),
ParserError::ParserError("Expected an expression:, found: EOF".to_string())
);
assert_eq!(
trailing_commas
.parse_sql_statements("CREATE TABLE employees (name text, age int,)")
.unwrap_err(),
ParserError::ParserError(
"Expected column name or constraint definition, found: )".to_string()
),
);
} }
#[test] #[test]

View file

@ -3701,7 +3701,7 @@ fn parse_create_table_with_alias() {
int2_col INT2, int2_col INT2,
float8_col FLOAT8, float8_col FLOAT8,
float4_col FLOAT4, float4_col FLOAT4,
bool_col BOOL, bool_col BOOL
);"; );";
match pg_and_generic().one_statement_parses_to(sql, "") { match pg_and_generic().one_statement_parses_to(sql, "") {
Statement::CreateTable(CreateTable { Statement::CreateTable(CreateTable {