diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 71997966..f87bf406 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -27,7 +27,7 @@ pub use self::ddl::{ }; pub use self::query::{ Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect, - SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, + SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins, }; pub use self::sqltype::SQLType; pub use self::value::{SQLDateTimeField, Value}; diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index df07d48a..150aeaa7 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -113,9 +113,7 @@ pub struct SQLSelect { /// projection expressions pub projection: Vec, /// FROM - pub relation: Option, - /// JOIN - pub joins: Vec, + pub from: Vec, /// WHERE pub selection: Option, /// GROUP BY @@ -131,11 +129,8 @@ impl ToString for SQLSelect { if self.distinct { " DISTINCT" } else { "" }, comma_separated_string(&self.projection) ); - if let Some(ref relation) = self.relation { - s += &format!(" FROM {}", relation.to_string()); - } - for join in &self.joins { - s += &join.to_string(); + if !self.from.is_empty() { + s += &format!(" FROM {}", comma_separated_string(&self.from)); } if let Some(ref selection) = self.selection { s += &format!(" WHERE {}", selection.to_string()); @@ -197,6 +192,22 @@ impl ToString for SQLSelectItem { } } +#[derive(Debug, Clone, PartialEq, Hash)] +pub struct TableWithJoins { + pub relation: TableFactor, + pub joins: Vec, +} + +impl ToString for TableWithJoins { + fn to_string(&self) -> String { + let mut s = self.relation.to_string(); + for join in &self.joins { + s += &join.to_string(); + } + s + } +} + /// A table name or a parenthesized subquery with an optional alias #[derive(Debug, Clone, PartialEq, Hash)] pub enum TableFactor { @@ -215,10 +226,7 @@ pub enum TableFactor { subquery: Box, alias: Option, }, - NestedJoin { - base: Box, - joins: Vec, - }, + NestedJoin(Box), } impl ToString for TableFactor { @@ -257,12 +265,8 @@ impl ToString for TableFactor { } s } - TableFactor::NestedJoin { base, joins } => { - let mut s = base.to_string(); - for join in joins { - s += &join.to_string(); - } - format!("({})", s) + TableFactor::NestedJoin(table_reference) => { + format!("({})", table_reference.to_string()) } } } @@ -313,7 +317,6 @@ impl ToString for Join { suffix(constraint) ), JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()), - JoinOperator::Implicit => format!(", {}", self.relation.to_string()), JoinOperator::LeftOuter(constraint) => format!( " {}LEFT JOIN {}{}", prefix(constraint), @@ -342,7 +345,6 @@ pub enum JoinOperator { LeftOuter(JoinConstraint), RightOuter(JoinConstraint), FullOuter(JoinConstraint), - Implicit, Cross, } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index be26265c..4694a094 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -1570,13 +1570,15 @@ impl Parser { } let projection = self.parse_select_list()?; - let (relation, joins) = if self.parse_keyword("FROM") { - let relation = Some(self.parse_table_factor()?); - let joins = self.parse_joins()?; - (relation, joins) - } else { - (None, vec![]) - }; + let mut from = vec![]; + if self.parse_keyword("FROM") { + loop { + from.push(self.parse_table_and_joins()?); + if !self.consume_token(&Token::Comma) { + break; + } + } + } let selection = if self.parse_keyword("WHERE") { Some(self.parse_expr()?) @@ -1599,95 +1601,18 @@ impl Parser { Ok(SQLSelect { distinct, projection, + from, selection, - relation, - joins, group_by, having, }) } - /// A table name or a parenthesized subquery, followed by optional `[AS] alias` - pub fn parse_table_factor(&mut self) -> Result { - let lateral = self.parse_keyword("LATERAL"); - if self.consume_token(&Token::LParen) { - if self.parse_keyword("SELECT") - || self.parse_keyword("WITH") - || self.parse_keyword("VALUES") - { - self.prev_token(); - let subquery = Box::new(self.parse_query()?); - self.expect_token(&Token::RParen)?; - let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; - Ok(TableFactor::Derived { - lateral, - subquery, - alias, - }) - } else if lateral { - parser_err!("Expected subquery after LATERAL, found nested join".to_string()) - } else { - let base = Box::new(self.parse_table_factor()?); - let joins = self.parse_joins()?; - self.expect_token(&Token::RParen)?; - Ok(TableFactor::NestedJoin { base, joins }) - } - } else if lateral { - self.expected("subquery after LATERAL", self.peek_token()) - } else { - let name = self.parse_object_name()?; - // Postgres, MSSQL: table-valued functions: - let args = if self.consume_token(&Token::LParen) { - self.parse_optional_args()? - } else { - vec![] - }; - let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; - // MSSQL-specific table hints: - let mut with_hints = vec![]; - if self.parse_keyword("WITH") { - if self.consume_token(&Token::LParen) { - with_hints = self.parse_expr_list()?; - self.expect_token(&Token::RParen)?; - } else { - // rewind, as WITH may belong to the next statement's CTE - self.prev_token(); - } - }; - Ok(TableFactor::Table { - name, - alias, - args, - with_hints, - }) - } - } - - fn parse_join_constraint(&mut self, natural: bool) -> Result { - if natural { - Ok(JoinConstraint::Natural) - } else if self.parse_keyword("ON") { - let constraint = self.parse_expr()?; - Ok(JoinConstraint::On(constraint)) - } else if self.parse_keyword("USING") { - let columns = self.parse_parenthesized_column_list(Mandatory)?; - Ok(JoinConstraint::Using(columns)) - } else { - self.expected("ON, or USING after JOIN", self.peek_token()) - } - } - - fn parse_joins(&mut self) -> Result, ParserError> { + pub fn parse_table_and_joins(&mut self) -> Result { + let relation = self.parse_table_factor()?; let mut joins = vec![]; loop { let join = match &self.peek_token() { - Some(Token::Comma) => { - self.next_token(); - Join { - relation: self.parse_table_factor()?, - join_operator: JoinOperator::Implicit, - } - } Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => { self.next_token(); self.expect_keyword("JOIN")?; @@ -1736,7 +1661,76 @@ impl Parser { }; joins.push(join); } - Ok(joins) + Ok(TableWithJoins { relation, joins }) + } + + /// A table name or a parenthesized subquery, followed by optional `[AS] alias` + pub fn parse_table_factor(&mut self) -> Result { + let lateral = self.parse_keyword("LATERAL"); + if self.consume_token(&Token::LParen) { + if self.parse_keyword("SELECT") + || self.parse_keyword("WITH") + || self.parse_keyword("VALUES") + { + self.prev_token(); + let subquery = Box::new(self.parse_query()?); + self.expect_token(&Token::RParen)?; + let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::Derived { + lateral, + subquery, + alias, + }) + } else if lateral { + parser_err!("Expected subquery after LATERAL, found nested join".to_string()) + } else { + let table_reference = self.parse_table_and_joins()?; + self.expect_token(&Token::RParen)?; + Ok(TableFactor::NestedJoin(Box::new(table_reference))) + } + } else if lateral { + self.expected("subquery after LATERAL", self.peek_token()) + } else { + let name = self.parse_object_name()?; + // Postgres, MSSQL: table-valued functions: + let args = if self.consume_token(&Token::LParen) { + self.parse_optional_args()? + } else { + vec![] + }; + let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + // MSSQL-specific table hints: + let mut with_hints = vec![]; + if self.parse_keyword("WITH") { + if self.consume_token(&Token::LParen) { + with_hints = self.parse_expr_list()?; + self.expect_token(&Token::RParen)?; + } else { + // rewind, as WITH may belong to the next statement's CTE + self.prev_token(); + } + }; + Ok(TableFactor::Table { + name, + alias, + args, + with_hints, + }) + } + } + + fn parse_join_constraint(&mut self, natural: bool) -> Result { + if natural { + Ok(JoinConstraint::Natural) + } else if self.parse_keyword("ON") { + let constraint = self.parse_expr()?; + Ok(JoinConstraint::On(constraint)) + } else if self.parse_keyword("USING") { + let columns = self.parse_parenthesized_column_list(Mandatory)?; + Ok(JoinConstraint::Using(columns)) + } else { + self.expected("ON, or USING after JOIN", self.peek_token()) + } } /// Parse an INSERT statement diff --git a/src/test_utils.rs b/src/test_utils.rs index 16216bfe..7a1ce5e2 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -109,9 +109,13 @@ pub fn all_dialects() -> TestedDialects { } } -pub fn only(v: &[T]) -> &T { - assert_eq!(1, v.len()); - v.first().unwrap() +pub fn only(v: impl IntoIterator) -> T { + let mut iter = v.into_iter(); + if let (Some(item), None) = (iter.next(), iter.next()) { + item + } else { + panic!("only called on collection without exactly one item") + } } pub fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index f1b9ed07..4318f4dc 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1300,7 +1300,7 @@ fn parse_delimited_identifiers() { r#"SELECT "alias"."bar baz", "myfun"(), "simple id" AS "column alias" FROM "a table" AS "alias""# ); // check FROM - match select.relation.unwrap() { + match only(select.from).relation { TableFactor::Table { name, alias, @@ -1430,16 +1430,69 @@ fn parse_implicit_join() { let sql = "SELECT * FROM t1, t2"; let select = verified_only_select(sql); assert_eq!( - &Join { - relation: TableFactor::Table { - name: SQLObjectName(vec!["t2".to_string()]), - alias: None, - args: vec![], - with_hints: vec![], + vec![ + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t1".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![], }, - join_operator: JoinOperator::Implicit - }, - only(&select.joins), + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![], + } + ], + select.from, + ); + + let sql = "SELECT * FROM t1a NATURAL JOIN t1b, t2a NATURAL JOIN t2b"; + let select = verified_only_select(sql); + assert_eq!( + vec![ + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t1a".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t1b".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + join_operator: JoinOperator::Inner(JoinConstraint::Natural), + }] + }, + TableWithJoins { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2a".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2b".into()]), + alias: None, + args: vec![], + with_hints: vec![], + }, + join_operator: JoinOperator::Inner(JoinConstraint::Natural), + }] + } + ], + select.from, ); } @@ -1448,7 +1501,7 @@ fn parse_cross_join() { let sql = "SELECT * FROM t1 CROSS JOIN t2"; let select = verified_only_select(sql); assert_eq!( - &Join { + Join { relation: TableFactor::Table { name: SQLObjectName(vec!["t2".to_string()]), alias: None, @@ -1457,7 +1510,7 @@ fn parse_cross_join() { }, join_operator: JoinOperator::Cross }, - only(&select.joins), + only(only(select.from).joins), ); } @@ -1491,7 +1544,7 @@ fn parse_joins_on() { } // Test parsing of aliases assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").from).joins, vec![join_with_constraint( "t2", table_alias("foo"), @@ -1504,19 +1557,19 @@ fn parse_joins_on() { ); // Test parsing of different join operators assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::Inner)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").joins, + only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").from).joins, vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] ); } @@ -1540,7 +1593,7 @@ fn parse_joins_using() { } // Test parsing of aliases assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").from).joins, vec![join_with_constraint( "t2", table_alias("foo"), @@ -1553,19 +1606,19 @@ fn parse_joins_using() { ); // Test parsing of different join operators assert_eq!( - verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::Inner)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").joins, + only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").from).joins, vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] ); } @@ -1584,19 +1637,19 @@ fn parse_natural_join() { } } assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").from).joins, vec![natural_join(JoinOperator::Inner)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").from).joins, vec![natural_join(JoinOperator::LeftOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").from).joins, vec![natural_join(JoinOperator::RightOuter)] ); assert_eq!( - verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").joins, + only(&verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").from).joins, vec![natural_join(JoinOperator::FullOuter)] ); @@ -1633,17 +1686,17 @@ fn parse_join_nesting() { macro_rules! nest { ($base:expr $(, $join:expr)*) => { - TableFactor::NestedJoin { - base: Box::new($base), + TableFactor::NestedJoin(Box::new(TableWithJoins { + relation: $base, joins: vec![$(join($join)),*] - } + })) }; } let sql = "SELECT * FROM a NATURAL JOIN (b NATURAL JOIN (c NATURAL JOIN d NATURAL JOIN e)) \ NATURAL JOIN (f NATURAL JOIN (g NATURAL JOIN h))"; assert_eq!( - verified_only_select(sql).joins, + only(&verified_only_select(sql).from).joins, vec![ join(nest!(table("b"), nest!(table("c"), table("d"), table("e")))), join(nest!(table("f"), nest!(table("g"), table("h")))) @@ -1652,22 +1705,22 @@ fn parse_join_nesting() { let sql = "SELECT * FROM (a NATURAL JOIN b) NATURAL JOIN c"; let select = verified_only_select(sql); - assert_eq!(select.relation.unwrap(), nest!(table("a"), table("b")),); - assert_eq!(select.joins, vec![join(table("c"))]); + let from = only(select.from); + assert_eq!(from.relation, nest!(table("a"), table("b"))); + assert_eq!(from.joins, vec![join(table("c"))]); let sql = "SELECT * FROM (((a NATURAL JOIN b)))"; let select = verified_only_select(sql); - assert_eq!( - select.relation.unwrap(), - nest!(nest!(nest!(table("a"), table("b")))) - ); - assert_eq!(select.joins, vec![]); + let from = only(select.from); + assert_eq!(from.relation, nest!(nest!(nest!(table("a"), table("b"))))); + assert_eq!(from.joins, vec![]); let sql = "SELECT * FROM a NATURAL JOIN (((b NATURAL JOIN c)))"; let select = verified_only_select(sql); - assert_eq!(select.relation.unwrap(), table("a")); + let from = only(select.from); + assert_eq!(from.relation, table("a")); assert_eq!( - select.joins, + from.joins, vec![join(nest!(nest!(nest!(table("b"), table("c")))))] ); } @@ -1729,8 +1782,8 @@ fn parse_ctes() { // CTE in a derived table let sql = &format!("SELECT * FROM ({})", with); let select = verified_only_select(sql); - match select.relation { - Some(TableFactor::Derived { subquery, .. }) => { + match only(select.from).relation { + TableFactor::Derived { subquery, .. } => { assert_ctes_in_select(&cte_sqls, subquery.as_ref()) } _ => panic!("Expected derived table"), @@ -2072,8 +2125,8 @@ fn parse_offset() { let ast = verified_query("SELECT foo FROM (SELECT * FROM bar OFFSET 2 ROWS) OFFSET 2 ROWS"); assert_eq!(ast.offset, Some(ASTNode::SQLValue(Value::Long(2)))); match ast.body { - SQLSetExpr::Select(s) => match s.relation { - Some(TableFactor::Derived { subquery, .. }) => { + SQLSetExpr::Select(s) => match only(s.from).relation { + TableFactor::Derived { subquery, .. } => { assert_eq!(subquery.offset, Some(ASTNode::SQLValue(Value::Long(2)))); } _ => panic!("Test broke"), @@ -2172,8 +2225,8 @@ fn parse_fetch() { }) ); match ast.body { - SQLSetExpr::Select(s) => match s.relation { - Some(TableFactor::Derived { subquery, .. }) => { + SQLSetExpr::Select(s) => match only(s.from).relation { + TableFactor::Derived { subquery, .. } => { assert_eq!( subquery.fetch, Some(Fetch { @@ -2198,8 +2251,8 @@ fn parse_fetch() { }) ); match ast.body { - SQLSetExpr::Select(s) => match s.relation { - Some(TableFactor::Derived { subquery, .. }) => { + SQLSetExpr::Select(s) => match only(s.from).relation { + TableFactor::Derived { subquery, .. } => { assert_eq!(subquery.offset, Some(ASTNode::SQLValue(Value::Long(2)))); assert_eq!( subquery.fetch, @@ -2250,16 +2303,18 @@ fn lateral_derived() { lateral_str ); let select = verified_only_select(&sql); - assert_eq!(select.joins.len(), 1); + let from = only(select.from); + assert_eq!(from.joins.len(), 1); + let join = &from.joins[0]; assert_eq!( - select.joins[0].join_operator, + join.join_operator, JoinOperator::LeftOuter(JoinConstraint::On(ASTNode::SQLValue(Value::Boolean(true)))) ); if let TableFactor::Derived { lateral, ref subquery, alias: Some(ref alias), - } = select.joins[0].relation + } = join.relation { assert_eq!(lateral_in, lateral); assert_eq!("order".to_string(), alias.name); diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index b49cfd78..d0b1f7ec 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -19,8 +19,8 @@ fn parse_mssql_identifiers() { expr_from_projection(&select.projection[1]), ); assert_eq!(2, select.projection.len()); - match select.relation { - Some(TableFactor::Table { name, .. }) => { + match &only(&select.from).relation { + TableFactor::Table { name, .. } => { assert_eq!("##temp".to_string(), name.to_string()); } _ => unreachable!(),