diff --git a/Cargo.toml b/Cargo.toml index 53463a9e..b24348b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ path = "src/lib.rs" [dependencies] log = "0.4.5" chrono = "0.4.6" -uuid = "0.7.1" [dev-dependencies] simple_logger = "1.0.1" diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 99e734f7..0a5bb6ba 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -106,8 +106,11 @@ pub enum ASTNode { over: Option, }, /// CASE [] WHEN THEN ... [ELSE ] END + /// Note we only recognize a complete single expression as , not + /// `< 0` nor `1, 2, 3` as allowed in a per + /// https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause SQLCase { - // TODO: support optional operand for "simple case" + operand: Option>, conditions: Vec, results: Vec, else_result: Option>, @@ -182,19 +185,21 @@ impl ToString for ASTNode { s } ASTNode::SQLCase { + operand, conditions, results, else_result, } => { - let mut s = format!( - "CASE {}", - conditions - .iter() - .zip(results) - .map(|(c, r)| format!("WHEN {} THEN {}", c.to_string(), r.to_string())) - .collect::>() - .join(" ") - ); + let mut s = "CASE".to_string(); + if let Some(operand) = operand { + s += &format!(" {}", operand.to_string()); + } + s += &conditions + .iter() + .zip(results) + .map(|(c, r)| format!(" WHEN {} THEN {}", c.to_string(), r.to_string())) + .collect::>() + .join(""); if let Some(else_result) = else_result { s += &format!(" ELSE {}", else_result.to_string()) } diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index 50b81841..6668b785 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -183,6 +183,12 @@ pub enum TableFactor { Table { name: SQLObjectName, alias: Option, + /// Arguments of a table-valued function, as supported by Postgres + /// and MSSQL. Note that deprecated MSSQL `FROM foo (NOLOCK)` syntax + /// will also be parsed as `args`. + args: Option>, + /// MSSQL-specific `WITH (...)` hints such as NOLOCK. + with_hints: Vec, }, Derived { subquery: Box, @@ -192,16 +198,32 @@ pub enum TableFactor { impl ToString for TableFactor { fn to_string(&self) -> String { - let (base, alias) = match self { - TableFactor::Table { name, alias } => (name.to_string(), alias), - TableFactor::Derived { subquery, alias } => { - (format!("({})", subquery.to_string()), alias) + match self { + TableFactor::Table { + name, + alias, + args, + with_hints, + } => { + let mut s = name.to_string(); + if let Some(args) = args { + s += &format!("({})", comma_separated_string(args)) + }; + if let Some(alias) = alias { + s += &format!(" AS {}", alias); + } + if !with_hints.is_empty() { + s += &format!(" WITH ({})", comma_separated_string(with_hints)); + } + s + } + TableFactor::Derived { subquery, alias } => { + let mut s = format!("({})", subquery.to_string()); + if let Some(alias) = alias { + s += &format!(" AS {}", alias); + } + s } - }; - if let Some(alias) = alias { - format!("{} AS {}", base, alias) - } else { - base } } } diff --git a/src/sqlast/value.rs b/src/sqlast/value.rs index eedb14a9..a36f8d27 100644 --- a/src/sqlast/value.rs +++ b/src/sqlast/value.rs @@ -1,7 +1,5 @@ use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime}; -use uuid::Uuid; - /// SQL values such as int, double, string, timestamp #[derive(Debug, Clone, PartialEq)] pub enum Value { @@ -9,8 +7,6 @@ pub enum Value { Long(i64), /// Literal floating point value Double(f64), - /// Uuid value - Uuid(Uuid), /// 'string value' SingleQuotedString(String), /// N'string value' @@ -34,7 +30,6 @@ impl ToString for Value { match self { Value::Long(v) => v.to_string(), Value::Double(v) => v.to_string(), - Value::Uuid(v) => v.to_string(), Value::SingleQuotedString(v) => format!("'{}'", escape_single_quote_string(v)), Value::NationalStringLiteral(v) => format!("N'{}'", v), Value::Boolean(v) => v.to_string(), diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 8ab0bee8..3b6f9c28 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -28,6 +28,7 @@ pub enum ParserError { ParserError(String), } +// Use `Parser::expected` instead, if possible macro_rules! parser_err { ($MSG:expr) => { Err(ParserError::ParserError($MSG.to_string())) @@ -69,10 +70,7 @@ impl Parser { if parser.peek_token().is_none() { break; } else if expecting_statement_delimiter { - return parser_err!(format!( - "Expected end of statement, found: {}", - parser.peek_token().unwrap().to_string() - )); + return parser.expected("end of statement", parser.peek_token()); } let statement = parser.parse_statement()?; @@ -102,12 +100,12 @@ impl Parser { w.to_string() )), }, - unexpected => parser_err!(format!( - "Unexpected {:?} at the beginning of a statement", - unexpected - )), + unexpected => self.expected( + "a keyword at the beginning of a statement", + Some(unexpected), + ), }, - _ => parser_err!("Unexpected end of file"), + None => self.expected("SQL statement", None), } } @@ -188,10 +186,10 @@ impl Parser { break; } unexpected => { - return parser_err!(format!( - "Expected an identifier or a '*' after '.', got: {:?}", - unexpected - )); + return self.expected( + "an identifier or a '*' after '.'", + unexpected, + ); } } } @@ -210,8 +208,13 @@ impl Parser { Token::Mult => Ok(ASTNode::SQLWildcard), tok @ Token::Minus | tok @ Token::Plus => { let p = self.get_precedence(&tok)?; + let operator = if tok == Token::Plus { + SQLOperator::Plus + } else { + SQLOperator::Minus + }; Ok(ASTNode::SQLUnary { - operator: self.to_sql_operator(&tok)?, + operator, expr: Box::new(self.parse_subexpr(p)?), }) } @@ -231,10 +234,7 @@ impl Parser { self.expect_token(&Token::RParen)?; Ok(expr) } - _ => parser_err!(format!( - "Did not expect {:?} at the beginning of an expression", - t - )), + _ => self.expected("an expression", Some(t)), }, None => parser_err!("Prefix parser expected a keyword but hit EOF"), } @@ -242,13 +242,7 @@ impl Parser { pub fn parse_function(&mut self, name: SQLObjectName) -> Result { self.expect_token(&Token::LParen)?; - let args = if self.consume_token(&Token::RParen) { - vec![] - } else { - let args = self.parse_expr_list()?; - self.expect_token(&Token::RParen)?; - args - }; + let args = self.parse_optional_args()?; let over = if self.parse_keyword("OVER") { // TBD: support window names (`OVER mywin`) in place of inline specification self.expect_token(&Token::LParen)?; @@ -302,12 +296,7 @@ impl Parser { } } Some(Token::RParen) => None, - unexpected => { - return parser_err!(format!( - "Expected 'ROWS', 'RANGE', 'GROUPS', or ')', got {:?}", - unexpected - )); - } + unexpected => return self.expected("'ROWS', 'RANGE', 'GROUPS', or ')'", unexpected), }; self.expect_token(&Token::RParen)?; Ok(window_frame) @@ -335,46 +324,39 @@ impl Parser { } else if self.parse_keyword("FOLLOWING") { Ok(SQLWindowFrameBound::Following(rows)) } else { - parser_err!(format!( - "Expected PRECEDING or FOLLOWING, found {:?}", - self.peek_token() - )) + self.expected("PRECEDING or FOLLOWING", self.peek_token()) } } } pub fn parse_case_expression(&mut self) -> Result { - if self.parse_keywords(vec!["WHEN"]) { - let mut conditions = vec![]; - let mut results = vec![]; - let mut else_result = None; - loop { - conditions.push(self.parse_expr()?); - self.expect_keyword("THEN")?; - results.push(self.parse_expr()?); - if self.parse_keywords(vec!["ELSE"]) { - else_result = Some(Box::new(self.parse_expr()?)); - if self.parse_keywords(vec!["END"]) { - break; - } else { - return parser_err!("Expecting END after a CASE..ELSE"); - } - } - if self.parse_keywords(vec!["END"]) { - break; - } - self.expect_keyword("WHEN")?; - } - Ok(ASTNode::SQLCase { - conditions, - results, - else_result, - }) - } else { - // TODO: implement "simple" case - // https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-case - parser_err!("Simple case not implemented") + let mut operand = None; + if !self.parse_keyword("WHEN") { + operand = Some(Box::new(self.parse_expr()?)); + self.expect_keyword("WHEN")?; } + let mut conditions = vec![]; + let mut results = vec![]; + loop { + conditions.push(self.parse_expr()?); + self.expect_keyword("THEN")?; + results.push(self.parse_expr()?); + if !self.parse_keyword("WHEN") { + break; + } + } + let else_result = if self.parse_keyword("ELSE") { + Some(Box::new(self.parse_expr()?)) + } else { + None + }; + self.expect_keyword("END")?; + Ok(ASTNode::SQLCase { + operand, + conditions, + results, + else_result, + }) } /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` @@ -390,65 +372,75 @@ impl Parser { }) } - /// Parse an expression infix (typically an operator) + /// Parse an operator following an expression pub fn parse_infix(&mut self, expr: ASTNode, precedence: u8) -> Result { debug!("parsing infix"); - match self.next_token() { - Some(tok) => match tok { - Token::SQLWord(ref k) if k.keyword == "IS" => { - if self.parse_keywords(vec!["NULL"]) { + let tok = self.next_token().unwrap(); // safe as EOF's precedence is the lowest + + let regular_binary_operator = match tok { + Token::Eq => Some(SQLOperator::Eq), + Token::Neq => Some(SQLOperator::NotEq), + Token::Gt => Some(SQLOperator::Gt), + Token::GtEq => Some(SQLOperator::GtEq), + Token::Lt => Some(SQLOperator::Lt), + Token::LtEq => Some(SQLOperator::LtEq), + Token::Plus => Some(SQLOperator::Plus), + Token::Minus => Some(SQLOperator::Minus), + Token::Mult => Some(SQLOperator::Multiply), + Token::Mod => Some(SQLOperator::Modulus), + Token::Div => Some(SQLOperator::Divide), + Token::SQLWord(ref k) => match k.keyword.as_ref() { + "AND" => Some(SQLOperator::And), + "OR" => Some(SQLOperator::Or), + "LIKE" => Some(SQLOperator::Like), + "NOT" => { + if self.parse_keyword("LIKE") { + Some(SQLOperator::NotLike) + } else { + None + } + } + _ => None, + }, + _ => None, + }; + + if let Some(op) = regular_binary_operator { + Ok(ASTNode::SQLBinaryExpr { + left: Box::new(expr), + op, + right: Box::new(self.parse_subexpr(precedence)?), + }) + } else if let Token::SQLWord(ref k) = tok { + match k.keyword.as_ref() { + "IS" => { + if self.parse_keyword("NULL") { Ok(ASTNode::SQLIsNull(Box::new(expr))) } else if self.parse_keywords(vec!["NOT", "NULL"]) { Ok(ASTNode::SQLIsNotNull(Box::new(expr))) } else { - parser_err!(format!( - "Expected NULL or NOT NULL after IS, found {:?}", - self.peek_token() - )) + self.expected("NULL or NOT NULL after IS", self.peek_token()) } } - Token::SQLWord(ref k) if k.keyword == "NOT" => { + "NOT" | "IN" | "BETWEEN" => { + self.prev_token(); + let negated = self.parse_keyword("NOT"); if self.parse_keyword("IN") { - self.parse_in(expr, true) + self.parse_in(expr, negated) } else if self.parse_keyword("BETWEEN") { - self.parse_between(expr, true) - } else if self.parse_keyword("LIKE") { - Ok(ASTNode::SQLBinaryExpr { - left: Box::new(expr), - op: SQLOperator::NotLike, - right: Box::new(self.parse_subexpr(precedence)?), - }) + self.parse_between(expr, negated) } else { - parser_err!(format!( - "Expected BETWEEN, IN or LIKE after NOT, found {:?}", - self.peek_token() - )) + panic!() } } - Token::SQLWord(ref k) if k.keyword == "IN" => self.parse_in(expr, false), - Token::SQLWord(ref k) if k.keyword == "BETWEEN" => self.parse_between(expr, false), - Token::DoubleColon => self.parse_pg_cast(expr), - Token::SQLWord(_) - | Token::Eq - | Token::Neq - | Token::Gt - | Token::GtEq - | Token::Lt - | Token::LtEq - | Token::Plus - | Token::Minus - | Token::Mult - | Token::Mod - | Token::Div => Ok(ASTNode::SQLBinaryExpr { - left: Box::new(expr), - op: self.to_sql_operator(&tok)?, - right: Box::new(self.parse_subexpr(precedence)?), - }), - _ => parser_err!(format!("No infix parser for token {:?}", tok)), - }, - // This is not supposed to happen, because of the precedence check - // in parse_subexpr. - None => parser_err!("Unexpected EOF in parse_infix"), + // Can only happen if `get_precedence` got out of sync with this function + _ => panic!("No infix parser for token {:?}", tok), + } + } else if Token::DoubleColon == tok { + self.parse_pg_cast(expr) + } else { + // Can only happen if `get_precedence` got out of sync with this function + panic!("No infix parser for token {:?}", tok) } } @@ -494,28 +486,6 @@ impl Parser { }) } - /// Convert a token operator to an AST operator - pub fn to_sql_operator(&self, tok: &Token) -> Result { - match tok { - Token::Eq => Ok(SQLOperator::Eq), - Token::Neq => Ok(SQLOperator::NotEq), - Token::Lt => Ok(SQLOperator::Lt), - Token::LtEq => Ok(SQLOperator::LtEq), - Token::Gt => Ok(SQLOperator::Gt), - Token::GtEq => Ok(SQLOperator::GtEq), - Token::Plus => Ok(SQLOperator::Plus), - Token::Minus => Ok(SQLOperator::Minus), - Token::Mult => Ok(SQLOperator::Multiply), - Token::Div => Ok(SQLOperator::Divide), - Token::Mod => Ok(SQLOperator::Modulus), - Token::SQLWord(ref k) if k.keyword == "AND" => Ok(SQLOperator::And), - Token::SQLWord(ref k) if k.keyword == "OR" => Ok(SQLOperator::Or), - //Token::SQLWord(ref k) if k.keyword == "NOT" => Ok(SQLOperator::Not), - Token::SQLWord(ref k) if k.keyword == "LIKE" => Ok(SQLOperator::Like), - _ => parser_err!(format!("Unsupported SQL operator {:?}", tok)), - } - } - /// Get the precedence of the next token pub fn get_next_precedence(&self) -> Result { if let Some(token) = self.peek_token() { @@ -629,6 +599,15 @@ impl Parser { } } + /// Report unexpected token + fn expected(&self, expected: &str, found: Option) -> Result { + parser_err!(format!( + "Expected {}, found: {}", + expected, + found.map_or("EOF".to_string(), |t| t.to_string()) + )) + } + /// Look for an expected keyword and consume it if it exists #[must_use] pub fn parse_keyword(&mut self, expected: &'static str) -> bool { @@ -666,11 +645,7 @@ impl Parser { if self.parse_keyword(expected) { Ok(()) } else { - parser_err!(format!( - "Expected keyword {}, found {:?}", - expected, - self.peek_token() - )) + self.expected(expected, self.peek_token()) } } @@ -695,11 +670,7 @@ impl Parser { if self.consume_token(expected) { Ok(()) } else { - parser_err!(format!( - "Expected token {:?}, found {:?}", - expected, - self.peek_token() - )) + self.expected(&expected.to_string(), self.peek_token()) } } @@ -713,10 +684,7 @@ impl Parser { } else if self.parse_keyword("EXTERNAL") { self.parse_create_external_table() } else { - parser_err!(format!( - "Unexpected token after CREATE: {:?}", - self.peek_token() - )) + self.expected("TABLE or VIEW after CREATE", self.peek_token()) } } @@ -867,20 +835,20 @@ impl Parser { self.expect_keyword("TABLE")?; let _ = self.parse_keyword("ONLY"); let table_name = self.parse_object_name()?; - let operation: Result = - if self.parse_keywords(vec!["ADD", "CONSTRAINT"]) { + let operation = if self.parse_keyword("ADD") { + if self.parse_keyword("CONSTRAINT") { let constraint_name = self.parse_identifier()?; let table_key = self.parse_table_key(constraint_name)?; - Ok(AlterOperation::AddConstraint(table_key)) + AlterOperation::AddConstraint(table_key) } else { - return parser_err!(format!( - "Expecting ADD CONSTRAINT, found :{:?}", - self.peek_token() - )); - }; + return self.expected("CONSTRAINT after ADD", self.peek_token()); + } + } else { + return self.expected("ADD after ALTER TABLE", self.peek_token()); + }; Ok(SQLStatement::SQLAlterTable { name: table_name, - operation: operation?, + operation, }) } @@ -1158,7 +1126,7 @@ impl Parser { Ok(SQLType::Custom(type_name)) } }, - other => parser_err!(format!("Invalid data type: '{:?}'", other)), + other => self.expected("a data type name", other), } } @@ -1218,10 +1186,7 @@ impl Parser { } } if expect_identifier { - parser_err!(format!( - "Expecting identifier, found {:?}", - self.peek_token() - )) + self.expected("identifier", self.peek_token()) } else { Ok(idents) } @@ -1237,7 +1202,7 @@ impl Parser { pub fn parse_identifier(&mut self) -> Result { match self.next_token() { Some(Token::SQLWord(w)) => Ok(w.as_sql_ident()), - unexpected => parser_err!(format!("Expected identifier, found {:?}", unexpected)), + unexpected => self.expected("identifier", unexpected), } } @@ -1368,7 +1333,7 @@ impl Parser { self.expect_token(&Token::RParen)?; SQLSetExpr::Query(Box::new(subquery)) } else { - parser_err!("Expected SELECT or a subquery in the query body!")? + return self.expected("SELECT or a subquery in the query body", self.peek_token()); }; loop { @@ -1459,8 +1424,30 @@ impl Parser { Ok(TableFactor::Derived { subquery, alias }) } else { let name = self.parse_object_name()?; + // Postgres, MSSQL: table-valued functions: + let args = if self.consume_token(&Token::LParen) { + Some(self.parse_optional_args()?) + } else { + None + }; let alias = self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; - Ok(TableFactor::Table { name, alias }) + // MSSQL-specific table hints: + let mut with_hints = vec![]; + if self.parse_keyword("WITH") { + if self.consume_token(&Token::LParen) { + with_hints = self.parse_expr_list()?; + self.expect_token(&Token::RParen)?; + } else { + // rewind, as WITH may belong to the next statement's CTE + self.prev_token(); + } + }; + Ok(TableFactor::Table { + name, + alias, + args, + with_hints, + }) } } @@ -1476,10 +1463,7 @@ impl Parser { self.expect_token(&Token::RParen)?; Ok(JoinConstraint::Using(attributes)) } else { - parser_err!(format!( - "Unexpected token after JOIN: {:?}", - self.peek_token() - )) + self.expected("ON, or USING after JOIN", self.peek_token()) } } @@ -1608,6 +1592,16 @@ impl Parser { Ok(expr_list) } + pub fn parse_optional_args(&mut self) -> Result, ParserError> { + if self.consume_token(&Token::RParen) { + Ok(vec![]) + } else { + let args = self.parse_expr_list()?; + self.expect_token(&Token::RParen)?; + Ok(args) + } + } + /// Parse a comma-delimited list of projections after SELECT pub fn parse_select_list(&mut self) -> Result, ParserError> { let mut projections: Vec = vec![]; diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_generic.rs index 3cef398e..e1e96972 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_generic.rs @@ -641,9 +641,16 @@ fn parse_delimited_identifiers() { ); // check FROM match select.relation.unwrap() { - TableFactor::Table { name, alias } => { + TableFactor::Table { + name, + alias, + args, + with_hints, + } => { assert_eq!(vec![r#""a table""#.to_string()], name.0); assert_eq!(r#""alias""#, alias.unwrap()); + assert!(args.is_none()); + assert!(with_hints.is_empty()); } _ => panic!("Expecting TableFactor::Table"), } @@ -698,13 +705,14 @@ fn parse_parens() { } #[test] -fn parse_case_expression() { +fn parse_searched_case_expression() { let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo"; use self::ASTNode::{SQLBinaryExpr, SQLCase, SQLIdentifier, SQLIsNull, SQLValue}; use self::SQLOperator::*; let select = verified_only_select(sql); assert_eq!( &SQLCase { + operand: None, conditions: vec![ SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))), SQLBinaryExpr { @@ -731,6 +739,31 @@ fn parse_case_expression() { ); } +#[test] +fn parse_simple_case_expression() { + // ANSI calls a CASE expression with an operand "" + let sql = "SELECT CASE foo WHEN 1 THEN 'Y' ELSE 'N' END"; + let select = verified_only_select(sql); + use self::ASTNode::{SQLCase, SQLIdentifier, SQLValue}; + assert_eq!( + &SQLCase { + operand: Some(Box::new(SQLIdentifier("foo".to_string()))), + conditions: vec![SQLValue(Value::Long(1))], + results: vec![SQLValue(Value::SingleQuotedString("Y".to_string())),], + else_result: Some(Box::new(SQLValue(Value::SingleQuotedString( + "N".to_string() + )))) + }, + expr_from_projection(only(&select.projection)), + ); +} + +#[test] +fn parse_from_advanced() { + let sql = "SELECT * FROM fn(1, 2) AS foo, schema.bar AS bar WITH (NOLOCK)"; + let _select = verified_only_select(sql); +} + #[test] fn parse_implicit_join() { let sql = "SELECT * FROM t1, t2"; @@ -740,6 +773,8 @@ fn parse_implicit_join() { relation: TableFactor::Table { name: SQLObjectName(vec!["t2".to_string()]), alias: None, + args: None, + with_hints: vec![], }, join_operator: JoinOperator::Implicit }, @@ -756,6 +791,8 @@ fn parse_cross_join() { relation: TableFactor::Table { name: SQLObjectName(vec!["t2".to_string()]), alias: None, + args: None, + with_hints: vec![], }, join_operator: JoinOperator::Cross }, @@ -774,6 +811,8 @@ fn parse_joins_on() { relation: TableFactor::Table { name: SQLObjectName(vec![relation.into()]), alias, + args: None, + with_hints: vec![], }, join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr { left: Box::new(ASTNode::SQLIdentifier("c1".into())), @@ -825,6 +864,8 @@ fn parse_joins_using() { relation: TableFactor::Table { name: SQLObjectName(vec![relation.into()]), alias, + args: None, + with_hints: vec![], }, join_operator: f(JoinConstraint::Using(vec!["c1".into()])), }