Support parsing of multiple statements (5/5)

Parser::parse_sql() can now parse a semicolon-separated list of
statements, returning them in a Vec<SQLStatement>.

To support this we:

  - Move handling of inter-statement tokens from the end of individual
    statement parsers (`parse_select` and `parse_delete`; this was not
    implemented for other top-level statements) to the common
    statement-list parsing code (`parse_sql`);

  - Change the "Unexpected token at end of ..." error, which didn't have
    tests and prevented us from parsing successive statements  ->
    "Expected end of statement" (i.e. a delimiter - currently only ";" -
    or the EOF);

  - Add PartialEq on ParserError to be able to assert_eq!() that parsing
    statements that do not terminate properly returns an expected error.
This commit is contained in:
Nickolay Ponomarev 2019-01-30 00:58:55 +03:00
parent 5a0e0ec928
commit 707c58ad57
3 changed files with 91 additions and 71 deletions

View file

@ -20,7 +20,7 @@ use super::sqlast::*;
use super::sqltokenizer::*;
use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime};
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub enum ParserError {
TokenizerError(String),
ParserError(String),
@ -54,14 +54,36 @@ impl Parser {
}
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &Dialect, sql: String) -> Result<SQLStatement, ParserError> {
pub fn parse_sql(dialect: &Dialect, sql: String) -> Result<Vec<SQLStatement>, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, &sql);
let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens);
parser.parse_statement()
let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false;
loop {
// ignore empty statements (between successive statement delimiters)
while parser.consume_token(&Token::SemiColon) {
expecting_statement_delimiter = false;
}
if parser.peek_token().is_none() {
break;
} else if expecting_statement_delimiter {
return parser_err!(format!(
"Expected end of statement, found: {}",
parser.peek_token().unwrap().to_string()
));
}
let statement = parser.parse_statement()?;
stmts.push(statement);
expecting_statement_delimiter = true;
}
Ok(stmts)
}
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.)
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
/// stopping before the statement separator, if any.
pub fn parse_statement(&mut self) -> Result<SQLStatement, ParserError> {
match self.next_token() {
Some(t) => match t {
@ -1095,20 +1117,10 @@ impl Parser {
None
};
let _ = self.consume_token(&Token::SemiColon);
// parse next token
if let Some(next_token) = self.peek_token() {
parser_err!(format!(
"Unexpected token at end of DELETE: {:?}",
next_token
))
} else {
Ok(SQLStatement::SQLDelete {
relation,
selection,
})
}
Ok(SQLStatement::SQLDelete {
relation,
selection,
})
}
/// Parse a SELECT statement
@ -1154,25 +1166,16 @@ impl Parser {
None
};
let _ = self.consume_token(&Token::SemiColon);
if let Some(next_token) = self.peek_token() {
parser_err!(format!(
"Unexpected token at end of SELECT: {:?}",
next_token
))
} else {
Ok(SQLSelect {
projection,
selection,
relation,
joins,
limit,
order_by,
group_by,
having,
})
}
Ok(SQLSelect {
projection,
selection,
relation,
joins,
limit,
order_by,
group_by,
having,
})
}
/// A table name or a parenthesized subquery, followed by optional `[AS] alias`

View file

@ -473,34 +473,6 @@ fn parse_case_expression() {
);
}
#[test]
fn parse_select_with_semi_colon() {
let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1;");
match one_statement_parses_to(&sql, "") {
SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(3, projection.len());
}
_ => assert!(false),
}
}
#[test]
fn parse_delete_with_semi_colon() {
let sql: &str = "DELETE FROM 'table';";
match one_statement_parses_to(&sql, "") {
SQLStatement::SQLDelete { relation, .. } => {
assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string()
)))),
relation
);
}
_ => assert!(false),
}
}
#[test]
fn parse_implicit_join() {
let sql = "SELECT * FROM t1, t2";
@ -669,6 +641,37 @@ fn parse_join_syntax_variants() {
);
}
#[test]
fn parse_multiple_statements() {
fn test_with(sql1: &str, sql2_kw: &str, sql2_rest: &str) {
// Check that a string consisting of two statements delimited by a semicolon
// parses the same as both statements individually:
let res = parse_sql_statements(&(sql1.to_owned() + ";" + sql2_kw + sql2_rest));
assert_eq!(
vec![
one_statement_parses_to(&sql1, ""),
one_statement_parses_to(&(sql2_kw.to_owned() + sql2_rest), ""),
],
res.unwrap()
);
// Check that extra semicolon at the end is stripped by normalization:
one_statement_parses_to(&(sql1.to_owned() + ";"), sql1);
// Check that forgetting the semicolon results in an error:
let res = parse_sql_statements(&(sql1.to_owned() + " " + sql2_kw + sql2_rest));
assert_eq!(
ParserError::ParserError("Expected end of statement, found: ".to_string() + sql2_kw),
res.unwrap_err()
);
}
test_with("SELECT foo", "SELECT", " bar");
test_with("DELETE FROM foo", "SELECT", " bar");
test_with("INSERT INTO foo VALUES(1)", "SELECT", " bar");
test_with("CREATE TABLE foo (baz int)", "SELECT", " bar");
// Make sure that empty statements do not cause an error:
let res = parse_sql_statements(";;");
assert_eq!(0, res.unwrap().len());
}
fn only<'a, T>(v: &'a Vec<T>) -> &'a T {
assert_eq!(1, v.len());
v.first().unwrap()
@ -699,17 +702,24 @@ fn verified_expr(query: &str) -> ASTNode {
ast
}
/// Ensures that `sql` parses as a statement, optionally checking that
/// Ensures that `sql` parses as a single statement, optionally checking that
/// converting AST back to string equals to `canonical` (unless an empty string
/// is provided).
fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement {
let generic_ast = Parser::parse_sql(&GenericSqlDialect {}, sql.to_string()).unwrap();
let pg_ast = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()).unwrap();
assert_eq!(generic_ast, pg_ast);
let mut statements = parse_sql_statements(&sql).unwrap();
assert_eq!(statements.len(), 1);
let only_statement = statements.pop().unwrap();
if !canonical.is_empty() {
assert_eq!(canonical, generic_ast.to_string())
assert_eq!(canonical, only_statement.to_string())
}
only_statement
}
fn parse_sql_statements(sql: &str) -> Result<Vec<SQLStatement>, ParserError> {
let generic_ast = Parser::parse_sql(&GenericSqlDialect {}, sql.to_string());
let pg_ast = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string());
assert_eq!(generic_ast, pg_ast);
generic_ast
}

View file

@ -372,13 +372,20 @@ fn verified_stmt(query: &str) -> SQLStatement {
/// converting AST back to string equals to `canonical` (unless an empty string
/// is provided).
fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement {
let only_statement = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()).unwrap();
let mut statements = parse_sql_statements(&sql).unwrap();
assert_eq!(statements.len(), 1);
let only_statement = statements.pop().unwrap();
if !canonical.is_empty() {
assert_eq!(canonical, only_statement.to_string())
}
only_statement
}
fn parse_sql_statements(sql: &str) -> Result<Vec<SQLStatement>, ParserError> {
Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string())
}
fn parse_sql_expr(sql: &str) -> ASTNode {
debug!("sql: {}", sql);
let mut parser = parser(sql);