diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index b343b837..f77058b3 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -20,7 +20,9 @@ mod sqltype; mod table_key; mod value; -pub use self::query::{Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLSelect, TableFactor}; +pub use self::query::{ + Cte, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect, TableFactor, +}; pub use self::sqltype::SQLType; pub use self::table_key::{AlterOperation, Key, TableKey}; pub use self::value::Value; @@ -76,7 +78,7 @@ pub enum ASTNode { }, /// A parenthesized subquery `(SELECT ...)`, used in expression like /// `SELECT (subquery) AS x` or `WHERE (subquery) = x` - SQLSubquery(Box), + SQLSubquery(Box), } impl ToString for ASTNode { @@ -139,7 +141,7 @@ impl ToString for ASTNode { #[derive(Debug, Clone, PartialEq)] pub enum SQLStatement { /// SELECT - SQLSelect(SQLSelect), + SQLSelect(SQLQuery), /// INSERT SQLInsert { /// TABLE @@ -177,7 +179,7 @@ pub enum SQLStatement { SQLCreateView { /// View name name: SQLObjectName, - query: SQLSelect, + query: SQLQuery, }, /// CREATE TABLE SQLCreateTable { diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index 887042f9..28f221bd 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -1,5 +1,53 @@ use super::*; +/// The most complete variant of a `SELECT` query expression, optionally +/// including `WITH`, `UNION` / other set operations, and `ORDER BY`. +#[derive(Debug, Clone, PartialEq)] +pub struct SQLQuery { + /// WITH (common table expressions, or CTEs) + pub ctes: Vec, + /// SELECT or UNION / EXCEPT / INTECEPT + pub body: SQLSelect, + /// ORDER BY + pub order_by: Option>, + /// LIMIT + pub limit: Option, +} + +impl ToString for SQLQuery { + fn to_string(&self) -> String { + let mut s = String::new(); + if !self.ctes.is_empty() { + s += &format!( + "WITH {} ", + self.ctes + .iter() + .map(|cte| format!("{} AS ({})", cte.alias, cte.query.to_string())) + .collect::>() + .join(", ") + ) + } + s += &self.body.to_string(); + if let Some(ref order_by) = self.order_by { + s += &format!( + " ORDER BY {}", + order_by + .iter() + .map(|o| o.to_string()) + .collect::>() + .join(", ") + ); + } + if let Some(ref limit) = self.limit { + s += &format!(" LIMIT {}", limit.to_string()); + } + s + } +} + +/// A restricted variant of `SELECT` (without CTEs/`ORDER BY`), which may +/// appear either as the only body item of an `SQLQuery`, or as an operand +/// to a set operation like `UNION`. #[derive(Debug, Clone, PartialEq)] pub struct SQLSelect { /// projection expressions @@ -10,14 +58,10 @@ pub struct SQLSelect { pub joins: Vec, /// WHERE pub selection: Option, - /// ORDER BY - pub order_by: Option>, /// GROUP BY pub group_by: Option>, /// HAVING pub having: Option, - /// LIMIT - pub limit: Option, } impl ToString for SQLSelect { @@ -52,23 +96,17 @@ impl ToString for SQLSelect { if let Some(ref having) = self.having { s += &format!(" HAVING {}", having.to_string()); } - if let Some(ref order_by) = self.order_by { - s += &format!( - " ORDER BY {}", - order_by - .iter() - .map(|o| o.to_string()) - .collect::>() - .join(", ") - ); - } - if let Some(ref limit) = self.limit { - s += &format!(" LIMIT {}", limit.to_string()); - } s } } +/// A single CTE (used after `WITH`): `alias AS ( query )` +#[derive(Debug, Clone, PartialEq)] +pub struct Cte { + pub alias: SQLIdent, + pub query: SQLQuery, +} + /// A table name or a parenthesized subquery with an optional alias #[derive(Debug, Clone, PartialEq)] pub enum TableFactor { @@ -77,7 +115,7 @@ pub enum TableFactor { alias: Option, }, Derived { - subquery: Box, + subquery: Box, alias: Option, }, } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index a6784aa1..d04be6e0 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -88,7 +88,10 @@ impl Parser { match self.next_token() { Some(t) => match t { Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() { - "SELECT" => Ok(SQLStatement::SQLSelect(self.parse_select()?)), + "SELECT" | "WITH" => { + self.prev_token(); + Ok(SQLStatement::SQLSelect(self.parse_query()?)) + } "CREATE" => Ok(self.parse_create()?), "DELETE" => Ok(self.parse_delete()?), "INSERT" => Ok(self.parse_insert()?), @@ -198,8 +201,9 @@ impl Parser { self.parse_sql_value() } Token::LParen => { - let expr = if self.parse_keyword("SELECT") { - ASTNode::SQLSubquery(Box::new(self.parse_select()?)) + let expr = if self.parse_keyword("SELECT") || self.parse_keyword("WITH") { + self.prev_token(); + ASTNode::SQLSubquery(Box::new(self.parse_query()?)) } else { ASTNode::SQLNested(Box::new(self.parse_expr()?)) }; @@ -568,8 +572,7 @@ impl Parser { // Some dialects allow WITH here, followed by some keywords (e.g. MS SQL) // or `(k1=v1, k2=v2, ...)` (Postgres) self.expect_keyword("AS")?; - self.expect_keyword("SELECT")?; - let query = self.parse_select()?; + let query = self.parse_query()?; // Optional `WITH [ CASCADED | LOCAL ] CHECK OPTION` is widely supported here. Ok(SQLStatement::SQLCreateView { name, query }) } @@ -673,18 +676,9 @@ impl Parser { let table_name = self.parse_object_name()?; let operation: Result = if self.parse_keywords(vec!["ADD", "CONSTRAINT"]) { - match self.next_token() { - Some(Token::SQLWord(ref id)) => { - let table_key = self.parse_table_key(id.as_sql_ident())?; - Ok(AlterOperation::AddConstraint(table_key)) - } - _ => { - return parser_err!(format!( - "Expecting identifier, found : {:?}", - self.peek_token() - )); - } - } + let constraint_name = self.parse_identifier()?; + let table_key = self.parse_table_key(constraint_name)?; + Ok(AlterOperation::AddConstraint(table_key)) } else { return parser_err!(format!( "Expecting ADD CONSTRAINT, found :{:?}", @@ -1079,6 +1073,14 @@ impl Parser { Ok(SQLObjectName(self.parse_list_of_ids(&Token::Period)?)) } + /// Parse a simple one-word identifier (possibly quoted, possibly a keyword) + pub fn parse_identifier(&mut self) -> Result { + match self.next_token() { + Some(Token::SQLWord(w)) => Ok(w.as_sql_ident()), + unexpected => parser_err!(format!("Expected identifier, found {:?}", unexpected)), + } + } + /// Parse a comma-separated list of unqualified, possibly quoted identifiers pub fn parse_column_names(&mut self) -> Result, ParserError> { Ok(self.parse_list_of_ids(&Token::Comma)?) @@ -1132,7 +1134,64 @@ impl Parser { }) } - /// Parse a SELECT statement + /// Parse a query expression, i.e. a `SELECT` statement optionally + /// preceeded with some `WITH` CTE declarations and optionally followed + /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't + /// expect the initial keyword to be already consumed + pub fn parse_query(&mut self) -> Result { + let ctes = if self.parse_keyword("WITH") { + // TODO: optional RECURSIVE + self.parse_cte_list()? + } else { + vec![] + }; + + self.expect_keyword("SELECT")?; + let body = self.parse_select()?; + + let order_by = if self.parse_keywords(vec!["ORDER", "BY"]) { + Some(self.parse_order_by_expr_list()?) + } else { + None + }; + + let limit = if self.parse_keyword("LIMIT") { + self.parse_limit()? + } else { + None + }; + + Ok(SQLQuery { + ctes, + body, + limit, + order_by, + }) + } + + /// Parse one or more (comma-separated) `alias AS (subquery)` CTEs, + /// assuming the initial `WITH` was already consumed. + fn parse_cte_list(&mut self) -> Result, ParserError> { + let mut cte = vec![]; + loop { + let alias = self.parse_identifier()?; + // TODO: Optional `( )` + self.expect_keyword("AS")?; + self.expect_token(&Token::LParen)?; + cte.push(Cte { + alias, + query: self.parse_query()?, + }); + self.expect_token(&Token::RParen)?; + if !self.consume_token(&Token::Comma) { + break; + } + } + return Ok(cte); + } + + /// Parse a restricted `SELECT` statement (no CTEs / `UNION` / `ORDER BY`), + /// assuming the initial `SELECT` was already consumed pub fn parse_select(&mut self) -> Result { let projection = self.parse_expr_list()?; @@ -1145,8 +1204,7 @@ impl Parser { }; let selection = if self.parse_keyword("WHERE") { - let expr = self.parse_expr()?; - Some(expr) + Some(self.parse_expr()?) } else { None }; @@ -1163,25 +1221,11 @@ impl Parser { None }; - let order_by = if self.parse_keywords(vec!["ORDER", "BY"]) { - Some(self.parse_order_by_expr_list()?) - } else { - None - }; - - let limit = if self.parse_keyword("LIMIT") { - self.parse_limit()? - } else { - None - }; - Ok(SQLSelect { projection, selection, relation, joins, - limit, - order_by, group_by, having, }) @@ -1190,18 +1234,14 @@ impl Parser { /// A table name or a parenthesized subquery, followed by optional `[AS] alias` pub fn parse_table_factor(&mut self) -> Result { if self.consume_token(&Token::LParen) { - self.expect_keyword("SELECT")?; - let subquery = self.parse_select()?; + let subquery = Box::new(self.parse_query()?); self.expect_token(&Token::RParen)?; - Ok(TableFactor::Derived { - subquery: Box::new(subquery), - alias: self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?, - }) + let alias = self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::Derived { subquery, alias }) } else { - Ok(TableFactor::Table { - name: self.parse_object_name()?, - alias: self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?, - }) + let name = self.parse_object_name()?; + let alias = self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::Table { name, alias }) } } diff --git a/tests/sqlparser_ansi.rs b/tests/sqlparser_ansi.rs index 7e4900f5..871046b1 100644 --- a/tests/sqlparser_ansi.rs +++ b/tests/sqlparser_ansi.rs @@ -11,7 +11,10 @@ fn parse_simple_select() { let ast = Parser::parse_sql(&AnsiSqlDialect {}, sql).unwrap(); assert_eq!(1, ast.len()); match ast.first().unwrap() { - SQLStatement::SQLSelect(SQLSelect { projection, .. }) => { + SQLStatement::SQLSelect(SQLQuery { + body: SQLSelect { projection, .. }, + .. + }) => { assert_eq!(3, projection.len()); } _ => assert!(false), diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_generic.rs index 9784eed7..e99f0460 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_generic.rs @@ -57,32 +57,22 @@ fn parse_simple_select() { #[test] fn parse_select_wildcard() { - let sql = String::from("SELECT * FROM customer"); - match verified_stmt(&sql) { - SQLStatement::SQLSelect(SQLSelect { projection, .. }) => { - assert_eq!(1, projection.len()); - assert_eq!(ASTNode::SQLWildcard, projection[0]); - } - _ => assert!(false), - } + let sql = "SELECT * FROM foo"; + let select = verified_only_select(sql); + assert_eq!(&ASTNode::SQLWildcard, only(&select.projection)); } #[test] fn parse_select_count_wildcard() { - let sql = String::from("SELECT COUNT(*) FROM customer"); - match verified_stmt(&sql) { - SQLStatement::SQLSelect(SQLSelect { projection, .. }) => { - assert_eq!(1, projection.len()); - assert_eq!( - ASTNode::SQLFunction { - id: "COUNT".to_string(), - args: vec![ASTNode::SQLWildcard], - }, - projection[0] - ); - } - _ => assert!(false), - } + let sql = "SELECT COUNT(*) FROM customer"; + let select = verified_only_select(sql); + assert_eq!( + &ASTNode::SQLFunction { + id: "COUNT".to_string(), + args: vec![ASTNode::SQLWildcard], + }, + expr_from_projection(only(&select.projection)) + ); } #[test] @@ -652,6 +642,59 @@ fn parse_join_syntax_variants() { ); } +#[test] +fn parse_ctes() { + // To be valid SQL this needs aliases for the derived columns, but + // we don't support them yet in the context of a SELECT's projection. + let cte_sqls = vec!["SELECT 1", "SELECT 2"]; + let with = &format!( + "WITH a AS ({}), b AS ({}) SELECT foo + bar FROM a, b", + cte_sqls[0], cte_sqls[1] + ); + + fn assert_ctes_in_select(expected: &Vec<&str>, sel: &SQLQuery) { + for i in 0..1 { + let Cte { + ref query, + ref alias, + } = sel.ctes[i]; + assert_eq!(expected[i], query.to_string()); + assert_eq!(if i == 0 { "a" } else { "b" }, alias); + } + } + + // Top-level CTE + assert_ctes_in_select(&cte_sqls, &verified_query(with)); + // CTE in a subquery + let sql = &format!("SELECT ({})", with); + let select = verified_only_select(sql); + match expr_from_projection(only(&select.projection)) { + &ASTNode::SQLSubquery(ref subquery) => { + assert_ctes_in_select(&cte_sqls, subquery.as_ref()); + } + _ => panic!("Expected subquery"), + } + // CTE in a derived table + let sql = &format!("SELECT * FROM ({})", with); + let select = verified_only_select(sql); + match select.relation { + Some(TableFactor::Derived { subquery, .. }) => { + assert_ctes_in_select(&cte_sqls, subquery.as_ref()) + } + _ => panic!("Expected derived table"), + } + // CTE in a view + let sql = &format!("CREATE VIEW v AS {}", with); + match verified_stmt(sql) { + SQLStatement::SQLCreateView { query, .. } => assert_ctes_in_select(&cte_sqls, &query), + _ => panic!("Expected CREATE VIEW"), + } + // CTE in a CTE... + let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with); + let select = verified_query(sql); + assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query); +} + #[test] fn parse_derived_tables() { let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b"; @@ -730,7 +773,7 @@ fn only<'a, T>(v: &'a Vec) -> &'a T { v.first().unwrap() } -fn verified_query(query: &str) -> SQLSelect { +fn verified_query(query: &str) -> SQLQuery { match verified_stmt(query) { SQLStatement::SQLSelect(select) => select, _ => panic!("Expected SELECT"), @@ -742,7 +785,7 @@ fn expr_from_projection(item: &ASTNode) -> &ASTNode { } fn verified_only_select(query: &str) -> SQLSelect { - verified_query(query) + verified_query(query).body } fn verified_stmt(query: &str) -> SQLStatement {