From 9f772f03b0ce17fb8993ddf20c3daf53ad37cff4 Mon Sep 17 00:00:00 2001 From: rhanqtl Date: Sun, 11 Oct 2020 14:43:51 +0800 Subject: [PATCH] Add support for Recursive CTEs (#278) i.e. `WITH RECURSIVE ... AS ( ... ) SELECT` - see https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#with-clause Fixes #277 --- src/ast/mod.rs | 2 +- src/ast/query.rs | 26 ++++++++++++++++++---- src/parser.rs | 12 +++++----- src/tokenizer.rs | 5 +---- tests/sqlparser_common.rs | 47 ++++++++++++++++++++++++++++++++++----- 5 files changed, 73 insertions(+), 19 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 8a58207d..a726b299 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -30,7 +30,7 @@ pub use self::ddl::{ pub use self::operator::{BinaryOperator, UnaryOperator}; pub use self::query::{ Cte, Fetch, Join, JoinConstraint, JoinOperator, Offset, OffsetRows, OrderByExpr, Query, Select, - SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values, + SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values, With, }; pub use self::value::{DateTimeField, Value}; diff --git a/src/ast/query.rs b/src/ast/query.rs index e0dbe4c7..06ea9c5b 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -20,8 +20,8 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Query { /// WITH (common table expressions, or CTEs) - pub ctes: Vec, - /// SELECT or UNION / EXCEPT / INTECEPT + pub with: Option, + /// SELECT or UNION / EXCEPT / INTERSECT pub body: SetExpr, /// ORDER BY pub order_by: Vec, @@ -35,8 +35,8 @@ pub struct Query { impl fmt::Display for Query { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if !self.ctes.is_empty() { - write!(f, "WITH {} ", display_comma_separated(&self.ctes))?; + if let Some(ref with) = self.with { + write!(f, "{} ", with)?; } write!(f, "{}", self.body)?; if !self.order_by.is_empty() { @@ -157,6 +157,24 @@ impl fmt::Display for Select { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct With { + pub recursive: bool, + pub cte_tables: Vec, +} + +impl fmt::Display for With { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "WITH {}{}", + if self.recursive { "RECURSIVE " } else { "" }, + display_comma_separated(&self.cte_tables) + ) + } +} + /// A single CTE (used after `WITH`): `alias [(col1, col2, ...)] AS ( query )` /// The names in the column list before `AS`, when specified, replace the names /// of the columns returned by the query. The parser does not validate that the diff --git a/src/parser.rs b/src/parser.rs index 438c9f1e..5d98f9c8 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1795,11 +1795,13 @@ impl<'a> Parser<'a> { /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed pub fn parse_query(&mut self) -> Result { - let ctes = if self.parse_keyword(Keyword::WITH) { - // TODO: optional RECURSIVE - self.parse_comma_separated(Parser::parse_cte)? + let with = if self.parse_keyword(Keyword::WITH) { + Some(With { + recursive: self.parse_keyword(Keyword::RECURSIVE), + cte_tables: self.parse_comma_separated(Parser::parse_cte)?, + }) } else { - vec![] + None }; let body = self.parse_query_body(0)?; @@ -1829,7 +1831,7 @@ impl<'a> Parser<'a> { }; Ok(Query { - ctes, + with, body, limit, order_by, diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 2496e63a..70587f18 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -382,10 +382,7 @@ impl<'a> Tokenizer<'a> { // numbers '0'..='9' => { // TODO: https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#unsigned-numeric-literal - let s = peeking_take_while(chars, |ch| match ch { - '0'..='9' | '.' => true, - _ => false, - }); + let s = peeking_take_while(chars, |ch| matches!(ch, '0'..='9' | '.')); Ok(Some(Token::Number(s))) } // punctuation diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 6b032d92..411b77f6 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -2389,7 +2389,7 @@ fn parse_ctes() { fn assert_ctes_in_select(expected: &[&str], sel: &Query) { for (i, exp) in expected.iter().enumerate() { - let Cte { alias, query } = &sel.ctes[i]; + let Cte { alias, query } = &sel.with.as_ref().unwrap().cte_tables[i]; assert_eq!(*exp, query.to_string()); assert_eq!( if i == 0 { @@ -2432,7 +2432,7 @@ fn parse_ctes() { // CTE in a CTE... let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with); let select = verified_query(sql); - assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query); + assert_ctes_in_select(&cte_sqls, &only(&select.with.unwrap().cte_tables).query); } #[test] @@ -2441,10 +2441,47 @@ fn parse_cte_renamed_columns() { let query = all_dialects().verified_query(sql); assert_eq!( vec![Ident::new("col1"), Ident::new("col2")], - query.ctes.first().unwrap().alias.columns + query + .with + .unwrap() + .cte_tables + .first() + .unwrap() + .alias + .columns ); } +#[test] +fn parse_recursive_cte() { + let cte_query = "SELECT 1 UNION ALL SELECT val + 1 FROM nums WHERE val < 10".to_owned(); + let sql = &format!( + "WITH RECURSIVE nums (val) AS ({}) SELECT * FROM nums", + cte_query + ); + + let cte_query = verified_query(&cte_query); + let query = verified_query(sql); + + let with = query.with.as_ref().unwrap(); + assert!(with.recursive); + assert_eq!(with.cte_tables.len(), 1); + let expected = Cte { + alias: TableAlias { + name: Ident { + value: "nums".to_string(), + quote_style: None, + }, + columns: vec![Ident { + value: "val".to_string(), + quote_style: None, + }], + }, + query: cte_query, + }; + assert_eq!(with.cte_tables.first().unwrap(), &expected); +} + #[test] fn parse_derived_tables() { let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b"; @@ -3266,8 +3303,8 @@ fn parse_drop_index() { fn all_keywords_sorted() { // assert!(ALL_KEYWORDS.is_sorted()) let mut copy = Vec::from(ALL_KEYWORDS); - copy.sort(); - assert!(copy == ALL_KEYWORDS) + copy.sort_unstable(); + assert_eq!(copy, ALL_KEYWORDS) } fn parse_sql_statements(sql: &str) -> Result, ParserError> {