Properly handle mixed implicit and explicit joins

Parse a query like

    SELECT * FROM a NATURAL JOIN b, c NATURAL JOIN d

as the SQL specification requires, i.e.:

    from: [
        TableReference {
            relation: TableFactor::Table("a"),
            joins: [Join {
                relation: TableFactor::Table("b"),
                join_operator: JoinOperator::Natural,
            }]
        },
        TableReference {
            relation: TableFactor::Table("c"),
            joins: [Join {
                relation: TableFactor::Table("d"),
                join_operator: JoinOperator::Natural,
            }]
        }
    ]

Previously we were parsing such queries as

    relation: TableFactor::Table("a"),
    joins: [
        Join {
            relation: TableFactor::Table("b"),
            join_operator: JoinOperator::Natural,
        },
        Join {
            relation: TableFactor::Table("c"),
            join_operator: JoinOperator::Implicit,
        },
        Join {
            relation: TableFactor::Table("d"),
            join_operator: JoinOperator::Natural,
        },
    ]

which did not make the join hierarchy clear.
This commit is contained in:
Nikhil Benesch 2019-06-07 23:45:26 -04:00
parent 518c8833d2
commit b841dccc2c
No known key found for this signature in database
GPG key ID: FCF98542083C5A69
6 changed files with 219 additions and 164 deletions

View file

@ -27,7 +27,7 @@ pub use self::ddl::{
};
pub use self::query::{
Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect,
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor,
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins,
};
pub use self::sqltype::SQLType;
pub use self::value::{SQLDateTimeField, Value};

View file

@ -113,9 +113,7 @@ pub struct SQLSelect {
/// projection expressions
pub projection: Vec<SQLSelectItem>,
/// FROM
pub relation: Option<TableFactor>,
/// JOIN
pub joins: Vec<Join>,
pub from: Vec<TableWithJoins>,
/// WHERE
pub selection: Option<ASTNode>,
/// GROUP BY
@ -131,11 +129,8 @@ impl ToString for SQLSelect {
if self.distinct { " DISTINCT" } else { "" },
comma_separated_string(&self.projection)
);
if let Some(ref relation) = self.relation {
s += &format!(" FROM {}", relation.to_string());
}
for join in &self.joins {
s += &join.to_string();
if !self.from.is_empty() {
s += &format!(" FROM {}", comma_separated_string(&self.from));
}
if let Some(ref selection) = self.selection {
s += &format!(" WHERE {}", selection.to_string());
@ -197,6 +192,22 @@ impl ToString for SQLSelectItem {
}
}
#[derive(Debug, Clone, PartialEq, Hash)]
pub struct TableWithJoins {
pub relation: TableFactor,
pub joins: Vec<Join>,
}
impl ToString for TableWithJoins {
fn to_string(&self) -> String {
let mut s = self.relation.to_string();
for join in &self.joins {
s += &join.to_string();
}
s
}
}
/// A table name or a parenthesized subquery with an optional alias
#[derive(Debug, Clone, PartialEq, Hash)]
pub enum TableFactor {
@ -215,10 +226,7 @@ pub enum TableFactor {
subquery: Box<SQLQuery>,
alias: Option<TableAlias>,
},
NestedJoin {
base: Box<TableFactor>,
joins: Vec<Join>,
},
NestedJoin(Box<TableWithJoins>),
}
impl ToString for TableFactor {
@ -257,12 +265,8 @@ impl ToString for TableFactor {
}
s
}
TableFactor::NestedJoin { base, joins } => {
let mut s = base.to_string();
for join in joins {
s += &join.to_string();
}
format!("({})", s)
TableFactor::NestedJoin(table_reference) => {
format!("({})", table_reference.to_string())
}
}
}
@ -313,7 +317,6 @@ impl ToString for Join {
suffix(constraint)
),
JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()),
JoinOperator::Implicit => format!(", {}", self.relation.to_string()),
JoinOperator::LeftOuter(constraint) => format!(
" {}LEFT JOIN {}{}",
prefix(constraint),
@ -342,7 +345,6 @@ pub enum JoinOperator {
LeftOuter(JoinConstraint),
RightOuter(JoinConstraint),
FullOuter(JoinConstraint),
Implicit,
Cross,
}

View file

@ -1570,13 +1570,15 @@ impl Parser {
}
let projection = self.parse_select_list()?;
let (relation, joins) = if self.parse_keyword("FROM") {
let relation = Some(self.parse_table_factor()?);
let joins = self.parse_joins()?;
(relation, joins)
} else {
(None, vec![])
};
let mut from = vec![];
if self.parse_keyword("FROM") {
loop {
from.push(self.parse_table_and_joins()?);
if !self.consume_token(&Token::Comma) {
break;
}
}
}
let selection = if self.parse_keyword("WHERE") {
Some(self.parse_expr()?)
@ -1599,95 +1601,18 @@ impl Parser {
Ok(SQLSelect {
distinct,
projection,
from,
selection,
relation,
joins,
group_by,
having,
})
}
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
let lateral = self.parse_keyword("LATERAL");
if self.consume_token(&Token::LParen) {
if self.parse_keyword("SELECT")
|| self.parse_keyword("WITH")
|| self.parse_keyword("VALUES")
{
self.prev_token();
let subquery = Box::new(self.parse_query()?);
self.expect_token(&Token::RParen)?;
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
Ok(TableFactor::Derived {
lateral,
subquery,
alias,
})
} else if lateral {
parser_err!("Expected subquery after LATERAL, found nested join".to_string())
} else {
let base = Box::new(self.parse_table_factor()?);
let joins = self.parse_joins()?;
self.expect_token(&Token::RParen)?;
Ok(TableFactor::NestedJoin { base, joins })
}
} else if lateral {
self.expected("subquery after LATERAL", self.peek_token())
} else {
let name = self.parse_object_name()?;
// Postgres, MSSQL: table-valued functions:
let args = if self.consume_token(&Token::LParen) {
self.parse_optional_args()?
} else {
vec![]
};
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
// MSSQL-specific table hints:
let mut with_hints = vec![];
if self.parse_keyword("WITH") {
if self.consume_token(&Token::LParen) {
with_hints = self.parse_expr_list()?;
self.expect_token(&Token::RParen)?;
} else {
// rewind, as WITH may belong to the next statement's CTE
self.prev_token();
}
};
Ok(TableFactor::Table {
name,
alias,
args,
with_hints,
})
}
}
fn parse_join_constraint(&mut self, natural: bool) -> Result<JoinConstraint, ParserError> {
if natural {
Ok(JoinConstraint::Natural)
} else if self.parse_keyword("ON") {
let constraint = self.parse_expr()?;
Ok(JoinConstraint::On(constraint))
} else if self.parse_keyword("USING") {
let columns = self.parse_parenthesized_column_list(Mandatory)?;
Ok(JoinConstraint::Using(columns))
} else {
self.expected("ON, or USING after JOIN", self.peek_token())
}
}
fn parse_joins(&mut self) -> Result<Vec<Join>, ParserError> {
pub fn parse_table_and_joins(&mut self) -> Result<TableWithJoins, ParserError> {
let relation = self.parse_table_factor()?;
let mut joins = vec![];
loop {
let join = match &self.peek_token() {
Some(Token::Comma) => {
self.next_token();
Join {
relation: self.parse_table_factor()?,
join_operator: JoinOperator::Implicit,
}
}
Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => {
self.next_token();
self.expect_keyword("JOIN")?;
@ -1736,7 +1661,76 @@ impl Parser {
};
joins.push(join);
}
Ok(joins)
Ok(TableWithJoins { relation, joins })
}
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
let lateral = self.parse_keyword("LATERAL");
if self.consume_token(&Token::LParen) {
if self.parse_keyword("SELECT")
|| self.parse_keyword("WITH")
|| self.parse_keyword("VALUES")
{
self.prev_token();
let subquery = Box::new(self.parse_query()?);
self.expect_token(&Token::RParen)?;
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
Ok(TableFactor::Derived {
lateral,
subquery,
alias,
})
} else if lateral {
parser_err!("Expected subquery after LATERAL, found nested join".to_string())
} else {
let table_reference = self.parse_table_and_joins()?;
self.expect_token(&Token::RParen)?;
Ok(TableFactor::NestedJoin(Box::new(table_reference)))
}
} else if lateral {
self.expected("subquery after LATERAL", self.peek_token())
} else {
let name = self.parse_object_name()?;
// Postgres, MSSQL: table-valued functions:
let args = if self.consume_token(&Token::LParen) {
self.parse_optional_args()?
} else {
vec![]
};
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
// MSSQL-specific table hints:
let mut with_hints = vec![];
if self.parse_keyword("WITH") {
if self.consume_token(&Token::LParen) {
with_hints = self.parse_expr_list()?;
self.expect_token(&Token::RParen)?;
} else {
// rewind, as WITH may belong to the next statement's CTE
self.prev_token();
}
};
Ok(TableFactor::Table {
name,
alias,
args,
with_hints,
})
}
}
fn parse_join_constraint(&mut self, natural: bool) -> Result<JoinConstraint, ParserError> {
if natural {
Ok(JoinConstraint::Natural)
} else if self.parse_keyword("ON") {
let constraint = self.parse_expr()?;
Ok(JoinConstraint::On(constraint))
} else if self.parse_keyword("USING") {
let columns = self.parse_parenthesized_column_list(Mandatory)?;
Ok(JoinConstraint::Using(columns))
} else {
self.expected("ON, or USING after JOIN", self.peek_token())
}
}
/// Parse an INSERT statement

View file

@ -109,9 +109,13 @@ pub fn all_dialects() -> TestedDialects {
}
}
pub fn only<T>(v: &[T]) -> &T {
assert_eq!(1, v.len());
v.first().unwrap()
pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
let mut iter = v.into_iter();
if let (Some(item), None) = (iter.next(), iter.next()) {
item
} else {
panic!("only called on collection without exactly one item")
}
}
pub fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode {

View file

@ -1300,7 +1300,7 @@ fn parse_delimited_identifiers() {
r#"SELECT "alias"."bar baz", "myfun"(), "simple id" AS "column alias" FROM "a table" AS "alias""#
);
// check FROM
match select.relation.unwrap() {
match only(select.from).relation {
TableFactor::Table {
name,
alias,
@ -1430,16 +1430,69 @@ fn parse_implicit_join() {
let sql = "SELECT * FROM t1, t2";
let select = verified_only_select(sql);
assert_eq!(
&Join {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t2".to_string()]),
alias: None,
args: vec![],
with_hints: vec![],
vec![
TableWithJoins {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t1".into()]),
alias: None,
args: vec![],
with_hints: vec![],
},
joins: vec![],
},
join_operator: JoinOperator::Implicit
},
only(&select.joins),
TableWithJoins {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t2".into()]),
alias: None,
args: vec![],
with_hints: vec![],
},
joins: vec![],
}
],
select.from,
);
let sql = "SELECT * FROM t1a NATURAL JOIN t1b, t2a NATURAL JOIN t2b";
let select = verified_only_select(sql);
assert_eq!(
vec![
TableWithJoins {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t1a".into()]),
alias: None,
args: vec![],
with_hints: vec![],
},
joins: vec![Join {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t1b".into()]),
alias: None,
args: vec![],
with_hints: vec![],
},
join_operator: JoinOperator::Inner(JoinConstraint::Natural),
}]
},
TableWithJoins {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t2a".into()]),
alias: None,
args: vec![],
with_hints: vec![],
},
joins: vec![Join {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t2b".into()]),
alias: None,
args: vec![],
with_hints: vec![],
},
join_operator: JoinOperator::Inner(JoinConstraint::Natural),
}]
}
],
select.from,
);
}
@ -1448,7 +1501,7 @@ fn parse_cross_join() {
let sql = "SELECT * FROM t1 CROSS JOIN t2";
let select = verified_only_select(sql);
assert_eq!(
&Join {
Join {
relation: TableFactor::Table {
name: SQLObjectName(vec!["t2".to_string()]),
alias: None,
@ -1457,7 +1510,7 @@ fn parse_cross_join() {
},
join_operator: JoinOperator::Cross
},
only(&select.joins),
only(only(select.from).joins),
);
}
@ -1491,7 +1544,7 @@ fn parse_joins_on() {
}
// Test parsing of aliases
assert_eq!(
verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").joins,
only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").from).joins,
vec![join_with_constraint(
"t2",
table_alias("foo"),
@ -1504,19 +1557,19 @@ fn parse_joins_on() {
);
// Test parsing of different join operators
assert_eq!(
verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").joins,
only(&verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::Inner)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").joins,
only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").joins,
only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::RightOuter)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").joins,
only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::FullOuter)]
);
}
@ -1540,7 +1593,7 @@ fn parse_joins_using() {
}
// Test parsing of aliases
assert_eq!(
verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").joins,
only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").from).joins,
vec![join_with_constraint(
"t2",
table_alias("foo"),
@ -1553,19 +1606,19 @@ fn parse_joins_using() {
);
// Test parsing of different join operators
assert_eq!(
verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").joins,
only(&verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::Inner)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").joins,
only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").joins,
only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::RightOuter)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").joins,
only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").from).joins,
vec![join_with_constraint("t2", None, JoinOperator::FullOuter)]
);
}
@ -1584,19 +1637,19 @@ fn parse_natural_join() {
}
}
assert_eq!(
verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").joins,
only(&verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").from).joins,
vec![natural_join(JoinOperator::Inner)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").joins,
only(&verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").from).joins,
vec![natural_join(JoinOperator::LeftOuter)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").joins,
only(&verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").from).joins,
vec![natural_join(JoinOperator::RightOuter)]
);
assert_eq!(
verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").joins,
only(&verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").from).joins,
vec![natural_join(JoinOperator::FullOuter)]
);
@ -1633,17 +1686,17 @@ fn parse_join_nesting() {
macro_rules! nest {
($base:expr $(, $join:expr)*) => {
TableFactor::NestedJoin {
base: Box::new($base),
TableFactor::NestedJoin(Box::new(TableWithJoins {
relation: $base,
joins: vec![$(join($join)),*]
}
}))
};
}
let sql = "SELECT * FROM a NATURAL JOIN (b NATURAL JOIN (c NATURAL JOIN d NATURAL JOIN e)) \
NATURAL JOIN (f NATURAL JOIN (g NATURAL JOIN h))";
assert_eq!(
verified_only_select(sql).joins,
only(&verified_only_select(sql).from).joins,
vec![
join(nest!(table("b"), nest!(table("c"), table("d"), table("e")))),
join(nest!(table("f"), nest!(table("g"), table("h"))))
@ -1652,22 +1705,22 @@ fn parse_join_nesting() {
let sql = "SELECT * FROM (a NATURAL JOIN b) NATURAL JOIN c";
let select = verified_only_select(sql);
assert_eq!(select.relation.unwrap(), nest!(table("a"), table("b")),);
assert_eq!(select.joins, vec![join(table("c"))]);
let from = only(select.from);
assert_eq!(from.relation, nest!(table("a"), table("b")));
assert_eq!(from.joins, vec![join(table("c"))]);
let sql = "SELECT * FROM (((a NATURAL JOIN b)))";
let select = verified_only_select(sql);
assert_eq!(
select.relation.unwrap(),
nest!(nest!(nest!(table("a"), table("b"))))
);
assert_eq!(select.joins, vec![]);
let from = only(select.from);
assert_eq!(from.relation, nest!(nest!(nest!(table("a"), table("b")))));
assert_eq!(from.joins, vec![]);
let sql = "SELECT * FROM a NATURAL JOIN (((b NATURAL JOIN c)))";
let select = verified_only_select(sql);
assert_eq!(select.relation.unwrap(), table("a"));
let from = only(select.from);
assert_eq!(from.relation, table("a"));
assert_eq!(
select.joins,
from.joins,
vec![join(nest!(nest!(nest!(table("b"), table("c")))))]
);
}
@ -1729,8 +1782,8 @@ fn parse_ctes() {
// CTE in a derived table
let sql = &format!("SELECT * FROM ({})", with);
let select = verified_only_select(sql);
match select.relation {
Some(TableFactor::Derived { subquery, .. }) => {
match only(select.from).relation {
TableFactor::Derived { subquery, .. } => {
assert_ctes_in_select(&cte_sqls, subquery.as_ref())
}
_ => panic!("Expected derived table"),
@ -2072,8 +2125,8 @@ fn parse_offset() {
let ast = verified_query("SELECT foo FROM (SELECT * FROM bar OFFSET 2 ROWS) OFFSET 2 ROWS");
assert_eq!(ast.offset, Some(ASTNode::SQLValue(Value::Long(2))));
match ast.body {
SQLSetExpr::Select(s) => match s.relation {
Some(TableFactor::Derived { subquery, .. }) => {
SQLSetExpr::Select(s) => match only(s.from).relation {
TableFactor::Derived { subquery, .. } => {
assert_eq!(subquery.offset, Some(ASTNode::SQLValue(Value::Long(2))));
}
_ => panic!("Test broke"),
@ -2172,8 +2225,8 @@ fn parse_fetch() {
})
);
match ast.body {
SQLSetExpr::Select(s) => match s.relation {
Some(TableFactor::Derived { subquery, .. }) => {
SQLSetExpr::Select(s) => match only(s.from).relation {
TableFactor::Derived { subquery, .. } => {
assert_eq!(
subquery.fetch,
Some(Fetch {
@ -2198,8 +2251,8 @@ fn parse_fetch() {
})
);
match ast.body {
SQLSetExpr::Select(s) => match s.relation {
Some(TableFactor::Derived { subquery, .. }) => {
SQLSetExpr::Select(s) => match only(s.from).relation {
TableFactor::Derived { subquery, .. } => {
assert_eq!(subquery.offset, Some(ASTNode::SQLValue(Value::Long(2))));
assert_eq!(
subquery.fetch,
@ -2250,16 +2303,18 @@ fn lateral_derived() {
lateral_str
);
let select = verified_only_select(&sql);
assert_eq!(select.joins.len(), 1);
let from = only(select.from);
assert_eq!(from.joins.len(), 1);
let join = &from.joins[0];
assert_eq!(
select.joins[0].join_operator,
join.join_operator,
JoinOperator::LeftOuter(JoinConstraint::On(ASTNode::SQLValue(Value::Boolean(true))))
);
if let TableFactor::Derived {
lateral,
ref subquery,
alias: Some(ref alias),
} = select.joins[0].relation
} = join.relation
{
assert_eq!(lateral_in, lateral);
assert_eq!("order".to_string(), alias.name);

View file

@ -19,8 +19,8 @@ fn parse_mssql_identifiers() {
expr_from_projection(&select.projection[1]),
);
assert_eq!(2, select.projection.len());
match select.relation {
Some(TableFactor::Table { name, .. }) => {
match &only(&select.from).relation {
TableFactor::Table { name, .. } => {
assert_eq!("##temp".to_string(), name.to_string());
}
_ => unreachable!(),