Merge pull request #35 from nickolay/consume-token

Clean up consume_token() and parse/expect_keyword()
This commit is contained in:
Andy Grove 2019-01-13 09:16:13 -07:00 committed by GitHub
commit 47e00af15b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 57 deletions

View file

@ -12,8 +12,7 @@ impl Dialect for GenericSqlDialect {
CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT,
REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC,
BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL,
CROSS, OUTER, INNER, NATURAL, ON, USING, CROSS, OUTER, INNER, NATURAL, ON, USING, LIKE,
BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, LIKE,
]; ];
} }

View file

@ -14,7 +14,8 @@ impl Dialect for PostgreSqlDialect {
DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN,
DATE, TIME, TIMESTAMP, VALUES, DEFAULT, ZONE, REGCLASS, TEXT, BYTEA, TRUE, FALSE, COPY, DATE, TIME, TIMESTAMP, VALUES, DEFAULT, ZONE, REGCLASS, TEXT, BYTEA, TRUE, FALSE, COPY,
STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, CASE, WHEN, STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, CASE, WHEN,
THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, CROSS, OUTER, INNER, NATURAL, ON, USING, LIKE THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, CROSS, OUTER, INNER, NATURAL, ON, USING,
LIKE,
]; ];
} }

View file

@ -90,7 +90,6 @@ impl Parser {
let mut expr = self.parse_prefix()?; let mut expr = self.parse_prefix()?;
debug!("prefix: {:?}", expr); debug!("prefix: {:?}", expr);
loop { loop {
// stop parsing on `NULL` | `NOT NULL` // stop parsing on `NULL` | `NOT NULL`
match self.peek_token() { match self.peek_token() {
Some(Token::Keyword(ref k)) if k == "NOT" || k == "NULL" => break, Some(Token::Keyword(ref k)) if k == "NOT" || k == "NULL" => break,
@ -142,7 +141,7 @@ impl Parser {
Some(Token::Period) => { Some(Token::Period) => {
let mut id_parts: Vec<String> = vec![id]; let mut id_parts: Vec<String> = vec![id];
while self.peek_token() == Some(Token::Period) { while self.peek_token() == Some(Token::Period) {
self.consume_token(&Token::Period)?; self.expect_token(&Token::Period)?;
match self.next_token() { match self.next_token() {
Some(Token::Identifier(id)) => id_parts.push(id), Some(Token::Identifier(id)) => id_parts.push(id),
_ => { _ => {
@ -167,9 +166,7 @@ impl Parser {
} }
Token::LParen => { Token::LParen => {
let expr = self.parse(); let expr = self.parse();
if !self.consume_token(&Token::RParen)? { self.expect_token(&Token::RParen)?;
return parser_err!(format!("expected token RParen"));
}
expr expr
} }
_ => parser_err!(format!( _ => parser_err!(format!(
@ -182,15 +179,15 @@ impl Parser {
} }
pub fn parse_function(&mut self, id: &str) -> Result<ASTNode, ParserError> { pub fn parse_function(&mut self, id: &str) -> Result<ASTNode, ParserError> {
self.consume_token(&Token::LParen)?; self.expect_token(&Token::LParen)?;
if let Ok(true) = self.consume_token(&Token::RParen) { if self.consume_token(&Token::RParen) {
Ok(ASTNode::SQLFunction { Ok(ASTNode::SQLFunction {
id: id.to_string(), id: id.to_string(),
args: vec![], args: vec![],
}) })
} else { } else {
let args = self.parse_expr_list()?; let args = self.parse_expr_list()?;
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok(ASTNode::SQLFunction { Ok(ASTNode::SQLFunction {
id: id.to_string(), id: id.to_string(),
args, args,
@ -205,7 +202,7 @@ impl Parser {
let mut else_result = None; let mut else_result = None;
loop { loop {
conditions.push(self.parse_expr(0)?); conditions.push(self.parse_expr(0)?);
self.consume_token(&Token::Keyword("THEN".to_string()))?; self.expect_keyword("THEN")?;
results.push(self.parse_expr(0)?); results.push(self.parse_expr(0)?);
if self.parse_keywords(vec!["ELSE"]) { if self.parse_keywords(vec!["ELSE"]) {
else_result = Some(Box::new(self.parse_expr(0)?)); else_result = Some(Box::new(self.parse_expr(0)?));
@ -218,7 +215,7 @@ impl Parser {
if self.parse_keywords(vec!["END"]) { if self.parse_keywords(vec!["END"]) {
break; break;
} }
self.consume_token(&Token::Keyword("WHEN".to_string()))?; self.expect_keyword("WHEN")?;
} }
Ok(ASTNode::SQLCase { Ok(ASTNode::SQLCase {
conditions, conditions,
@ -234,11 +231,11 @@ impl Parser {
/// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)`
pub fn parse_cast_expression(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_cast_expression(&mut self) -> Result<ASTNode, ParserError> {
self.consume_token(&Token::LParen)?; self.expect_token(&Token::LParen)?;
let expr = self.parse_expr(0)?; let expr = self.parse_expr(0)?;
self.consume_token(&Token::Keyword("AS".to_string()))?; self.expect_keyword("AS")?;
let data_type = self.parse_data_type()?; let data_type = self.parse_data_type()?;
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok(ASTNode::SQLCast { Ok(ASTNode::SQLCast {
expr: Box::new(expr), expr: Box::new(expr),
data_type, data_type,
@ -247,7 +244,6 @@ impl Parser {
/// Parse a postgresql casting style which is in the form of `expr::datatype` /// Parse a postgresql casting style which is in the form of `expr::datatype`
pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result<ASTNode, ParserError> { pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result<ASTNode, ParserError> {
let _ = self.consume_token(&Token::DoubleColon)?;
Ok(ASTNode::SQLCast { Ok(ASTNode::SQLCast {
expr: Box::new(expr), expr: Box::new(expr),
data_type: self.parse_data_type()?, data_type: self.parse_data_type()?,
@ -449,6 +445,7 @@ impl Parser {
} }
/// Look for an expected keyword and consume it if it exists /// Look for an expected keyword and consume it if it exists
#[must_use]
pub fn parse_keyword(&mut self, expected: &'static str) -> bool { pub fn parse_keyword(&mut self, expected: &'static str) -> bool {
match self.peek_token() { match self.peek_token() {
Some(Token::Keyword(k)) => { Some(Token::Keyword(k)) => {
@ -464,6 +461,7 @@ impl Parser {
} }
/// Look for an expected sequence of keywords and consume them if they exist /// Look for an expected sequence of keywords and consume them if they exist
#[must_use]
pub fn parse_keywords(&mut self, keywords: Vec<&'static str>) -> bool { pub fn parse_keywords(&mut self, keywords: Vec<&'static str>) -> bool {
let index = self.index; let index = self.index;
for keyword in keywords { for keyword in keywords {
@ -477,6 +475,7 @@ impl Parser {
true true
} }
/// Bail out if the current token is not an expected keyword, or consume it if it is
pub fn expect_keyword(&mut self, expected: &'static str) -> Result<(), ParserError> { pub fn expect_keyword(&mut self, expected: &'static str) -> Result<(), ParserError> {
if self.parse_keyword(expected) { if self.parse_keyword(expected) {
Ok(()) Ok(())
@ -489,20 +488,32 @@ impl Parser {
} }
} }
//TODO: this function is inconsistent and sometimes returns bool and sometimes fails /// Consume the next token if it matches the expected token, otherwise return false
#[must_use]
/// Consume the next token if it matches the expected token, otherwise return an error pub fn consume_token(&mut self, expected: &Token) -> bool {
pub fn consume_token(&mut self, expected: &Token) -> Result<bool, ParserError> {
match self.peek_token() { match self.peek_token() {
Some(ref t) => { Some(ref t) => {
if *t == *expected { if *t == *expected {
self.next_token(); self.next_token();
Ok(true) true
} else { } else {
Ok(false) false
} }
} }
other => parser_err!(format!("expected token {:?} but was {:?}", expected, other,)), _ => false,
}
}
/// Bail out if the current token is not an expected keyword, or consume it if it is
pub fn expect_token(&mut self, expected: &Token) -> Result<(), ParserError> {
if self.consume_token(expected) {
Ok(())
} else {
parser_err!(format!(
"Expected token {:?}, found {:?}",
expected,
self.peek_token()
))
} }
} }
@ -512,7 +523,7 @@ impl Parser {
let table_name = self.parse_tablename()?; let table_name = self.parse_tablename()?;
// parse optional column list (schema) // parse optional column list (schema)
let mut columns = vec![]; let mut columns = vec![];
if self.consume_token(&Token::LParen)? { if self.consume_token(&Token::LParen) {
loop { loop {
if let Some(Token::Identifier(column_name)) = self.next_token() { if let Some(Token::Identifier(column_name)) = self.next_token() {
if let Ok(data_type) = self.parse_data_type() { if let Ok(data_type) = self.parse_data_type() {
@ -590,9 +601,9 @@ impl Parser {
let is_primary_key = self.parse_keywords(vec!["PRIMARY", "KEY"]); let is_primary_key = self.parse_keywords(vec!["PRIMARY", "KEY"]);
let is_unique_key = self.parse_keywords(vec!["UNIQUE", "KEY"]); let is_unique_key = self.parse_keywords(vec!["UNIQUE", "KEY"]);
let is_foreign_key = self.parse_keywords(vec!["FOREIGN", "KEY"]); let is_foreign_key = self.parse_keywords(vec!["FOREIGN", "KEY"]);
self.consume_token(&Token::LParen)?; self.expect_token(&Token::LParen)?;
let column_names = self.parse_column_names()?; let column_names = self.parse_column_names()?;
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
let key = Key { let key = Key {
name: constraint_name.to_string(), name: constraint_name.to_string(),
columns: column_names, columns: column_names,
@ -604,9 +615,9 @@ impl Parser {
} else if is_foreign_key { } else if is_foreign_key {
if self.parse_keyword("REFERENCES") { if self.parse_keyword("REFERENCES") {
let foreign_table = self.parse_tablename()?; let foreign_table = self.parse_tablename()?;
self.consume_token(&Token::LParen)?; self.expect_token(&Token::LParen)?;
let referred_columns = self.parse_column_names()?; let referred_columns = self.parse_column_names()?;
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok(TableKey::ForeignKey { Ok(TableKey::ForeignKey {
key, key,
foreign_table, foreign_table,
@ -662,16 +673,16 @@ impl Parser {
/// Parse a copy statement /// Parse a copy statement
pub fn parse_copy(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_copy(&mut self) -> Result<ASTNode, ParserError> {
let table_name = self.parse_tablename()?; let table_name = self.parse_tablename()?;
let columns = if self.consume_token(&Token::LParen)? { let columns = if self.consume_token(&Token::LParen) {
let column_names = self.parse_column_names()?; let column_names = self.parse_column_names()?;
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
column_names column_names
} else { } else {
vec![] vec![]
}; };
self.parse_keyword("FROM"); self.expect_keyword("FROM")?;
self.parse_keyword("STDIN"); self.expect_keyword("STDIN")?;
self.consume_token(&Token::SemiColon)?; self.expect_token(&Token::SemiColon)?;
let values = self.parse_tsv()?; let values = self.parse_tsv()?;
Ok(ASTNode::SQLCopy { Ok(ASTNode::SQLCopy {
table_name, table_name,
@ -705,7 +716,7 @@ impl Parser {
content.clear(); content.clear();
} }
Token::Backslash => { Token::Backslash => {
if let Ok(true) = self.consume_token(&Token::Period) { if self.consume_token(&Token::Period) {
return Ok(values); return Ok(values);
} }
if let Some(token) = self.next_token() { if let Some(token) = self.next_token() {
@ -830,9 +841,9 @@ impl Parser {
} }
pub fn parse_date(&mut self, year: i64) -> Result<NaiveDate, ParserError> { pub fn parse_date(&mut self, year: i64) -> Result<NaiveDate, ParserError> {
if let Ok(true) = self.consume_token(&Token::Minus) { if self.consume_token(&Token::Minus) {
let month = self.parse_literal_int()?; let month = self.parse_literal_int()?;
if let Ok(true) = self.consume_token(&Token::Minus) { if self.consume_token(&Token::Minus) {
let day = self.parse_literal_int()?; let day = self.parse_literal_int()?;
let date = NaiveDate::from_ymd(year as i32, month as u32, day as u32); let date = NaiveDate::from_ymd(year as i32, month as u32, day as u32);
Ok(date) Ok(date)
@ -852,9 +863,9 @@ impl Parser {
pub fn parse_time(&mut self) -> Result<NaiveTime, ParserError> { pub fn parse_time(&mut self) -> Result<NaiveTime, ParserError> {
let hour = self.parse_literal_int()?; let hour = self.parse_literal_int()?;
self.consume_token(&Token::Colon)?; self.expect_token(&Token::Colon)?;
let min = self.parse_literal_int()?; let min = self.parse_literal_int()?;
self.consume_token(&Token::Colon)?; self.expect_token(&Token::Colon)?;
// On one hand, the SQL specs defines <seconds fraction> ::= <unsigned integer>, // On one hand, the SQL specs defines <seconds fraction> ::= <unsigned integer>,
// so it would be more correct to parse it as such // so it would be more correct to parse it as such
let sec = self.parse_literal_double()?; let sec = self.parse_literal_double()?;
@ -943,8 +954,9 @@ impl Parser {
} }
"REGCLASS" => Ok(SQLType::Regclass), "REGCLASS" => Ok(SQLType::Regclass),
"TEXT" => { "TEXT" => {
if let Ok(true) = self.consume_token(&Token::LBracket) { if self.consume_token(&Token::LBracket) {
self.consume_token(&Token::RBracket)?; // Note: this is postgresql-specific
self.expect_token(&Token::RBracket)?;
Ok(SQLType::Array(Box::new(SQLType::Text))) Ok(SQLType::Array(Box::new(SQLType::Text)))
} else { } else {
Ok(SQLType::Text) Ok(SQLType::Text)
@ -1026,10 +1038,10 @@ impl Parser {
} }
pub fn parse_optional_precision(&mut self) -> Result<Option<usize>, ParserError> { pub fn parse_optional_precision(&mut self) -> Result<Option<usize>, ParserError> {
if self.consume_token(&Token::LParen)? { if self.consume_token(&Token::LParen) {
let n = self.parse_literal_int()?; let n = self.parse_literal_int()?;
//TODO: check return value of reading rparen //TODO: check return value of reading rparen
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok(Some(n as usize)) Ok(Some(n as usize))
} else { } else {
Ok(None) Ok(None)
@ -1039,14 +1051,14 @@ impl Parser {
pub fn parse_optional_precision_scale( pub fn parse_optional_precision_scale(
&mut self, &mut self,
) -> Result<(usize, Option<usize>), ParserError> { ) -> Result<(usize, Option<usize>), ParserError> {
if self.consume_token(&Token::LParen)? { if self.consume_token(&Token::LParen) {
let n = self.parse_literal_int()?; let n = self.parse_literal_int()?;
let scale = if let Ok(true) = self.consume_token(&Token::Comma) { let scale = if self.consume_token(&Token::Comma) {
Some(self.parse_literal_int()? as usize) Some(self.parse_literal_int()? as usize)
} else { } else {
None None
}; };
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok((n as usize, scale)) Ok((n as usize, scale))
} else { } else {
parser_err!("Expecting `(`") parser_err!("Expecting `(`")
@ -1153,7 +1165,7 @@ impl Parser {
let constraint = self.parse_expr(0)?; let constraint = self.parse_expr(0)?;
Ok(JoinConstraint::On(constraint)) Ok(JoinConstraint::On(constraint))
} else if self.parse_keyword("USING") { } else if self.parse_keyword("USING") {
if self.consume_token(&Token::LParen)? { if self.consume_token(&Token::LParen) {
let attributes = self let attributes = self
.parse_expr_list()? .parse_expr_list()?
.into_iter() .into_iter()
@ -1165,7 +1177,7 @@ impl Parser {
}) })
.collect::<Result<Vec<String>, ParserError>>()?; .collect::<Result<Vec<String>, ParserError>>()?;
if self.consume_token(&Token::RParen)? { if self.consume_token(&Token::RParen) {
Ok(JoinConstraint::Using(attributes)) Ok(JoinConstraint::Using(attributes))
} else { } else {
parser_err!(format!("Expected token ')', found {:?}", self.peek_token())) parser_err!(format!("Expected token ')', found {:?}", self.peek_token()))
@ -1232,7 +1244,7 @@ impl Parser {
} }
Some(Token::Keyword(kw)) if kw == "LEFT" => { Some(Token::Keyword(kw)) if kw == "LEFT" => {
self.next_token(); self.next_token();
self.parse_keyword("OUTER"); let _ = self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?; self.expect_keyword("JOIN")?;
Join { Join {
relation: self.parse_expr(0)?, relation: self.parse_expr(0)?,
@ -1243,7 +1255,7 @@ impl Parser {
} }
Some(Token::Keyword(kw)) if kw == "RIGHT" => { Some(Token::Keyword(kw)) if kw == "RIGHT" => {
self.next_token(); self.next_token();
self.parse_keyword("OUTER"); let _ = self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?; self.expect_keyword("JOIN")?;
Join { Join {
relation: self.parse_expr(0)?, relation: self.parse_expr(0)?,
@ -1254,7 +1266,7 @@ impl Parser {
} }
Some(Token::Keyword(kw)) if kw == "FULL" => { Some(Token::Keyword(kw)) if kw == "FULL" => {
self.next_token(); self.next_token();
self.parse_keyword("OUTER"); let _ = self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?; self.expect_keyword("JOIN")?;
Join { Join {
relation: self.parse_expr(0)?, relation: self.parse_expr(0)?,
@ -1273,19 +1285,19 @@ impl Parser {
/// Parse an INSERT statement /// Parse an INSERT statement
pub fn parse_insert(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_insert(&mut self) -> Result<ASTNode, ParserError> {
self.parse_keyword("INTO"); self.expect_keyword("INTO")?;
let table_name = self.parse_tablename()?; let table_name = self.parse_tablename()?;
let columns = if self.consume_token(&Token::LParen)? { let columns = if self.consume_token(&Token::LParen) {
let column_names = self.parse_column_names()?; let column_names = self.parse_column_names()?;
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
column_names column_names
} else { } else {
vec![] vec![]
}; };
self.parse_keyword("VALUES"); self.expect_keyword("VALUES")?;
self.consume_token(&Token::LParen)?; self.expect_token(&Token::LParen)?;
let values = self.parse_expr_list()?; let values = self.parse_expr_list()?;
self.consume_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok(ASTNode::SQLInsert { Ok(ASTNode::SQLInsert {
table_name, table_name,
columns, columns,

View file

@ -221,6 +221,16 @@ fn parse_insert_with_columns() {
} }
} }
#[test]
fn parse_insert_invalid() {
let sql = String::from("INSERT public.customer (id, name, active) VALUES (1, 2, 3)");
let mut parser = parser(&sql);
match parser.parse() {
Err(_) => {}
_ => assert!(false),
}
}
#[test] #[test]
fn parse_select_wildcard() { fn parse_select_wildcard() {
let sql = String::from("SELECT * FROM customer"); let sql = String::from("SELECT * FROM customer");