From 10519003ed06defa48b1c9ecc734b3d5c92c297d Mon Sep 17 00:00:00 2001 From: tomershaniii <65544633+tomershaniii@users.noreply.github.com> Date: Sat, 23 Nov 2024 13:33:14 +0200 Subject: [PATCH] recursive select calls are parsed with bad trailing_commas parameter (#1521) --- src/parser/mod.rs | 39 +++++++++++++++++++++++++----------- tests/sqlparser_snowflake.rs | 17 ++++++++++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 35c763e9..c8358767 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -3532,16 +3532,11 @@ impl<'a> Parser<'a> { // e.g. `SELECT 1, 2, FROM t` // https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#trailing_commas // https://docs.snowflake.com/en/release-notes/2024/8_11#select-supports-trailing-commas - // - // This pattern could be captured better with RAII type semantics, but it's quite a bit of - // code to add for just one case, so we'll just do it manually here. - let old_value = self.options.trailing_commas; - self.options.trailing_commas |= self.dialect.supports_projection_trailing_commas(); - let ret = self.parse_comma_separated(|p| p.parse_select_item()); - self.options.trailing_commas = old_value; + let trailing_commas = + self.options.trailing_commas | self.dialect.supports_projection_trailing_commas(); - ret + self.parse_comma_separated_with_trailing_commas(|p| p.parse_select_item(), trailing_commas) } pub fn parse_actions_list(&mut self) -> Result, ParserError> { @@ -3568,11 +3563,12 @@ impl<'a> Parser<'a> { } /// Parse the comma of a comma-separated syntax element. + /// Allows for control over trailing commas /// Returns true if there is a next element - fn is_parse_comma_separated_end(&mut self) -> bool { + fn is_parse_comma_separated_end_with_trailing_commas(&mut self, trailing_commas: bool) -> bool { if !self.consume_token(&Token::Comma) { true - } else if self.options.trailing_commas { + } else if trailing_commas { let token = self.peek_token().token; match token { Token::Word(ref kw) @@ -3590,15 +3586,34 @@ impl<'a> Parser<'a> { } } + /// Parse the comma of a comma-separated syntax element. + /// Returns true if there is a next element + fn is_parse_comma_separated_end(&mut self) -> bool { + self.is_parse_comma_separated_end_with_trailing_commas(self.options.trailing_commas) + } + /// Parse a comma-separated list of 1+ items accepted by `F` - pub fn parse_comma_separated(&mut self, mut f: F) -> Result, ParserError> + pub fn parse_comma_separated(&mut self, f: F) -> Result, ParserError> + where + F: FnMut(&mut Parser<'a>) -> Result, + { + self.parse_comma_separated_with_trailing_commas(f, self.options.trailing_commas) + } + + /// Parse a comma-separated list of 1+ items accepted by `F` + /// Allows for control over trailing commas + fn parse_comma_separated_with_trailing_commas( + &mut self, + mut f: F, + trailing_commas: bool, + ) -> Result, ParserError> where F: FnMut(&mut Parser<'a>) -> Result, { let mut values = vec![]; loop { values.push(f(self)?); - if self.is_parse_comma_separated_end() { + if self.is_parse_comma_separated_end_with_trailing_commas(trailing_commas) { break; } } diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index 1f1c00e7..1d053bb0 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -2846,3 +2846,20 @@ fn test_parse_show_columns_sql() { snowflake().verified_stmt("SHOW COLUMNS IN TABLE abc"); snowflake().verified_stmt("SHOW COLUMNS LIKE '%xyz%' IN TABLE abc"); } + +#[test] +fn test_projection_with_nested_trailing_commas() { + let sql = "SELECT a, FROM b, LATERAL FLATTEN(input => events)"; + let _ = snowflake().parse_sql_statements(sql).unwrap(); + + //Single nesting + let sql = "SELECT (SELECT a, FROM b, LATERAL FLATTEN(input => events))"; + let _ = snowflake().parse_sql_statements(sql).unwrap(); + + //Double nesting + let sql = "SELECT (SELECT (SELECT a, FROM b, LATERAL FLATTEN(input => events)))"; + let _ = snowflake().parse_sql_statements(sql).unwrap(); + + let sql = "SELECT a, b, FROM c, (SELECT d, e, FROM f, LATERAL FLATTEN(input => events))"; + let _ = snowflake().parse_sql_statements(sql).unwrap(); +}