diff --git a/src/dialect/generic_sql.rs b/src/dialect/generic_sql.rs index 3788a749..0f18b723 100644 --- a/src/dialect/generic_sql.rs +++ b/src/dialect/generic_sql.rs @@ -12,8 +12,7 @@ impl Dialect for GenericSqlDialect { CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, - CROSS, OUTER, INNER, NATURAL, ON, USING, - BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, LIKE, + CROSS, OUTER, INNER, NATURAL, ON, USING, LIKE, ]; } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 0c5f4fd6..66cb51c1 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -14,7 +14,8 @@ impl Dialect for PostgreSqlDialect { DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN, DATE, TIME, TIMESTAMP, VALUES, DEFAULT, ZONE, REGCLASS, TEXT, BYTEA, TRUE, FALSE, COPY, STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, CASE, WHEN, - THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, CROSS, OUTER, INNER, NATURAL, ON, USING, LIKE + THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, CROSS, OUTER, INNER, NATURAL, ON, USING, + LIKE, ]; } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 67e0d790..25043087 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -90,7 +90,6 @@ impl Parser { let mut expr = self.parse_prefix()?; debug!("prefix: {:?}", expr); loop { - // stop parsing on `NULL` | `NOT NULL` match self.peek_token() { Some(Token::Keyword(ref k)) if k == "NOT" || k == "NULL" => break, @@ -142,7 +141,7 @@ impl Parser { Some(Token::Period) => { let mut id_parts: Vec = vec![id]; while self.peek_token() == Some(Token::Period) { - self.consume_token(&Token::Period)?; + self.expect_token(&Token::Period)?; match self.next_token() { Some(Token::Identifier(id)) => id_parts.push(id), _ => { @@ -167,9 +166,7 @@ impl Parser { } Token::LParen => { let expr = self.parse(); - if !self.consume_token(&Token::RParen)? { - return parser_err!(format!("expected token RParen")); - } + self.expect_token(&Token::RParen)?; expr } _ => parser_err!(format!( @@ -182,15 +179,15 @@ impl Parser { } pub fn parse_function(&mut self, id: &str) -> Result { - self.consume_token(&Token::LParen)?; - if let Ok(true) = self.consume_token(&Token::RParen) { + self.expect_token(&Token::LParen)?; + if self.consume_token(&Token::RParen) { Ok(ASTNode::SQLFunction { id: id.to_string(), args: vec![], }) } else { let args = self.parse_expr_list()?; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; Ok(ASTNode::SQLFunction { id: id.to_string(), args, @@ -205,7 +202,7 @@ impl Parser { let mut else_result = None; loop { conditions.push(self.parse_expr(0)?); - self.consume_token(&Token::Keyword("THEN".to_string()))?; + self.expect_keyword("THEN")?; results.push(self.parse_expr(0)?); if self.parse_keywords(vec!["ELSE"]) { else_result = Some(Box::new(self.parse_expr(0)?)); @@ -218,7 +215,7 @@ impl Parser { if self.parse_keywords(vec!["END"]) { break; } - self.consume_token(&Token::Keyword("WHEN".to_string()))?; + self.expect_keyword("WHEN")?; } Ok(ASTNode::SQLCase { conditions, @@ -234,11 +231,11 @@ impl Parser { /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expression(&mut self) -> Result { - self.consume_token(&Token::LParen)?; + self.expect_token(&Token::LParen)?; let expr = self.parse_expr(0)?; - self.consume_token(&Token::Keyword("AS".to_string()))?; + self.expect_keyword("AS")?; let data_type = self.parse_data_type()?; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; Ok(ASTNode::SQLCast { expr: Box::new(expr), data_type, @@ -247,7 +244,6 @@ impl Parser { /// Parse a postgresql casting style which is in the form of `expr::datatype` pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result { - let _ = self.consume_token(&Token::DoubleColon)?; Ok(ASTNode::SQLCast { expr: Box::new(expr), data_type: self.parse_data_type()?, @@ -449,6 +445,7 @@ impl Parser { } /// Look for an expected keyword and consume it if it exists + #[must_use] pub fn parse_keyword(&mut self, expected: &'static str) -> bool { match self.peek_token() { Some(Token::Keyword(k)) => { @@ -464,6 +461,7 @@ impl Parser { } /// Look for an expected sequence of keywords and consume them if they exist + #[must_use] pub fn parse_keywords(&mut self, keywords: Vec<&'static str>) -> bool { let index = self.index; for keyword in keywords { @@ -477,6 +475,7 @@ impl Parser { true } + /// Bail out if the current token is not an expected keyword, or consume it if it is pub fn expect_keyword(&mut self, expected: &'static str) -> Result<(), ParserError> { if self.parse_keyword(expected) { Ok(()) @@ -489,20 +488,32 @@ impl Parser { } } - //TODO: this function is inconsistent and sometimes returns bool and sometimes fails - - /// Consume the next token if it matches the expected token, otherwise return an error - pub fn consume_token(&mut self, expected: &Token) -> Result { + /// Consume the next token if it matches the expected token, otherwise return false + #[must_use] + pub fn consume_token(&mut self, expected: &Token) -> bool { match self.peek_token() { Some(ref t) => { if *t == *expected { self.next_token(); - Ok(true) + true } else { - Ok(false) + false } } - other => parser_err!(format!("expected token {:?} but was {:?}", expected, other,)), + _ => false, + } + } + + /// Bail out if the current token is not an expected keyword, or consume it if it is + pub fn expect_token(&mut self, expected: &Token) -> Result<(), ParserError> { + if self.consume_token(expected) { + Ok(()) + } else { + parser_err!(format!( + "Expected token {:?}, found {:?}", + expected, + self.peek_token() + )) } } @@ -512,7 +523,7 @@ impl Parser { let table_name = self.parse_tablename()?; // parse optional column list (schema) let mut columns = vec![]; - if self.consume_token(&Token::LParen)? { + if self.consume_token(&Token::LParen) { loop { if let Some(Token::Identifier(column_name)) = self.next_token() { if let Ok(data_type) = self.parse_data_type() { @@ -590,9 +601,9 @@ impl Parser { let is_primary_key = self.parse_keywords(vec!["PRIMARY", "KEY"]); let is_unique_key = self.parse_keywords(vec!["UNIQUE", "KEY"]); let is_foreign_key = self.parse_keywords(vec!["FOREIGN", "KEY"]); - self.consume_token(&Token::LParen)?; + self.expect_token(&Token::LParen)?; let column_names = self.parse_column_names()?; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; let key = Key { name: constraint_name.to_string(), columns: column_names, @@ -604,9 +615,9 @@ impl Parser { } else if is_foreign_key { if self.parse_keyword("REFERENCES") { let foreign_table = self.parse_tablename()?; - self.consume_token(&Token::LParen)?; + self.expect_token(&Token::LParen)?; let referred_columns = self.parse_column_names()?; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; Ok(TableKey::ForeignKey { key, foreign_table, @@ -662,16 +673,16 @@ impl Parser { /// Parse a copy statement pub fn parse_copy(&mut self) -> Result { let table_name = self.parse_tablename()?; - let columns = if self.consume_token(&Token::LParen)? { + let columns = if self.consume_token(&Token::LParen) { let column_names = self.parse_column_names()?; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; column_names } else { vec![] }; - self.parse_keyword("FROM"); - self.parse_keyword("STDIN"); - self.consume_token(&Token::SemiColon)?; + self.expect_keyword("FROM")?; + self.expect_keyword("STDIN")?; + self.expect_token(&Token::SemiColon)?; let values = self.parse_tsv()?; Ok(ASTNode::SQLCopy { table_name, @@ -705,7 +716,7 @@ impl Parser { content.clear(); } Token::Backslash => { - if let Ok(true) = self.consume_token(&Token::Period) { + if self.consume_token(&Token::Period) { return Ok(values); } if let Some(token) = self.next_token() { @@ -830,9 +841,9 @@ impl Parser { } pub fn parse_date(&mut self, year: i64) -> Result { - if let Ok(true) = self.consume_token(&Token::Minus) { + if self.consume_token(&Token::Minus) { let month = self.parse_literal_int()?; - if let Ok(true) = self.consume_token(&Token::Minus) { + if self.consume_token(&Token::Minus) { let day = self.parse_literal_int()?; let date = NaiveDate::from_ymd(year as i32, month as u32, day as u32); Ok(date) @@ -852,9 +863,9 @@ impl Parser { pub fn parse_time(&mut self) -> Result { let hour = self.parse_literal_int()?; - self.consume_token(&Token::Colon)?; + self.expect_token(&Token::Colon)?; let min = self.parse_literal_int()?; - self.consume_token(&Token::Colon)?; + self.expect_token(&Token::Colon)?; // On one hand, the SQL specs defines ::= , // so it would be more correct to parse it as such let sec = self.parse_literal_double()?; @@ -943,8 +954,9 @@ impl Parser { } "REGCLASS" => Ok(SQLType::Regclass), "TEXT" => { - if let Ok(true) = self.consume_token(&Token::LBracket) { - self.consume_token(&Token::RBracket)?; + if self.consume_token(&Token::LBracket) { + // Note: this is postgresql-specific + self.expect_token(&Token::RBracket)?; Ok(SQLType::Array(Box::new(SQLType::Text))) } else { Ok(SQLType::Text) @@ -1026,10 +1038,10 @@ impl Parser { } pub fn parse_optional_precision(&mut self) -> Result, ParserError> { - if self.consume_token(&Token::LParen)? { + if self.consume_token(&Token::LParen) { let n = self.parse_literal_int()?; //TODO: check return value of reading rparen - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; Ok(Some(n as usize)) } else { Ok(None) @@ -1039,14 +1051,14 @@ impl Parser { pub fn parse_optional_precision_scale( &mut self, ) -> Result<(usize, Option), ParserError> { - if self.consume_token(&Token::LParen)? { + if self.consume_token(&Token::LParen) { let n = self.parse_literal_int()?; - let scale = if let Ok(true) = self.consume_token(&Token::Comma) { + let scale = if self.consume_token(&Token::Comma) { Some(self.parse_literal_int()? as usize) } else { None }; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; Ok((n as usize, scale)) } else { parser_err!("Expecting `(`") @@ -1153,7 +1165,7 @@ impl Parser { let constraint = self.parse_expr(0)?; Ok(JoinConstraint::On(constraint)) } else if self.parse_keyword("USING") { - if self.consume_token(&Token::LParen)? { + if self.consume_token(&Token::LParen) { let attributes = self .parse_expr_list()? .into_iter() @@ -1165,7 +1177,7 @@ impl Parser { }) .collect::, ParserError>>()?; - if self.consume_token(&Token::RParen)? { + if self.consume_token(&Token::RParen) { Ok(JoinConstraint::Using(attributes)) } else { parser_err!(format!("Expected token ')', found {:?}", self.peek_token())) @@ -1232,7 +1244,7 @@ impl Parser { } Some(Token::Keyword(kw)) if kw == "LEFT" => { self.next_token(); - self.parse_keyword("OUTER"); + let _ = self.parse_keyword("OUTER"); self.expect_keyword("JOIN")?; Join { relation: self.parse_expr(0)?, @@ -1243,7 +1255,7 @@ impl Parser { } Some(Token::Keyword(kw)) if kw == "RIGHT" => { self.next_token(); - self.parse_keyword("OUTER"); + let _ = self.parse_keyword("OUTER"); self.expect_keyword("JOIN")?; Join { relation: self.parse_expr(0)?, @@ -1254,7 +1266,7 @@ impl Parser { } Some(Token::Keyword(kw)) if kw == "FULL" => { self.next_token(); - self.parse_keyword("OUTER"); + let _ = self.parse_keyword("OUTER"); self.expect_keyword("JOIN")?; Join { relation: self.parse_expr(0)?, @@ -1273,19 +1285,19 @@ impl Parser { /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { - self.parse_keyword("INTO"); + self.expect_keyword("INTO")?; let table_name = self.parse_tablename()?; - let columns = if self.consume_token(&Token::LParen)? { + let columns = if self.consume_token(&Token::LParen) { let column_names = self.parse_column_names()?; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; column_names } else { vec![] }; - self.parse_keyword("VALUES"); - self.consume_token(&Token::LParen)?; + self.expect_keyword("VALUES")?; + self.expect_token(&Token::LParen)?; let values = self.parse_expr_list()?; - self.consume_token(&Token::RParen)?; + self.expect_token(&Token::RParen)?; Ok(ASTNode::SQLInsert { table_name, columns, diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index e3412b2c..5c354a3b 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -221,6 +221,16 @@ fn parse_insert_with_columns() { } } +#[test] +fn parse_insert_invalid() { + let sql = String::from("INSERT public.customer (id, name, active) VALUES (1, 2, 3)"); + let mut parser = parser(&sql); + match parser.parse() { + Err(_) => {} + _ => assert!(false), + } +} + #[test] fn parse_select_wildcard() { let sql = String::from("SELECT * FROM customer");