Support basic CTEs (WITH)

Some unsupported features are noted as TODOs.
This commit is contained in:
Nickolay Ponomarev 2019-02-03 06:40:17 +03:00
parent f958e9d3cf
commit bf0c07bb1b
5 changed files with 217 additions and 91 deletions

View file

@ -20,7 +20,9 @@ mod sqltype;
mod table_key; mod table_key;
mod value; mod value;
pub use self::query::{Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLSelect, TableFactor}; pub use self::query::{
Cte, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect, TableFactor,
};
pub use self::sqltype::SQLType; pub use self::sqltype::SQLType;
pub use self::table_key::{AlterOperation, Key, TableKey}; pub use self::table_key::{AlterOperation, Key, TableKey};
pub use self::value::Value; pub use self::value::Value;
@ -76,7 +78,7 @@ pub enum ASTNode {
}, },
/// A parenthesized subquery `(SELECT ...)`, used in expression like /// A parenthesized subquery `(SELECT ...)`, used in expression like
/// `SELECT (subquery) AS x` or `WHERE (subquery) = x` /// `SELECT (subquery) AS x` or `WHERE (subquery) = x`
SQLSubquery(Box<SQLSelect>), SQLSubquery(Box<SQLQuery>),
} }
impl ToString for ASTNode { impl ToString for ASTNode {
@ -139,7 +141,7 @@ impl ToString for ASTNode {
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum SQLStatement { pub enum SQLStatement {
/// SELECT /// SELECT
SQLSelect(SQLSelect), SQLSelect(SQLQuery),
/// INSERT /// INSERT
SQLInsert { SQLInsert {
/// TABLE /// TABLE
@ -177,7 +179,7 @@ pub enum SQLStatement {
SQLCreateView { SQLCreateView {
/// View name /// View name
name: SQLObjectName, name: SQLObjectName,
query: SQLSelect, query: SQLQuery,
}, },
/// CREATE TABLE /// CREATE TABLE
SQLCreateTable { SQLCreateTable {

View file

@ -1,5 +1,53 @@
use super::*; use super::*;
/// The most complete variant of a `SELECT` query expression, optionally
/// including `WITH`, `UNION` / other set operations, and `ORDER BY`.
#[derive(Debug, Clone, PartialEq)]
pub struct SQLQuery {
/// WITH (common table expressions, or CTEs)
pub ctes: Vec<Cte>,
/// SELECT or UNION / EXCEPT / INTECEPT
pub body: SQLSelect,
/// ORDER BY
pub order_by: Option<Vec<SQLOrderByExpr>>,
/// LIMIT
pub limit: Option<ASTNode>,
}
impl ToString for SQLQuery {
fn to_string(&self) -> String {
let mut s = String::new();
if !self.ctes.is_empty() {
s += &format!(
"WITH {} ",
self.ctes
.iter()
.map(|cte| format!("{} AS ({})", cte.alias, cte.query.to_string()))
.collect::<Vec<String>>()
.join(", ")
)
}
s += &self.body.to_string();
if let Some(ref order_by) = self.order_by {
s += &format!(
" ORDER BY {}",
order_by
.iter()
.map(|o| o.to_string())
.collect::<Vec<String>>()
.join(", ")
);
}
if let Some(ref limit) = self.limit {
s += &format!(" LIMIT {}", limit.to_string());
}
s
}
}
/// A restricted variant of `SELECT` (without CTEs/`ORDER BY`), which may
/// appear either as the only body item of an `SQLQuery`, or as an operand
/// to a set operation like `UNION`.
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct SQLSelect { pub struct SQLSelect {
/// projection expressions /// projection expressions
@ -10,14 +58,10 @@ pub struct SQLSelect {
pub joins: Vec<Join>, pub joins: Vec<Join>,
/// WHERE /// WHERE
pub selection: Option<ASTNode>, pub selection: Option<ASTNode>,
/// ORDER BY
pub order_by: Option<Vec<SQLOrderByExpr>>,
/// GROUP BY /// GROUP BY
pub group_by: Option<Vec<ASTNode>>, pub group_by: Option<Vec<ASTNode>>,
/// HAVING /// HAVING
pub having: Option<ASTNode>, pub having: Option<ASTNode>,
/// LIMIT
pub limit: Option<ASTNode>,
} }
impl ToString for SQLSelect { impl ToString for SQLSelect {
@ -52,23 +96,17 @@ impl ToString for SQLSelect {
if let Some(ref having) = self.having { if let Some(ref having) = self.having {
s += &format!(" HAVING {}", having.to_string()); s += &format!(" HAVING {}", having.to_string());
} }
if let Some(ref order_by) = self.order_by {
s += &format!(
" ORDER BY {}",
order_by
.iter()
.map(|o| o.to_string())
.collect::<Vec<String>>()
.join(", ")
);
}
if let Some(ref limit) = self.limit {
s += &format!(" LIMIT {}", limit.to_string());
}
s s
} }
} }
/// A single CTE (used after `WITH`): `alias AS ( query )`
#[derive(Debug, Clone, PartialEq)]
pub struct Cte {
pub alias: SQLIdent,
pub query: SQLQuery,
}
/// A table name or a parenthesized subquery with an optional alias /// A table name or a parenthesized subquery with an optional alias
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum TableFactor { pub enum TableFactor {
@ -77,7 +115,7 @@ pub enum TableFactor {
alias: Option<SQLIdent>, alias: Option<SQLIdent>,
}, },
Derived { Derived {
subquery: Box<SQLSelect>, subquery: Box<SQLQuery>,
alias: Option<SQLIdent>, alias: Option<SQLIdent>,
}, },
} }

View file

@ -88,7 +88,10 @@ impl Parser {
match self.next_token() { match self.next_token() {
Some(t) => match t { Some(t) => match t {
Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() { Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() {
"SELECT" => Ok(SQLStatement::SQLSelect(self.parse_select()?)), "SELECT" | "WITH" => {
self.prev_token();
Ok(SQLStatement::SQLSelect(self.parse_query()?))
}
"CREATE" => Ok(self.parse_create()?), "CREATE" => Ok(self.parse_create()?),
"DELETE" => Ok(self.parse_delete()?), "DELETE" => Ok(self.parse_delete()?),
"INSERT" => Ok(self.parse_insert()?), "INSERT" => Ok(self.parse_insert()?),
@ -198,8 +201,9 @@ impl Parser {
self.parse_sql_value() self.parse_sql_value()
} }
Token::LParen => { Token::LParen => {
let expr = if self.parse_keyword("SELECT") { let expr = if self.parse_keyword("SELECT") || self.parse_keyword("WITH") {
ASTNode::SQLSubquery(Box::new(self.parse_select()?)) self.prev_token();
ASTNode::SQLSubquery(Box::new(self.parse_query()?))
} else { } else {
ASTNode::SQLNested(Box::new(self.parse_expr()?)) ASTNode::SQLNested(Box::new(self.parse_expr()?))
}; };
@ -568,8 +572,7 @@ impl Parser {
// Some dialects allow WITH here, followed by some keywords (e.g. MS SQL) // Some dialects allow WITH here, followed by some keywords (e.g. MS SQL)
// or `(k1=v1, k2=v2, ...)` (Postgres) // or `(k1=v1, k2=v2, ...)` (Postgres)
self.expect_keyword("AS")?; self.expect_keyword("AS")?;
self.expect_keyword("SELECT")?; let query = self.parse_query()?;
let query = self.parse_select()?;
// Optional `WITH [ CASCADED | LOCAL ] CHECK OPTION` is widely supported here. // Optional `WITH [ CASCADED | LOCAL ] CHECK OPTION` is widely supported here.
Ok(SQLStatement::SQLCreateView { name, query }) Ok(SQLStatement::SQLCreateView { name, query })
} }
@ -673,18 +676,9 @@ impl Parser {
let table_name = self.parse_object_name()?; let table_name = self.parse_object_name()?;
let operation: Result<AlterOperation, ParserError> = let operation: Result<AlterOperation, ParserError> =
if self.parse_keywords(vec!["ADD", "CONSTRAINT"]) { if self.parse_keywords(vec!["ADD", "CONSTRAINT"]) {
match self.next_token() { let constraint_name = self.parse_identifier()?;
Some(Token::SQLWord(ref id)) => { let table_key = self.parse_table_key(constraint_name)?;
let table_key = self.parse_table_key(id.as_sql_ident())?;
Ok(AlterOperation::AddConstraint(table_key)) Ok(AlterOperation::AddConstraint(table_key))
}
_ => {
return parser_err!(format!(
"Expecting identifier, found : {:?}",
self.peek_token()
));
}
}
} else { } else {
return parser_err!(format!( return parser_err!(format!(
"Expecting ADD CONSTRAINT, found :{:?}", "Expecting ADD CONSTRAINT, found :{:?}",
@ -1079,6 +1073,14 @@ impl Parser {
Ok(SQLObjectName(self.parse_list_of_ids(&Token::Period)?)) Ok(SQLObjectName(self.parse_list_of_ids(&Token::Period)?))
} }
/// Parse a simple one-word identifier (possibly quoted, possibly a keyword)
pub fn parse_identifier(&mut self) -> Result<SQLIdent, ParserError> {
match self.next_token() {
Some(Token::SQLWord(w)) => Ok(w.as_sql_ident()),
unexpected => parser_err!(format!("Expected identifier, found {:?}", unexpected)),
}
}
/// Parse a comma-separated list of unqualified, possibly quoted identifiers /// Parse a comma-separated list of unqualified, possibly quoted identifiers
pub fn parse_column_names(&mut self) -> Result<Vec<SQLIdent>, ParserError> { pub fn parse_column_names(&mut self) -> Result<Vec<SQLIdent>, ParserError> {
Ok(self.parse_list_of_ids(&Token::Comma)?) Ok(self.parse_list_of_ids(&Token::Comma)?)
@ -1132,7 +1134,64 @@ impl Parser {
}) })
} }
/// Parse a SELECT statement /// Parse a query expression, i.e. a `SELECT` statement optionally
/// preceeded with some `WITH` CTE declarations and optionally followed
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
/// expect the initial keyword to be already consumed
pub fn parse_query(&mut self) -> Result<SQLQuery, ParserError> {
let ctes = if self.parse_keyword("WITH") {
// TODO: optional RECURSIVE
self.parse_cte_list()?
} else {
vec![]
};
self.expect_keyword("SELECT")?;
let body = self.parse_select()?;
let order_by = if self.parse_keywords(vec!["ORDER", "BY"]) {
Some(self.parse_order_by_expr_list()?)
} else {
None
};
let limit = if self.parse_keyword("LIMIT") {
self.parse_limit()?
} else {
None
};
Ok(SQLQuery {
ctes,
body,
limit,
order_by,
})
}
/// Parse one or more (comma-separated) `alias AS (subquery)` CTEs,
/// assuming the initial `WITH` was already consumed.
fn parse_cte_list(&mut self) -> Result<Vec<Cte>, ParserError> {
let mut cte = vec![];
loop {
let alias = self.parse_identifier()?;
// TODO: Optional `( <column list> )`
self.expect_keyword("AS")?;
self.expect_token(&Token::LParen)?;
cte.push(Cte {
alias,
query: self.parse_query()?,
});
self.expect_token(&Token::RParen)?;
if !self.consume_token(&Token::Comma) {
break;
}
}
return Ok(cte);
}
/// Parse a restricted `SELECT` statement (no CTEs / `UNION` / `ORDER BY`),
/// assuming the initial `SELECT` was already consumed
pub fn parse_select(&mut self) -> Result<SQLSelect, ParserError> { pub fn parse_select(&mut self) -> Result<SQLSelect, ParserError> {
let projection = self.parse_expr_list()?; let projection = self.parse_expr_list()?;
@ -1145,8 +1204,7 @@ impl Parser {
}; };
let selection = if self.parse_keyword("WHERE") { let selection = if self.parse_keyword("WHERE") {
let expr = self.parse_expr()?; Some(self.parse_expr()?)
Some(expr)
} else { } else {
None None
}; };
@ -1163,25 +1221,11 @@ impl Parser {
None None
}; };
let order_by = if self.parse_keywords(vec!["ORDER", "BY"]) {
Some(self.parse_order_by_expr_list()?)
} else {
None
};
let limit = if self.parse_keyword("LIMIT") {
self.parse_limit()?
} else {
None
};
Ok(SQLSelect { Ok(SQLSelect {
projection, projection,
selection, selection,
relation, relation,
joins, joins,
limit,
order_by,
group_by, group_by,
having, having,
}) })
@ -1190,18 +1234,14 @@ impl Parser {
/// A table name or a parenthesized subquery, followed by optional `[AS] alias` /// A table name or a parenthesized subquery, followed by optional `[AS] alias`
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> { pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
if self.consume_token(&Token::LParen) { if self.consume_token(&Token::LParen) {
self.expect_keyword("SELECT")?; let subquery = Box::new(self.parse_query()?);
let subquery = self.parse_select()?;
self.expect_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok(TableFactor::Derived { let alias = self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
subquery: Box::new(subquery), Ok(TableFactor::Derived { subquery, alias })
alias: self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?,
})
} else { } else {
Ok(TableFactor::Table { let name = self.parse_object_name()?;
name: self.parse_object_name()?, let alias = self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
alias: self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?, Ok(TableFactor::Table { name, alias })
})
} }
} }

View file

@ -11,7 +11,10 @@ fn parse_simple_select() {
let ast = Parser::parse_sql(&AnsiSqlDialect {}, sql).unwrap(); let ast = Parser::parse_sql(&AnsiSqlDialect {}, sql).unwrap();
assert_eq!(1, ast.len()); assert_eq!(1, ast.len());
match ast.first().unwrap() { match ast.first().unwrap() {
SQLStatement::SQLSelect(SQLSelect { projection, .. }) => { SQLStatement::SQLSelect(SQLQuery {
body: SQLSelect { projection, .. },
..
}) => {
assert_eq!(3, projection.len()); assert_eq!(3, projection.len());
} }
_ => assert!(false), _ => assert!(false),

View file

@ -57,33 +57,23 @@ fn parse_simple_select() {
#[test] #[test]
fn parse_select_wildcard() { fn parse_select_wildcard() {
let sql = String::from("SELECT * FROM customer"); let sql = "SELECT * FROM foo";
match verified_stmt(&sql) { let select = verified_only_select(sql);
SQLStatement::SQLSelect(SQLSelect { projection, .. }) => { assert_eq!(&ASTNode::SQLWildcard, only(&select.projection));
assert_eq!(1, projection.len());
assert_eq!(ASTNode::SQLWildcard, projection[0]);
}
_ => assert!(false),
}
} }
#[test] #[test]
fn parse_select_count_wildcard() { fn parse_select_count_wildcard() {
let sql = String::from("SELECT COUNT(*) FROM customer"); let sql = "SELECT COUNT(*) FROM customer";
match verified_stmt(&sql) { let select = verified_only_select(sql);
SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(1, projection.len());
assert_eq!( assert_eq!(
ASTNode::SQLFunction { &ASTNode::SQLFunction {
id: "COUNT".to_string(), id: "COUNT".to_string(),
args: vec![ASTNode::SQLWildcard], args: vec![ASTNode::SQLWildcard],
}, },
projection[0] expr_from_projection(only(&select.projection))
); );
} }
_ => assert!(false),
}
}
#[test] #[test]
fn parse_not() { fn parse_not() {
@ -652,6 +642,59 @@ fn parse_join_syntax_variants() {
); );
} }
#[test]
fn parse_ctes() {
// To be valid SQL this needs aliases for the derived columns, but
// we don't support them yet in the context of a SELECT's projection.
let cte_sqls = vec!["SELECT 1", "SELECT 2"];
let with = &format!(
"WITH a AS ({}), b AS ({}) SELECT foo + bar FROM a, b",
cte_sqls[0], cte_sqls[1]
);
fn assert_ctes_in_select(expected: &Vec<&str>, sel: &SQLQuery) {
for i in 0..1 {
let Cte {
ref query,
ref alias,
} = sel.ctes[i];
assert_eq!(expected[i], query.to_string());
assert_eq!(if i == 0 { "a" } else { "b" }, alias);
}
}
// Top-level CTE
assert_ctes_in_select(&cte_sqls, &verified_query(with));
// CTE in a subquery
let sql = &format!("SELECT ({})", with);
let select = verified_only_select(sql);
match expr_from_projection(only(&select.projection)) {
&ASTNode::SQLSubquery(ref subquery) => {
assert_ctes_in_select(&cte_sqls, subquery.as_ref());
}
_ => panic!("Expected subquery"),
}
// CTE in a derived table
let sql = &format!("SELECT * FROM ({})", with);
let select = verified_only_select(sql);
match select.relation {
Some(TableFactor::Derived { subquery, .. }) => {
assert_ctes_in_select(&cte_sqls, subquery.as_ref())
}
_ => panic!("Expected derived table"),
}
// CTE in a view
let sql = &format!("CREATE VIEW v AS {}", with);
match verified_stmt(sql) {
SQLStatement::SQLCreateView { query, .. } => assert_ctes_in_select(&cte_sqls, &query),
_ => panic!("Expected CREATE VIEW"),
}
// CTE in a CTE...
let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with);
let select = verified_query(sql);
assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query);
}
#[test] #[test]
fn parse_derived_tables() { fn parse_derived_tables() {
let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b"; let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b";
@ -730,7 +773,7 @@ fn only<'a, T>(v: &'a Vec<T>) -> &'a T {
v.first().unwrap() v.first().unwrap()
} }
fn verified_query(query: &str) -> SQLSelect { fn verified_query(query: &str) -> SQLQuery {
match verified_stmt(query) { match verified_stmt(query) {
SQLStatement::SQLSelect(select) => select, SQLStatement::SQLSelect(select) => select,
_ => panic!("Expected SELECT"), _ => panic!("Expected SELECT"),
@ -742,7 +785,7 @@ fn expr_from_projection(item: &ASTNode) -> &ASTNode {
} }
fn verified_only_select(query: &str) -> SQLSelect { fn verified_only_select(query: &str) -> SQLSelect {
verified_query(query) verified_query(query).body
} }
fn verified_stmt(query: &str) -> SQLStatement { fn verified_stmt(query: &str) -> SQLStatement {