Add support for Recursive CTEs (#278)

i.e. `WITH RECURSIVE ... AS ( ... ) SELECT` - see https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#with-clause

Fixes #277
This commit is contained in:
rhanqtl 2020-10-11 14:43:51 +08:00 committed by GitHub
parent 54be3912a9
commit 9f772f03b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 19 deletions

View file

@ -30,7 +30,7 @@ pub use self::ddl::{
pub use self::operator::{BinaryOperator, UnaryOperator};
pub use self::query::{
Cte, Fetch, Join, JoinConstraint, JoinOperator, Offset, OffsetRows, OrderByExpr, Query, Select,
SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values,
SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values, With,
};
pub use self::value::{DateTimeField, Value};

View file

@ -20,8 +20,8 @@ use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Query {
/// WITH (common table expressions, or CTEs)
pub ctes: Vec<Cte>,
/// SELECT or UNION / EXCEPT / INTECEPT
pub with: Option<With>,
/// SELECT or UNION / EXCEPT / INTERSECT
pub body: SetExpr,
/// ORDER BY
pub order_by: Vec<OrderByExpr>,
@ -35,8 +35,8 @@ pub struct Query {
impl fmt::Display for Query {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.ctes.is_empty() {
write!(f, "WITH {} ", display_comma_separated(&self.ctes))?;
if let Some(ref with) = self.with {
write!(f, "{} ", with)?;
}
write!(f, "{}", self.body)?;
if !self.order_by.is_empty() {
@ -157,6 +157,24 @@ impl fmt::Display for Select {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct With {
pub recursive: bool,
pub cte_tables: Vec<Cte>,
}
impl fmt::Display for With {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"WITH {}{}",
if self.recursive { "RECURSIVE " } else { "" },
display_comma_separated(&self.cte_tables)
)
}
}
/// A single CTE (used after `WITH`): `alias [(col1, col2, ...)] AS ( query )`
/// The names in the column list before `AS`, when specified, replace the names
/// of the columns returned by the query. The parser does not validate that the

View file

@ -1795,11 +1795,13 @@ impl<'a> Parser<'a> {
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
/// expect the initial keyword to be already consumed
pub fn parse_query(&mut self) -> Result<Query, ParserError> {
let ctes = if self.parse_keyword(Keyword::WITH) {
// TODO: optional RECURSIVE
self.parse_comma_separated(Parser::parse_cte)?
let with = if self.parse_keyword(Keyword::WITH) {
Some(With {
recursive: self.parse_keyword(Keyword::RECURSIVE),
cte_tables: self.parse_comma_separated(Parser::parse_cte)?,
})
} else {
vec![]
None
};
let body = self.parse_query_body(0)?;
@ -1829,7 +1831,7 @@ impl<'a> Parser<'a> {
};
Ok(Query {
ctes,
with,
body,
limit,
order_by,

View file

@ -382,10 +382,7 @@ impl<'a> Tokenizer<'a> {
// numbers
'0'..='9' => {
// TODO: https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#unsigned-numeric-literal
let s = peeking_take_while(chars, |ch| match ch {
'0'..='9' | '.' => true,
_ => false,
});
let s = peeking_take_while(chars, |ch| matches!(ch, '0'..='9' | '.'));
Ok(Some(Token::Number(s)))
}
// punctuation

View file

@ -2389,7 +2389,7 @@ fn parse_ctes() {
fn assert_ctes_in_select(expected: &[&str], sel: &Query) {
for (i, exp) in expected.iter().enumerate() {
let Cte { alias, query } = &sel.ctes[i];
let Cte { alias, query } = &sel.with.as_ref().unwrap().cte_tables[i];
assert_eq!(*exp, query.to_string());
assert_eq!(
if i == 0 {
@ -2432,7 +2432,7 @@ fn parse_ctes() {
// CTE in a CTE...
let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with);
let select = verified_query(sql);
assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query);
assert_ctes_in_select(&cte_sqls, &only(&select.with.unwrap().cte_tables).query);
}
#[test]
@ -2441,10 +2441,47 @@ fn parse_cte_renamed_columns() {
let query = all_dialects().verified_query(sql);
assert_eq!(
vec![Ident::new("col1"), Ident::new("col2")],
query.ctes.first().unwrap().alias.columns
query
.with
.unwrap()
.cte_tables
.first()
.unwrap()
.alias
.columns
);
}
#[test]
fn parse_recursive_cte() {
let cte_query = "SELECT 1 UNION ALL SELECT val + 1 FROM nums WHERE val < 10".to_owned();
let sql = &format!(
"WITH RECURSIVE nums (val) AS ({}) SELECT * FROM nums",
cte_query
);
let cte_query = verified_query(&cte_query);
let query = verified_query(sql);
let with = query.with.as_ref().unwrap();
assert!(with.recursive);
assert_eq!(with.cte_tables.len(), 1);
let expected = Cte {
alias: TableAlias {
name: Ident {
value: "nums".to_string(),
quote_style: None,
},
columns: vec![Ident {
value: "val".to_string(),
quote_style: None,
}],
},
query: cte_query,
};
assert_eq!(with.cte_tables.first().unwrap(), &expected);
}
#[test]
fn parse_derived_tables() {
let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b";
@ -3266,8 +3303,8 @@ fn parse_drop_index() {
fn all_keywords_sorted() {
// assert!(ALL_KEYWORDS.is_sorted())
let mut copy = Vec::from(ALL_KEYWORDS);
copy.sort();
assert!(copy == ALL_KEYWORDS)
copy.sort_unstable();
assert_eq!(copy, ALL_KEYWORDS)
}
fn parse_sql_statements(sql: &str) -> Result<Vec<Statement>, ParserError> {