From 7624095738bd3b46397d9d1221ba75dbc73a8c97 Mon Sep 17 00:00:00 2001 From: Fredrik Roos Date: Sat, 17 Nov 2018 15:40:24 +0100 Subject: [PATCH 1/2] Support for joins --- src/dialect/generic_sql.rs | 3 +- src/sqlast/mod.rs | 87 +++++++++++++++- src/sqlparser.rs | 200 ++++++++++++++++++++++++++++++------- tests/sqlparser_generic.rs | 183 ++++++++++++++++++++++++++++----- 4 files changed, 409 insertions(+), 64 deletions(-) diff --git a/src/dialect/generic_sql.rs b/src/dialect/generic_sql.rs index 0d358277..5a5d73cd 100644 --- a/src/dialect/generic_sql.rs +++ b/src/dialect/generic_sql.rs @@ -11,7 +11,8 @@ impl Dialect for GenericSqlDialect { STORED, CSV, PARQUET, LOCATION, WITH, WITHOUT, HEADER, ROW, // SQL types CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, - BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, + BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, + CROSS, OUTER, INNER, NATURAL, ON, USING, ]; } diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 6b28f245..a2848e7d 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -75,6 +75,8 @@ pub enum ASTNode { projection: Vec, /// FROM relation: Option>, + // JOIN + joins: Vec, /// WHERE selection: Option>, /// ORDER BY @@ -167,10 +169,16 @@ impl ToString for ASTNode { .collect::>() .join(", ") ), - ASTNode::SQLCase { conditions, results, else_result } => { + ASTNode::SQLCase { + conditions, + results, + else_result, + } => { let mut s = format!( "CASE {}", - conditions.iter().zip(results) + conditions + .iter() + .zip(results) .map(|(c, r)| format!("WHEN {} THEN {}", c.to_string(), r.to_string())) .collect::>() .join(" ") @@ -179,10 +187,11 @@ impl ToString for ASTNode { s += &format!(" ELSE {}", else_result.to_string()) } s + " END" - }, + } ASTNode::SQLSelect { projection, relation, + joins, selection, order_by, group_by, @@ -200,6 +209,9 @@ impl ToString for ASTNode { if let Some(relation) = relation { s += &format!(" FROM {}", relation.as_ref().to_string()); } + for join in joins { + s += &join.to_string(); + } if let Some(selection) = selection { s += &format!(" WHERE {}", selection.as_ref().to_string()); } @@ -402,3 +414,72 @@ impl ToString for SQLColumnDef { s } } + +#[derive(Debug, Clone, PartialEq)] +pub struct Join { + pub relation: ASTNode, + pub join_operator: JoinOperator, +} + +impl ToString for Join { + fn to_string(&self) -> String { + fn prefix(constraint: &JoinConstraint) -> String { + match constraint { + JoinConstraint::Natural => "NATURAL ".to_string(), + _ => "".to_string(), + } + } + fn suffix(constraint: &JoinConstraint) -> String { + match constraint { + JoinConstraint::On(expr) => format!(" ON({})", expr.to_string()), + JoinConstraint::Using(attrs) => format!(" USING({})", attrs.join(", ")), + _ => "".to_string(), + } + } + match &self.join_operator { + JoinOperator::Inner(constraint) => format!( + "{}INNER JOIN {}{}", + prefix(constraint), + self.relation.to_string(), + prefix(constraint) + ), + JoinOperator::Cross => format!("CROSS JOIN {}", self.relation.to_string()), + JoinOperator::Implicit => format!(", {}", self.relation.to_string()), + JoinOperator::LeftOuter(constraint) => format!( + "{}LEFT OUTER JOIN {}{}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::RightOuter(constraint) => format!( + "{}RIGHT OUTER JOIN {}{}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::FullOuter(constraint) => format!( + "{}FULL OUTER JOIN {}{}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinOperator { + Inner(JoinConstraint), + LeftOuter(JoinConstraint), + RightOuter(JoinConstraint), + FullOuter(JoinConstraint), + Implicit, + Cross, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinConstraint { + On(ASTNode), + Using(Vec), + Natural, +} diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 33d4487d..71486789 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -17,10 +17,7 @@ use super::dialect::Dialect; use super::sqlast::*; use super::sqltokenizer::*; -use chrono::{ - offset::{FixedOffset}, - DateTime, NaiveDate, NaiveDateTime, NaiveTime, -}; +use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime}; #[derive(Debug, Clone)] pub enum ParserError { @@ -109,10 +106,8 @@ impl Parser { "NULL" => { self.prev_token(); self.parse_sql_value() - }, - "CASE" => { - self.parse_case_expression() } + "CASE" => self.parse_case_expression(), _ => return parser_err!(format!("No prefix parser for keyword {}", k)), }, Token::Mult => Ok(ASTNode::SQLWildcard), @@ -156,14 +151,14 @@ impl Parser { Token::DoubleQuotedString(_) => { self.prev_token(); self.parse_sql_value() - }, + } Token::LParen => { let expr = self.parse(); if !self.consume_token(&Token::RParen)? { return parser_err!(format!("expected token RParen")); } expr - }, + } _ => parser_err!(format!( "Prefix parser expected a keyword but found {:?}", t @@ -211,20 +206,20 @@ impl Parser { if self.parse_keywords(vec!["ELSE"]) { else_result = Some(Box::new(self.parse_expr(0)?)); if self.parse_keywords(vec!["END"]) { - break + break; } else { return parser_err!("Expecting END after a CASE..ELSE"); } } if self.parse_keywords(vec!["END"]) { - break + break; } self.consume_token(&Token::Keyword("WHEN".to_string()))?; } Ok(ASTNode::SQLCase { conditions, results, - else_result + else_result, }) } else { // TODO: implement "simple" case @@ -489,6 +484,18 @@ impl Parser { true } + pub fn expect_keyword(&mut self, expected: &'static str) -> Result<(), ParserError> { + if self.parse_keyword(expected) { + Ok(()) + } else { + parser_err!(format!( + "Expected keyword {}, found {:?}", + expected, + self.peek_token() + )) + } + } + //TODO: this function is inconsistent and sometimes returns bool and sometimes fails /// Consume the next token if it matches the expected token, otherwise return an error @@ -751,12 +758,12 @@ impl Parser { }, Token::Number(ref n) => match n.parse::() { Ok(n) => { -// if let Some(Token::Minus) = self.peek_token() { -// self.prev_token(); -// self.parse_timestamp_value() -// } else { - Ok(Value::Long(n)) -// } + // if let Some(Token::Minus) = self.peek_token() { + // self.prev_token(); + // self.parse_timestamp_value() + // } else { + Ok(Value::Long(n)) + // } } Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), }, @@ -1108,11 +1115,12 @@ impl Parser { pub fn parse_select(&mut self) -> Result { let projection = self.parse_expr_list()?; - let relation: Option> = if self.parse_keyword("FROM") { - //TODO: add support for JOIN - Some(Box::new(self.parse_expr(0)?)) + let (relation, joins): (Option>, Vec) = if self.parse_keyword("FROM") { + let relation = Some(Box::new(self.parse_expr(0)?)); + let joins = self.parse_joins()?; + (relation, joins) } else { - None + (None, vec![]) }; let selection = if self.parse_keyword("WHERE") { @@ -1158,6 +1166,7 @@ impl Parser { projection, selection, relation, + joins, limit, order_by, group_by, @@ -1166,6 +1175,131 @@ impl Parser { } } + fn parse_join_constraint(&mut self, natural: bool) -> Result { + if natural { + Ok(JoinConstraint::Natural) + } else if self.parse_keyword("ON") { + let constraint = self.parse_expr(0)?; + Ok(JoinConstraint::On(constraint)) + } else if self.parse_keyword("USING") { + if self.consume_token(&Token::LParen)? { + let attributes = self + .parse_expr_list()? + .into_iter() + .map(|ast_node| match ast_node { + ASTNode::SQLIdentifier(ident) => Ok(ident), + unexpected => { + parser_err!(format!("Expected identifier, found {:?}", unexpected)) + } + }) + .collect::, ParserError>>()?; + + if self.consume_token(&Token::RParen)? { + Ok(JoinConstraint::Using(attributes)) + } else { + parser_err!(format!("Expected token ')', found {:?}", self.peek_token())) + } + } else { + parser_err!(format!("Expected token '(', found {:?}", self.peek_token())) + } + } else { + parser_err!(format!( + "Unexpected token after JOIN: {:?}", + self.peek_token() + )) + } + } + + fn parse_joins(&mut self) -> Result, ParserError> { + let mut joins = vec![]; + loop { + let natural = match &self.peek_token() { + Some(Token::Comma) => { + self.next_token(); + let relation = self.parse_expr(0)?; + let join = Join { + relation, + join_operator: JoinOperator::Implicit, + }; + joins.push(join); + continue; + } + Some(Token::Keyword(kw)) if kw == "CROSS" => { + self.next_token(); + self.expect_keyword("JOIN")?; + let relation = self.parse_expr(0)?; + let join = Join { + relation, + join_operator: JoinOperator::Cross, + }; + joins.push(join); + continue; + } + Some(Token::Keyword(kw)) if kw == "NATURAL" => { + self.next_token(); + true + } + Some(_) => false, + None => return Ok(joins), + }; + + let join = match &self.peek_token() { + Some(Token::Keyword(kw)) if kw == "INNER" => { + self.next_token(); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), + } + } + Some(Token::Keyword(kw)) if kw == "JOIN" => { + self.next_token(); + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), + } + } + Some(Token::Keyword(kw)) if kw == "LEFT" => { + self.next_token(); + self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::LeftOuter( + self.parse_join_constraint(natural)?, + ), + } + } + Some(Token::Keyword(kw)) if kw == "RIGHT" => { + self.next_token(); + self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::RightOuter( + self.parse_join_constraint(natural)?, + ), + } + } + Some(Token::Keyword(kw)) if kw == "FULL" => { + self.next_token(); + self.parse_keyword("OUTER"); + self.expect_keyword("JOIN")?; + Join { + relation: self.parse_expr(0)?, + join_operator: JoinOperator::FullOuter( + self.parse_join_constraint(natural)?, + ), + } + } + _ => break, + }; + joins.push(join); + } + + Ok(joins) + } + /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { self.parse_keyword("INTO"); @@ -1215,19 +1349,17 @@ impl Parser { // look for optional ASC / DESC specifier let asc = match self.peek_token() { - Some(Token::Keyword(k)) => { - match k.to_uppercase().as_ref() { - "ASC" => { - self.next_token(); - true - }, - "DESC" => { - self.next_token(); - false - }, - _ => true + Some(Token::Keyword(k)) => match k.to_uppercase().as_ref() { + "ASC" => { + self.next_token(); + true } - } + "DESC" => { + self.next_token(); + false + } + _ => true, + }, Some(Token::Comma) => true, _ => true, }; diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_generic.rs index 09eeadf7..572f2c60 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_generic.rs @@ -13,7 +13,9 @@ fn parse_delete_statement() { match parse_sql(&sql) { ASTNode::SQLDelete { relation, .. } => { assert_eq!( - Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))), + Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( + "table".to_string() + )))), relation ); } @@ -36,7 +38,9 @@ fn parse_where_delete_statement() { .. } => { assert_eq!( - Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))), + Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( + "table".to_string() + )))), relation ); @@ -207,7 +211,9 @@ fn parse_select_order_by_limit() { ); let ast = parse_sql(&sql); match ast { - ASTNode::SQLSelect { order_by, limit, .. } => { + ASTNode::SQLSelect { + order_by, limit, .. + } => { assert_eq!( Some(vec![ SQLOrderByExpr { @@ -341,7 +347,10 @@ fn parse_literal_string() { let sql = "SELECT 'one'"; match parse_sql(&sql) { ASTNode::SQLSelect { ref projection, .. } => { - assert_eq!(projection[0], ASTNode::SQLValue(Value::SingleQuotedString("one".to_string()))); + assert_eq!( + projection[0], + ASTNode::SQLValue(Value::SingleQuotedString("one".to_string())) + ); } _ => panic!(), } @@ -380,20 +389,21 @@ fn parse_parens() { let sql = "(a + b) - (c + d)"; let ast = parse_sql(&sql); assert_eq!( - SQLBinaryExpr { + SQLBinaryExpr { left: Box::new(SQLBinaryExpr { - left: Box::new(SQLIdentifier("a".to_string())), - op: Plus, - right: Box::new(SQLIdentifier("b".to_string())) - }), - op: Minus, - right: Box::new(SQLBinaryExpr { + left: Box::new(SQLIdentifier("a".to_string())), + op: Plus, + right: Box::new(SQLIdentifier("b".to_string())) + }), + op: Minus, + right: Box::new(SQLBinaryExpr { left: Box::new(SQLIdentifier("c".to_string())), - op: Plus, + op: Plus, right: Box::new(SQLIdentifier("d".to_string())) }) - } - , ast); + }, + ast + ); } #[test] @@ -410,17 +420,28 @@ fn parse_case_expression() { SQLCase { conditions: vec![ SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))), - SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())), - op: Eq, right: Box::new(SQLValue(Value::Long(0))) }, - SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())), - op: GtEq, right: Box::new(SQLValue(Value::Long(0))) } + SQLBinaryExpr { + left: Box::new(SQLIdentifier("bar".to_string())), + op: Eq, + right: Box::new(SQLValue(Value::Long(0))) + }, + SQLBinaryExpr { + left: Box::new(SQLIdentifier("bar".to_string())), + op: GtEq, + right: Box::new(SQLValue(Value::Long(0))) + } ], - results: vec![SQLValue(Value::SingleQuotedString("null".to_string())), - SQLValue(Value::SingleQuotedString("=0".to_string())), - SQLValue(Value::SingleQuotedString(">=0".to_string()))], - else_result: Some(Box::new(SQLValue(Value::SingleQuotedString("<0".to_string())))) + results: vec![ + SQLValue(Value::SingleQuotedString("null".to_string())), + SQLValue(Value::SingleQuotedString("=0".to_string())), + SQLValue(Value::SingleQuotedString(">=0".to_string())) + ], + else_result: Some(Box::new(SQLValue(Value::SingleQuotedString( + "<0".to_string() + )))) }, - projection[0]); + projection[0] + ); } _ => assert!(false), } @@ -445,7 +466,9 @@ fn parse_delete_with_semi_colon() { match parse_sql(&sql) { ASTNode::SQLDelete { relation, .. } => { assert_eq!( - Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))), + Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( + "table".to_string() + )))), relation ); } @@ -453,12 +476,120 @@ fn parse_delete_with_semi_colon() { } } +#[test] +fn parse_implicit_join() { + let sql = "SELECT * FROM t1,t2"; + + match parse_sql(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Implicit + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_cross_join() { + let sql = "SELECT * FROM t1 CROSS JOIN t2"; + + match parse_sql(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Cross + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_joins_on() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr { + left: Box::new(ASTNode::SQLIdentifier("c1".into())), + op: SQLOperator::Eq, + right: Box::new(ASTNode::SQLIdentifier("c2".into())), + })), + } + } + + assert_eq!( + joins_from("SELECT * FROM t1 JOIN t2 ON c1 = c2"), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2"), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2"), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 ON c1 = c2"), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +#[test] +fn parse_joins_using() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::Using(vec!["c1".into()])), + } + } + + assert_eq!( + joins_from("SELECT * FROM t1 JOIN t2 USING(c1)"), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from("SELECT * FROM t1 LEFT JOIN t2 USING(c1)"), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)"), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 USING(c1)"), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +fn joins_from(sql: &str) -> Vec { + match parse_sql(sql) { + ASTNode::SQLSelect { joins, .. } => joins, + _ => panic!("Expected SELECT"), + } +} + fn parse_sql(sql: &str) -> ASTNode { let dialect = GenericSqlDialect {}; - let mut tokenizer = Tokenizer::new(&dialect,&sql, ); + let mut tokenizer = Tokenizer::new(&dialect, &sql); let tokens = tokenizer.tokenize().unwrap(); let mut parser = Parser::new(tokens); let ast = parser.parse().unwrap(); ast } - From 72024661a911b0904c9f4d49db2d4450c0c5accc Mon Sep 17 00:00:00 2001 From: Fredrik Roos Date: Sun, 18 Nov 2018 00:36:58 +0100 Subject: [PATCH 2/2] More tests and some small bugfixes --- src/dialect/postgresql.rs | 4 +- src/sqlast/mod.rs | 16 ++-- tests/sqlparser_ansi.rs | 8 +- tests/sqlparser_generic.rs | 57 +++++++++++--- tests/sqlparser_postgres.rs | 144 ++++++++++++++++++++++++++++++++++++ 5 files changed, 202 insertions(+), 27 deletions(-) diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index e8682819..20da4bef 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -13,8 +13,8 @@ impl Dialect for PostgreSqlDialect { CHAR, CHARACTER, VARYING, LARGE, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN, DATE, TIME, TIMESTAMP, VALUES, DEFAULT, ZONE, REGCLASS, TEXT, BYTEA, TRUE, FALSE, COPY, - STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, - CASE, WHEN, THEN, ELSE, END, + STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, CASE, WHEN, + THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, CROSS, OUTER, INNER, NATURAL, ON, USING, ]; } diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index a2848e7d..1d8328d5 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -431,34 +431,34 @@ impl ToString for Join { } fn suffix(constraint: &JoinConstraint) -> String { match constraint { - JoinConstraint::On(expr) => format!(" ON({})", expr.to_string()), - JoinConstraint::Using(attrs) => format!(" USING({})", attrs.join(", ")), + JoinConstraint::On(expr) => format!("ON {}", expr.to_string()), + JoinConstraint::Using(attrs) => format!("USING({})", attrs.join(", ")), _ => "".to_string(), } } match &self.join_operator { JoinOperator::Inner(constraint) => format!( - "{}INNER JOIN {}{}", + " {}JOIN {} {}", prefix(constraint), self.relation.to_string(), - prefix(constraint) + suffix(constraint) ), - JoinOperator::Cross => format!("CROSS JOIN {}", self.relation.to_string()), + JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()), JoinOperator::Implicit => format!(", {}", self.relation.to_string()), JoinOperator::LeftOuter(constraint) => format!( - "{}LEFT OUTER JOIN {}{}", + " {}LEFT JOIN {} {}", prefix(constraint), self.relation.to_string(), suffix(constraint) ), JoinOperator::RightOuter(constraint) => format!( - "{}RIGHT OUTER JOIN {}{}", + " {}RIGHT JOIN {} {}", prefix(constraint), self.relation.to_string(), suffix(constraint) ), JoinOperator::FullOuter(constraint) => format!( - "{}FULL OUTER JOIN {}{}", + " {}FULL JOIN {} {}", prefix(constraint), self.relation.to_string(), suffix(constraint) diff --git a/tests/sqlparser_ansi.rs b/tests/sqlparser_ansi.rs index aea3c843..4fec4f49 100644 --- a/tests/sqlparser_ansi.rs +++ b/tests/sqlparser_ansi.rs @@ -11,22 +11,18 @@ fn parse_simple_select() { let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1"); let ast = parse_sql(&sql); match ast { - ASTNode::SQLSelect { - projection, .. - } => { + ASTNode::SQLSelect { projection, .. } => { assert_eq!(3, projection.len()); } _ => assert!(false), } } - fn parse_sql(sql: &str) -> ASTNode { let dialect = AnsiSqlDialect {}; - let mut tokenizer = Tokenizer::new(&dialect,&sql, ); + let mut tokenizer = Tokenizer::new(&dialect, &sql); let tokens = tokenizer.tokenize().unwrap(); let mut parser = Parser::new(tokens); let ast = parser.parse().unwrap(); ast } - diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_generic.rs index 572f2c60..737caaf5 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_generic.rs @@ -529,21 +529,20 @@ fn parse_joins_on() { })), } } - assert_eq!( - joins_from("SELECT * FROM t1 JOIN t2 ON c1 = c2"), + joins_from(verified("SELECT * FROM t1 JOIN t2 ON c1 = c2")), vec![join_with_constraint("t2", JoinOperator::Inner)] ); assert_eq!( - joins_from("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2"), + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2")), vec![join_with_constraint("t2", JoinOperator::LeftOuter)] ); assert_eq!( - joins_from("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2"), + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2")), vec![join_with_constraint("t2", JoinOperator::RightOuter)] ); assert_eq!( - joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 ON c1 = c2"), + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2")), vec![join_with_constraint("t2", JoinOperator::FullOuter)] ); } @@ -561,25 +560,61 @@ fn parse_joins_using() { } assert_eq!( - joins_from("SELECT * FROM t1 JOIN t2 USING(c1)"), + joins_from(verified("SELECT * FROM t1 JOIN t2 USING(c1)")), vec![join_with_constraint("t2", JoinOperator::Inner)] ); assert_eq!( - joins_from("SELECT * FROM t1 LEFT JOIN t2 USING(c1)"), + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 USING(c1)")), vec![join_with_constraint("t2", JoinOperator::LeftOuter)] ); assert_eq!( - joins_from("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)"), + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)")), vec![join_with_constraint("t2", JoinOperator::RightOuter)] ); assert_eq!( - joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 USING(c1)"), + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 USING(c1)")), vec![join_with_constraint("t2", JoinOperator::FullOuter)] ); } -fn joins_from(sql: &str) -> Vec { - match parse_sql(sql) { +#[test] +fn parse_complex_join() { + let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c"; + assert_eq!(sql, parse_sql(sql).to_string()); +} + +#[test] +fn parse_join_syntax_variants() { + fn parses_to(from: &str, to: &str) { + assert_eq!(to, &parse_sql(from).to_string()) + } + + parses_to( + "SELECT c1 FROM t1 INNER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 FULL JOIN t2 USING(c1)", + ); +} + +fn verified(query: &str) -> ASTNode { + let ast = parse_sql(query); + assert_eq!(query, &ast.to_string()); + ast +} + +fn joins_from(ast: ASTNode) -> Vec { + match ast { ASTNode::SQLSelect { joins, .. } => joins, _ => panic!("Expected SELECT"), } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index adc575fc..4fccf328 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -716,6 +716,150 @@ fn parse_function_now() { assert_eq!(sql, ast.to_string()); } +#[test] +fn parse_implicit_join() { + let sql = "SELECT * FROM t1, t2"; + + match verified(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Implicit + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_cross_join() { + let sql = "SELECT * FROM t1 CROSS JOIN t2"; + + match verified(sql) { + ASTNode::SQLSelect { joins, .. } => { + assert_eq!(joins.len(), 1); + assert_eq!( + joins[0], + Join { + relation: ASTNode::SQLIdentifier("t2".to_string()), + join_operator: JoinOperator::Cross + } + ) + } + _ => assert!(false), + } +} + +#[test] +fn parse_joins_on() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr { + left: Box::new(ASTNode::SQLIdentifier("c1".into())), + op: SQLOperator::Eq, + right: Box::new(ASTNode::SQLIdentifier("c2".into())), + })), + } + } + assert_eq!( + joins_from(verified("SELECT * FROM t1 JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2")), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +#[test] +fn parse_joins_using() { + fn join_with_constraint( + relation: impl Into, + f: impl Fn(JoinConstraint) -> JoinOperator, + ) -> Join { + Join { + relation: ASTNode::SQLIdentifier(relation.into()), + join_operator: f(JoinConstraint::Using(vec!["c1".into()])), + } + } + + assert_eq!( + joins_from(verified("SELECT * FROM t1 JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::Inner)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::RightOuter)] + ); + assert_eq!( + joins_from(verified("SELECT * FROM t1 FULL JOIN t2 USING(c1)")), + vec![join_with_constraint("t2", JoinOperator::FullOuter)] + ); +} + +#[test] +fn parse_join_syntax_variants() { + fn parses_to(from: &str, to: &str) { + assert_eq!(to, &parse_sql(from).to_string()) + } + + parses_to( + "SELECT c1 FROM t1 INNER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)", + ); + parses_to( + "SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)", + "SELECT c1 FROM t1 FULL JOIN t2 USING(c1)", + ); +} + +#[test] +fn parse_complex_join() { + let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c"; + assert_eq!(sql, parse_sql(sql).to_string()); +} + +fn verified(query: &str) -> ASTNode { + let ast = parse_sql(query); + assert_eq!(query, &ast.to_string()); + ast +} + +fn joins_from(ast: ASTNode) -> Vec { + match ast { + ASTNode::SQLSelect { joins, .. } => joins, + _ => panic!("Expected SELECT"), + } +} + fn parse_sql(sql: &str) -> ASTNode { debug!("sql: {}", sql); let mut parser = parser(sql);