diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index 6668b785..b220770e 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -18,14 +18,7 @@ impl ToString for SQLQuery { fn to_string(&self) -> String { let mut s = String::new(); if !self.ctes.is_empty() { - s += &format!( - "WITH {} ", - self.ctes - .iter() - .map(|cte| format!("{} AS ({})", cte.alias, cte.query.to_string())) - .collect::>() - .join(", ") - ) + s += &format!("WITH {} ", comma_separated_string(&self.ctes)) } s += &self.body.to_string(); if let Some(ref order_by) = self.order_by { @@ -144,11 +137,25 @@ impl ToString for SQLSelect { } } -/// A single CTE (used after `WITH`): `alias AS ( query )` +/// A single CTE (used after `WITH`): `alias [(col1, col2, ...)] AS ( query )` +/// The names in the column list before `AS`, when specified, replace the names +/// of the columns returned by the query. The parser does not validate that the +/// number of columns in the query matches the number of columns in the query. #[derive(Debug, Clone, PartialEq)] pub struct Cte { pub alias: SQLIdent, pub query: SQLQuery, + pub renamed_columns: Vec, +} + +impl ToString for Cte { + fn to_string(&self) -> String { + let mut s = self.alias.clone(); + if !self.renamed_columns.is_empty() { + s += &format!(" ({})", comma_separated_string(&self.renamed_columns)); + } + s + &format!(" AS ({})", self.query.to_string()) + } } /// One item of the comma-separated list following `SELECT` diff --git a/src/sqlparser.rs b/src/sqlparser.rs index d29a7da2..a19eab08 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -1215,12 +1215,19 @@ impl Parser { let mut cte = vec![]; loop { let alias = self.parse_identifier()?; - // TODO: Optional `( )` + let renamed_columns = if self.consume_token(&Token::LParen) { + let cols = self.parse_column_names()?; + self.expect_token(&Token::RParen)?; + cols + } else { + vec![] + }; self.expect_keyword("AS")?; self.expect_token(&Token::LParen)?; cte.push(Cte { alias, query: self.parse_query()?, + renamed_columns, }); self.expect_token(&Token::RParen)?; if !self.consume_token(&Token::Comma) { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index e44b3894..df08adce 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1027,9 +1027,14 @@ fn parse_ctes() { fn assert_ctes_in_select(expected: &[&str], sel: &SQLQuery) { let mut i = 0; for exp in expected { - let Cte { query, alias } = &sel.ctes[i]; + let Cte { + query, + alias, + renamed_columns, + } = &sel.ctes[i]; assert_eq!(*exp, query.to_string()); assert_eq!(if i == 0 { "a" } else { "b" }, alias); + assert!(renamed_columns.is_empty()); i += 1; } } @@ -1066,6 +1071,16 @@ fn parse_ctes() { assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query); } +#[test] +fn parse_cte_renamed_columns() { + let sql = "WITH cte (col1, col2) AS (SELECT foo, bar FROM baz) SELECT * FROM cte"; + let query = all_dialects().verified_query(sql); + assert_eq!( + vec!["col1", "col2"], + query.ctes.first().unwrap().renamed_columns + ); +} + #[test] fn parse_derived_tables() { let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b";