Separate statement from expr parsing (4/5)

Continuing from https://github.com/andygrove/sqlparser-rs/pull/33#issuecomment-453060427

This stops the parser from accepting (and the AST from being able to
represent) SQL look-alike code that makes no sense, e.g.

    SELECT ... FROM (CREATE TABLE ...) foo
    SELECT ... FROM (1+CAST(...)) foo

Generally this makes the AST less "partially typed": meaning certain
parts are strongly typed (e.g. SELECT can only contain projections,
relations, etc.), while everything that didn't get its own type is
dumped into ASTNode, effectively untyped. After a few more fixes (yet
to be implemented), `ASTNode` could become an `SQLExpression`. The
Pratt-style expression parser (returning an SQLExpression) would be
invoked from the top-down parser in places where a generic expression
is expected (e.g. after SELECT <...>, WHERE <...>, etc.), while things
like select's `projection` and `relation` could be more appropriately
(narrowly) typed.


Since the diff is quite large due to necessarily large number of
mechanical changes, here's an overview:

1) Interface changes:

   - A new AST enum - `SQLStatement` - is split out of ASTNode:

     - The variants of the ASTNode enum, which _only_ make sense as a top
       level statement (INSERT, UPDATE, DELETE, CREATE, ALTER, COPY) are
       _moved_ to the new enum, with no other changes.
     - SQLSelect is _duplicated_: now available both as a variant in
       SQLStatement::SQLSelect (top-level SELECT) and ASTNode:: (subquery).

   - The main entry point (Parser::parse_sql) now expects an SQL statement
     as input, and returns an `SQLStatement`.

2) Parser changes: instead of detecting the top-level constructs deep
down in the precedence parser (`parse_prefix`) we are able to do it
just right after setting up the parser in the `parse_sql` entry point

(SELECT, again, is kept in the expression parser to demonstrate how
subqueries could be implemented).

The rest of parser changes are mechanical ASTNode -> SQLStatement
replacements resulting from the AST change.

3) Testing changes: for every test - depending on whether the input was
a complete statement or an expresssion -  I used an appropriate helper
function:

   - `verified` (parses SQL, checks that it round-trips, and returns
     the AST) - was replaced by `verified_stmt` or `verified_expr`.

   - `parse_sql` (which returned AST without checking it round-tripped)
     was replaced by:

     - `parse_sql_expr` (same function, for expressions)

     - `one_statement_parses_to` (formerly `parses_to`), extended to
       deal with statements that are not expected to round-trip.
       The weird name is to reduce further churn when implementing
       multi-statement parsing.

     - `verified_stmt` (in 4 testcases that actually round-tripped)
This commit is contained in:
Nickolay Ponomarev 2019-01-31 04:56:20 +03:00
parent 7b86f5c842
commit 2dec65fdb4
5 changed files with 245 additions and 193 deletions

View file

@ -79,52 +79,6 @@ pub enum ASTNode {
}, },
/// SELECT /// SELECT
SQLSelect(SQLSelect), SQLSelect(SQLSelect),
/// INSERT
SQLInsert {
/// TABLE
table_name: String,
/// COLUMNS
columns: Vec<SQLIdent>,
/// VALUES (vector of rows to insert)
values: Vec<Vec<ASTNode>>,
},
SQLCopy {
/// TABLE
table_name: String,
/// COLUMNS
columns: Vec<SQLIdent>,
/// VALUES a vector of values to be copied
values: Vec<Option<String>>,
},
/// UPDATE
SQLUpdate {
/// TABLE
table_name: String,
/// Column assignments
assignments: Vec<SQLAssignment>,
/// WHERE
selection: Option<Box<ASTNode>>,
},
/// DELETE
SQLDelete {
/// FROM
relation: Option<Box<ASTNode>>,
/// WHERE
selection: Option<Box<ASTNode>>,
},
/// CREATE TABLE
SQLCreateTable {
/// Table name
name: String,
/// Optional schema
columns: Vec<SQLColumnDef>,
},
/// ALTER TABLE
SQLAlterTable {
/// Table name
name: String,
operation: AlterOperation,
},
} }
impl ToString for ASTNode { impl ToString for ASTNode {
@ -186,7 +140,68 @@ impl ToString for ASTNode {
} }
} }
ASTNode::SQLSelect(s) => s.to_string(), ASTNode::SQLSelect(s) => s.to_string(),
ASTNode::SQLInsert { }
}
}
/// A top-level statement (SELECT, INSERT, CREATE, etc.)
#[derive(Debug, Clone, PartialEq)]
pub enum SQLStatement {
/// SELECT
SQLSelect(SQLSelect),
/// INSERT
SQLInsert {
/// TABLE
table_name: String,
/// COLUMNS
columns: Vec<SQLIdent>,
/// VALUES (vector of rows to insert)
values: Vec<Vec<ASTNode>>,
},
SQLCopy {
/// TABLE
table_name: String,
/// COLUMNS
columns: Vec<SQLIdent>,
/// VALUES a vector of values to be copied
values: Vec<Option<String>>,
},
/// UPDATE
SQLUpdate {
/// TABLE
table_name: String,
/// Column assignments
assignments: Vec<SQLAssignment>,
/// WHERE
selection: Option<Box<ASTNode>>,
},
/// DELETE
SQLDelete {
/// FROM
relation: Option<Box<ASTNode>>,
/// WHERE
selection: Option<Box<ASTNode>>,
},
/// CREATE TABLE
SQLCreateTable {
/// Table name
name: String,
/// Optional schema
columns: Vec<SQLColumnDef>,
},
/// ALTER TABLE
SQLAlterTable {
/// Table name
name: String,
operation: AlterOperation,
},
}
impl ToString for SQLStatement {
fn to_string(&self) -> String {
match self {
SQLStatement::SQLSelect(s) => s.to_string(),
SQLStatement::SQLInsert {
table_name, table_name,
columns, columns,
values, values,
@ -211,7 +226,7 @@ impl ToString for ASTNode {
} }
s s
} }
ASTNode::SQLCopy { SQLStatement::SQLCopy {
table_name, table_name,
columns, columns,
values, values,
@ -241,7 +256,7 @@ impl ToString for ASTNode {
s += "\n\\."; s += "\n\\.";
s s
} }
ASTNode::SQLUpdate { SQLStatement::SQLUpdate {
table_name, table_name,
assignments, assignments,
selection, selection,
@ -262,7 +277,7 @@ impl ToString for ASTNode {
} }
s s
} }
ASTNode::SQLDelete { SQLStatement::SQLDelete {
relation, relation,
selection, selection,
} => { } => {
@ -275,7 +290,7 @@ impl ToString for ASTNode {
} }
s s
} }
ASTNode::SQLCreateTable { name, columns } => format!( SQLStatement::SQLCreateTable { name, columns } => format!(
"CREATE TABLE {} ({})", "CREATE TABLE {} ({})",
name, name,
columns columns
@ -284,7 +299,7 @@ impl ToString for ASTNode {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", ") .join(", ")
), ),
ASTNode::SQLAlterTable { name, operation } => { SQLStatement::SQLAlterTable { name, operation } => {
format!("ALTER TABLE {} {}", name, operation.to_string()) format!("ALTER TABLE {} {}", name, operation.to_string())
} }
} }

View file

@ -54,11 +54,36 @@ impl Parser {
} }
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST) /// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &Dialect, sql: String) -> Result<ASTNode, ParserError> { pub fn parse_sql(dialect: &Dialect, sql: String) -> Result<SQLStatement, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, &sql); let mut tokenizer = Tokenizer::new(dialect, &sql);
let tokens = tokenizer.tokenize()?; let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens); let mut parser = Parser::new(tokens);
parser.parse() parser.parse_statement()
}
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.)
pub fn parse_statement(&mut self) -> Result<SQLStatement, ParserError> {
match self.next_token() {
Some(t) => match t {
Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() {
"SELECT" => Ok(SQLStatement::SQLSelect(self.parse_select()?)),
"CREATE" => Ok(self.parse_create()?),
"DELETE" => Ok(self.parse_delete()?),
"INSERT" => Ok(self.parse_insert()?),
"ALTER" => Ok(self.parse_alter()?),
"COPY" => Ok(self.parse_copy()?),
_ => parser_err!(format!(
"Unexpected keyword {:?} at the beginning of a statement",
w.to_string()
)),
},
unexpected => parser_err!(format!(
"Unexpected {:?} at the beginning of a statement",
unexpected
)),
},
_ => parser_err!("Unexpected end of file"),
}
} }
/// Parse a new expression /// Parse a new expression
@ -111,12 +136,7 @@ impl Parser {
match self.next_token() { match self.next_token() {
Some(t) => match t { Some(t) => match t {
Token::SQLWord(w) => match w.keyword.as_ref() { Token::SQLWord(w) => match w.keyword.as_ref() {
"SELECT" => Ok(self.parse_select()?), "SELECT" => Ok(ASTNode::SQLSelect(self.parse_select()?)),
"CREATE" => Ok(self.parse_create()?),
"DELETE" => Ok(self.parse_delete()?),
"INSERT" => Ok(self.parse_insert()?),
"ALTER" => Ok(self.parse_alter()?),
"COPY" => Ok(self.parse_copy()?),
"TRUE" | "FALSE" | "NULL" => { "TRUE" | "FALSE" | "NULL" => {
self.prev_token(); self.prev_token();
self.parse_sql_value() self.parse_sql_value()
@ -495,7 +515,7 @@ impl Parser {
} }
/// Parse a SQL CREATE statement /// Parse a SQL CREATE statement
pub fn parse_create(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_create(&mut self) -> Result<SQLStatement, ParserError> {
if self.parse_keywords(vec!["TABLE"]) { if self.parse_keywords(vec!["TABLE"]) {
let table_name = self.parse_tablename()?; let table_name = self.parse_tablename()?;
// parse optional column list (schema) // parse optional column list (schema)
@ -562,7 +582,7 @@ impl Parser {
} }
} }
} }
Ok(ASTNode::SQLCreateTable { Ok(SQLStatement::SQLCreateTable {
name: table_name, name: table_name,
columns, columns,
}) })
@ -608,7 +628,7 @@ impl Parser {
} }
} }
pub fn parse_alter(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_alter(&mut self) -> Result<SQLStatement, ParserError> {
self.expect_keyword("TABLE")?; self.expect_keyword("TABLE")?;
let _ = self.parse_keyword("ONLY"); let _ = self.parse_keyword("ONLY");
let table_name = self.parse_tablename()?; let table_name = self.parse_tablename()?;
@ -632,14 +652,14 @@ impl Parser {
self.peek_token() self.peek_token()
)); ));
}; };
Ok(ASTNode::SQLAlterTable { Ok(SQLStatement::SQLAlterTable {
name: table_name, name: table_name,
operation: operation?, operation: operation?,
}) })
} }
/// Parse a copy statement /// Parse a copy statement
pub fn parse_copy(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_copy(&mut self) -> Result<SQLStatement, ParserError> {
let table_name = self.parse_tablename()?; let table_name = self.parse_tablename()?;
let columns = if self.consume_token(&Token::LParen) { let columns = if self.consume_token(&Token::LParen) {
let column_names = self.parse_column_names()?; let column_names = self.parse_column_names()?;
@ -652,7 +672,7 @@ impl Parser {
self.expect_keyword("STDIN")?; self.expect_keyword("STDIN")?;
self.expect_token(&Token::SemiColon)?; self.expect_token(&Token::SemiColon)?;
let values = self.parse_tsv()?; let values = self.parse_tsv()?;
Ok(ASTNode::SQLCopy { Ok(SQLStatement::SQLCopy {
table_name, table_name,
columns, columns,
values, values,
@ -1062,7 +1082,7 @@ impl Parser {
} }
} }
pub fn parse_delete(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_delete(&mut self) -> Result<SQLStatement, ParserError> {
let relation: Option<Box<ASTNode>> = if self.parse_keyword("FROM") { let relation: Option<Box<ASTNode>> = if self.parse_keyword("FROM") {
Some(Box::new(self.parse_expr(0)?)) Some(Box::new(self.parse_expr(0)?))
} else { } else {
@ -1084,7 +1104,7 @@ impl Parser {
next_token next_token
)) ))
} else { } else {
Ok(ASTNode::SQLDelete { Ok(SQLStatement::SQLDelete {
relation, relation,
selection, selection,
}) })
@ -1092,7 +1112,7 @@ impl Parser {
} }
/// Parse a SELECT statement /// Parse a SELECT statement
pub fn parse_select(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_select(&mut self) -> Result<SQLSelect, ParserError> {
let projection = self.parse_expr_list()?; let projection = self.parse_expr_list()?;
let (relation, joins): (Option<Box<ASTNode>>, Vec<Join>) = if self.parse_keyword("FROM") { let (relation, joins): (Option<Box<ASTNode>>, Vec<Join>) = if self.parse_keyword("FROM") {
@ -1142,7 +1162,7 @@ impl Parser {
next_token next_token
)) ))
} else { } else {
Ok(ASTNode::SQLSelect(SQLSelect { Ok(SQLSelect {
projection, projection,
selection, selection,
relation, relation,
@ -1151,7 +1171,7 @@ impl Parser {
order_by, order_by,
group_by, group_by,
having, having,
})) })
} }
} }
@ -1290,7 +1310,7 @@ impl Parser {
} }
/// Parse an INSERT statement /// Parse an INSERT statement
pub fn parse_insert(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_insert(&mut self) -> Result<SQLStatement, ParserError> {
self.expect_keyword("INTO")?; self.expect_keyword("INTO")?;
let table_name = self.parse_tablename()?; let table_name = self.parse_tablename()?;
let columns = if self.consume_token(&Token::LParen) { let columns = if self.consume_token(&Token::LParen) {
@ -1304,7 +1324,7 @@ impl Parser {
self.expect_token(&Token::LParen)?; self.expect_token(&Token::LParen)?;
let values = self.parse_expr_list()?; let values = self.parse_expr_list()?;
self.expect_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
Ok(ASTNode::SQLInsert { Ok(SQLStatement::SQLInsert {
table_name, table_name,
columns, columns,
values: vec![values], values: vec![values],

View file

@ -9,7 +9,7 @@ use sqlparser::sqltokenizer::*;
#[test] #[test]
fn parse_simple_select() { fn parse_simple_select() {
let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1"); let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1");
let ast = parse_sql(&sql); let ast = parse_sql_expr(&sql);
match ast { match ast {
ASTNode::SQLSelect(SQLSelect { projection, .. }) => { ASTNode::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(3, projection.len()); assert_eq!(3, projection.len());
@ -18,7 +18,7 @@ fn parse_simple_select() {
} }
} }
fn parse_sql(sql: &str) -> ASTNode { fn parse_sql_expr(sql: &str) -> ASTNode {
let dialect = AnsiSqlDialect {}; let dialect = AnsiSqlDialect {};
let mut tokenizer = Tokenizer::new(&dialect, &sql); let mut tokenizer = Tokenizer::new(&dialect, &sql);
let tokens = tokenizer.tokenize().unwrap(); let tokens = tokenizer.tokenize().unwrap();

View file

@ -10,8 +10,8 @@ use sqlparser::sqltokenizer::*;
fn parse_delete_statement() { fn parse_delete_statement() {
let sql: &str = "DELETE FROM 'table'"; let sql: &str = "DELETE FROM 'table'";
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLDelete { relation, .. } => { SQLStatement::SQLDelete { relation, .. } => {
assert_eq!( assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string() "table".to_string()
@ -31,8 +31,8 @@ fn parse_where_delete_statement() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLOperator::*;
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLDelete { SQLStatement::SQLDelete {
relation, relation,
selection, selection,
.. ..
@ -61,8 +61,8 @@ fn parse_where_delete_statement() {
#[test] #[test]
fn parse_simple_select() { fn parse_simple_select() {
let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT 5"); let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT 5");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { SQLStatement::SQLSelect(SQLSelect {
projection, limit, .. projection, limit, ..
}) => { }) => {
assert_eq!(3, projection.len()); assert_eq!(3, projection.len());
@ -75,8 +75,8 @@ fn parse_simple_select() {
#[test] #[test]
fn parse_select_wildcard() { fn parse_select_wildcard() {
let sql = String::from("SELECT * FROM customer"); let sql = String::from("SELECT * FROM customer");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { projection, .. }) => { SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(1, projection.len()); assert_eq!(1, projection.len());
assert_eq!(ASTNode::SQLWildcard, projection[0]); assert_eq!(ASTNode::SQLWildcard, projection[0]);
} }
@ -87,8 +87,8 @@ fn parse_select_wildcard() {
#[test] #[test]
fn parse_select_count_wildcard() { fn parse_select_count_wildcard() {
let sql = String::from("SELECT COUNT(*) FROM customer"); let sql = String::from("SELECT COUNT(*) FROM customer");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { projection, .. }) => { SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(1, projection.len()); assert_eq!(1, projection.len());
assert_eq!( assert_eq!(
ASTNode::SQLFunction { ASTNode::SQLFunction {
@ -108,7 +108,7 @@ fn parse_not() {
"SELECT id FROM customer \ "SELECT id FROM customer \
WHERE NOT salary = ''", WHERE NOT salary = ''",
); );
let _ast = verified(&sql); let _ast = verified_stmt(&sql);
//TODO: add assertions //TODO: add assertions
} }
@ -118,14 +118,14 @@ fn parse_select_string_predicate() {
"SELECT id, fname, lname FROM customer \ "SELECT id, fname, lname FROM customer \
WHERE salary != 'Not Provided' AND salary != ''", WHERE salary != 'Not Provided' AND salary != ''",
); );
let _ast = verified(&sql); let _ast = verified_stmt(&sql);
//TODO: add assertions //TODO: add assertions
} }
#[test] #[test]
fn parse_projection_nested_type() { fn parse_projection_nested_type() {
let sql = String::from("SELECT customer.address.state FROM foo"); let sql = String::from("SELECT customer.address.state FROM foo");
let _ast = verified(&sql); let _ast = verified_stmt(&sql);
//TODO: add assertions //TODO: add assertions
} }
@ -144,7 +144,7 @@ fn parse_compound_expr_1() {
right: Box::new(SQLIdentifier("c".to_string())) right: Box::new(SQLIdentifier("c".to_string()))
}) })
}, },
verified(&sql) verified_expr(&sql)
); );
} }
@ -163,7 +163,7 @@ fn parse_compound_expr_2() {
op: Plus, op: Plus,
right: Box::new(SQLIdentifier("c".to_string())) right: Box::new(SQLIdentifier("c".to_string()))
}, },
verified(&sql) verified_expr(&sql)
); );
} }
@ -173,7 +173,7 @@ fn parse_is_null() {
let sql = String::from("a IS NULL"); let sql = String::from("a IS NULL");
assert_eq!( assert_eq!(
SQLIsNull(Box::new(SQLIdentifier("a".to_string()))), SQLIsNull(Box::new(SQLIdentifier("a".to_string()))),
verified(&sql) verified_expr(&sql)
); );
} }
@ -183,15 +183,15 @@ fn parse_is_not_null() {
let sql = String::from("a IS NOT NULL"); let sql = String::from("a IS NOT NULL");
assert_eq!( assert_eq!(
SQLIsNotNull(Box::new(SQLIdentifier("a".to_string()))), SQLIsNotNull(Box::new(SQLIdentifier("a".to_string()))),
verified(&sql) verified_expr(&sql)
); );
} }
#[test] #[test]
fn parse_like() { fn parse_like() {
let sql = String::from("SELECT * FROM customers WHERE name LIKE '%a'"); let sql = String::from("SELECT * FROM customers WHERE name LIKE '%a'");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { selection, .. }) => { SQLStatement::SQLSelect(SQLSelect { selection, .. }) => {
assert_eq!( assert_eq!(
ASTNode::SQLBinaryExpr { ASTNode::SQLBinaryExpr {
left: Box::new(ASTNode::SQLIdentifier("name".to_string())), left: Box::new(ASTNode::SQLIdentifier("name".to_string())),
@ -210,8 +210,8 @@ fn parse_like() {
#[test] #[test]
fn parse_not_like() { fn parse_not_like() {
let sql = String::from("SELECT * FROM customers WHERE name NOT LIKE '%a'"); let sql = String::from("SELECT * FROM customers WHERE name NOT LIKE '%a'");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { selection, .. }) => { SQLStatement::SQLSelect(SQLSelect { selection, .. }) => {
assert_eq!( assert_eq!(
ASTNode::SQLBinaryExpr { ASTNode::SQLBinaryExpr {
left: Box::new(ASTNode::SQLIdentifier("name".to_string())), left: Box::new(ASTNode::SQLIdentifier("name".to_string())),
@ -230,8 +230,8 @@ fn parse_not_like() {
#[test] #[test]
fn parse_select_order_by() { fn parse_select_order_by() {
fn chk(sql: &str) { fn chk(sql: &str) {
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { order_by, .. }) => { SQLStatement::SQLSelect(SQLSelect { order_by, .. }) => {
assert_eq!( assert_eq!(
Some(vec![ Some(vec![
SQLOrderByExpr { SQLOrderByExpr {
@ -263,9 +263,8 @@ fn parse_select_order_by_limit() {
let sql = String::from( let sql = String::from(
"SELECT id, fname, lname FROM customer WHERE id < 5 ORDER BY lname ASC, fname DESC LIMIT 2", "SELECT id, fname, lname FROM customer WHERE id < 5 ORDER BY lname ASC, fname DESC LIMIT 2",
); );
let ast = parse_sql(&sql); match verified_stmt(&sql) {
match ast { SQLStatement::SQLSelect(SQLSelect {
ASTNode::SQLSelect(SQLSelect {
order_by, limit, .. order_by, limit, ..
}) => { }) => {
assert_eq!( assert_eq!(
@ -290,8 +289,8 @@ fn parse_select_order_by_limit() {
#[test] #[test]
fn parse_select_group_by() { fn parse_select_group_by() {
let sql = String::from("SELECT id, fname, lname FROM customer GROUP BY lname, fname"); let sql = String::from("SELECT id, fname, lname FROM customer GROUP BY lname, fname");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { group_by, .. }) => { SQLStatement::SQLSelect(SQLSelect { group_by, .. }) => {
assert_eq!( assert_eq!(
Some(vec![ Some(vec![
ASTNode::SQLIdentifier("lname".to_string()), ASTNode::SQLIdentifier("lname".to_string()),
@ -306,7 +305,7 @@ fn parse_select_group_by() {
#[test] #[test]
fn parse_limit_accepts_all() { fn parse_limit_accepts_all() {
parses_to( one_statement_parses_to(
"SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT ALL", "SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT ALL",
"SELECT id, fname, lname FROM customer WHERE id = 1", "SELECT id, fname, lname FROM customer WHERE id = 1",
); );
@ -315,8 +314,8 @@ fn parse_limit_accepts_all() {
#[test] #[test]
fn parse_cast() { fn parse_cast() {
let sql = String::from("SELECT CAST(id AS bigint) FROM customer"); let sql = String::from("SELECT CAST(id AS bigint) FROM customer");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { projection, .. }) => { SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(1, projection.len()); assert_eq!(1, projection.len());
assert_eq!( assert_eq!(
ASTNode::SQLCast { ASTNode::SQLCast {
@ -328,7 +327,7 @@ fn parse_cast() {
} }
_ => assert!(false), _ => assert!(false),
} }
parses_to( one_statement_parses_to(
"SELECT CAST(id AS BIGINT) FROM customer", "SELECT CAST(id AS BIGINT) FROM customer",
"SELECT CAST(id AS bigint) FROM customer", "SELECT CAST(id AS bigint) FROM customer",
); );
@ -342,15 +341,15 @@ fn parse_create_table() {
lat DOUBLE NULL,\ lat DOUBLE NULL,\
lng DOUBLE NULL)", lng DOUBLE NULL)",
); );
parses_to( let ast = one_statement_parses_to(
&sql, &sql,
"CREATE TABLE uk_cities (\ "CREATE TABLE uk_cities (\
name character varying(100) NOT NULL, \ name character varying(100) NOT NULL, \
lat double, \ lat double, \
lng double)", lng double)",
); );
match parse_sql(&sql) { match ast {
ASTNode::SQLCreateTable { name, columns } => { SQLStatement::SQLCreateTable { name, columns } => {
assert_eq!("uk_cities", name); assert_eq!("uk_cities", name);
assert_eq!(3, columns.len()); assert_eq!(3, columns.len());
@ -376,8 +375,8 @@ fn parse_create_table() {
#[test] #[test]
fn parse_scalar_function_in_projection() { fn parse_scalar_function_in_projection() {
let sql = String::from("SELECT sqrt(id) FROM foo"); let sql = String::from("SELECT sqrt(id) FROM foo");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { projection, .. }) => { SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!( assert_eq!(
vec![ASTNode::SQLFunction { vec![ASTNode::SQLFunction {
id: String::from("sqrt"), id: String::from("sqrt"),
@ -393,15 +392,15 @@ fn parse_scalar_function_in_projection() {
#[test] #[test]
fn parse_aggregate_with_group_by() { fn parse_aggregate_with_group_by() {
let sql = String::from("SELECT a, COUNT(1), MIN(b), MAX(b) FROM foo GROUP BY a"); let sql = String::from("SELECT a, COUNT(1), MIN(b), MAX(b) FROM foo GROUP BY a");
let _ast = verified(&sql); let _ast = verified_stmt(&sql);
//TODO: assertions //TODO: assertions
} }
#[test] #[test]
fn parse_literal_string() { fn parse_literal_string() {
let sql = "SELECT 'one'"; let sql = "SELECT 'one'";
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { ref projection, .. }) => { SQLStatement::SQLSelect(SQLSelect { ref projection, .. }) => {
assert_eq!( assert_eq!(
projection[0], projection[0],
ASTNode::SQLValue(Value::SingleQuotedString("one".to_string())) ASTNode::SQLValue(Value::SingleQuotedString("one".to_string()))
@ -414,20 +413,20 @@ fn parse_literal_string() {
#[test] #[test]
fn parse_simple_math_expr_plus() { fn parse_simple_math_expr_plus() {
let sql = "SELECT a + b, 2 + a, 2.5 + a, a_f + b_f, 2 + a_f, 2.5 + a_f FROM c"; let sql = "SELECT a + b, 2 + a, 2.5 + a, a_f + b_f, 2 + a_f, 2.5 + a_f FROM c";
parse_sql(&sql); verified_stmt(&sql);
} }
#[test] #[test]
fn parse_simple_math_expr_minus() { fn parse_simple_math_expr_minus() {
let sql = "SELECT a - b, 2 - a, 2.5 - a, a_f - b_f, 2 - a_f, 2.5 - a_f FROM c"; let sql = "SELECT a - b, 2 - a, 2.5 - a, a_f - b_f, 2 - a_f, 2.5 - a_f FROM c";
parse_sql(&sql); verified_stmt(&sql);
} }
#[test] #[test]
fn parse_select_version() { fn parse_select_version() {
let sql = "SELECT @@version"; let sql = "SELECT @@version";
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { ref projection, .. }) => { SQLStatement::SQLSelect(SQLSelect { ref projection, .. }) => {
assert_eq!( assert_eq!(
projection[0], projection[0],
ASTNode::SQLIdentifier("@@version".to_string()) ASTNode::SQLIdentifier("@@version".to_string())
@ -442,7 +441,7 @@ fn parse_parens() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLOperator::*;
let sql = "(a + b) - (c + d)"; let sql = "(a + b) - (c + d)";
let ast = parse_sql(&sql); let ast = parse_sql_expr(&sql);
assert_eq!( assert_eq!(
SQLBinaryExpr { SQLBinaryExpr {
left: Box::new(SQLBinaryExpr { left: Box::new(SQLBinaryExpr {
@ -464,12 +463,10 @@ fn parse_parens() {
#[test] #[test]
fn parse_case_expression() { fn parse_case_expression() {
let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo"; let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo";
let ast = parse_sql(&sql);
assert_eq!(sql, ast.to_string());
use self::ASTNode::{SQLBinaryExpr, SQLCase, SQLIdentifier, SQLIsNull, SQLValue}; use self::ASTNode::{SQLBinaryExpr, SQLCase, SQLIdentifier, SQLIsNull, SQLValue};
use self::SQLOperator::*; use self::SQLOperator::*;
match ast { match verified_stmt(&sql) {
ASTNode::SQLSelect(SQLSelect { projection, .. }) => { SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(1, projection.len()); assert_eq!(1, projection.len());
assert_eq!( assert_eq!(
SQLCase { SQLCase {
@ -505,9 +502,8 @@ fn parse_case_expression() {
#[test] #[test]
fn parse_select_with_semi_colon() { fn parse_select_with_semi_colon() {
let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1;"); let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1;");
let ast = parse_sql(&sql); match one_statement_parses_to(&sql, "") {
match ast { SQLStatement::SQLSelect(SQLSelect { projection, .. }) => {
ASTNode::SQLSelect(SQLSelect { projection, .. }) => {
assert_eq!(3, projection.len()); assert_eq!(3, projection.len());
} }
_ => assert!(false), _ => assert!(false),
@ -518,8 +514,8 @@ fn parse_select_with_semi_colon() {
fn parse_delete_with_semi_colon() { fn parse_delete_with_semi_colon() {
let sql: &str = "DELETE FROM 'table';"; let sql: &str = "DELETE FROM 'table';";
match parse_sql(&sql) { match one_statement_parses_to(&sql, "") {
ASTNode::SQLDelete { relation, .. } => { SQLStatement::SQLDelete { relation, .. } => {
assert_eq!( assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string() "table".to_string()
@ -535,8 +531,8 @@ fn parse_delete_with_semi_colon() {
fn parse_implicit_join() { fn parse_implicit_join() {
let sql = "SELECT * FROM t1, t2"; let sql = "SELECT * FROM t1, t2";
match verified(sql) { match verified_stmt(sql) {
ASTNode::SQLSelect(SQLSelect { joins, .. }) => { SQLStatement::SQLSelect(SQLSelect { joins, .. }) => {
assert_eq!(joins.len(), 1); assert_eq!(joins.len(), 1);
assert_eq!( assert_eq!(
joins[0], joins[0],
@ -557,8 +553,8 @@ fn parse_implicit_join() {
fn parse_cross_join() { fn parse_cross_join() {
let sql = "SELECT * FROM t1 CROSS JOIN t2"; let sql = "SELECT * FROM t1 CROSS JOIN t2";
match verified(sql) { match verified_stmt(sql) {
ASTNode::SQLSelect(SQLSelect { joins, .. }) => { SQLStatement::SQLSelect(SQLSelect { joins, .. }) => {
assert_eq!(joins.len(), 1); assert_eq!(joins.len(), 1);
assert_eq!( assert_eq!(
joins[0], joins[0],
@ -596,32 +592,32 @@ fn parse_joins_on() {
} }
// Test parsing of aliases // Test parsing of aliases
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2")), joins_from(verified_stmt("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2")),
vec![join_with_constraint( vec![join_with_constraint(
"t2", "t2",
Some("foo".to_string()), Some("foo".to_string()),
JoinOperator::Inner JoinOperator::Inner
)] )]
); );
parses_to( one_statement_parses_to(
"SELECT * FROM t1 JOIN t2 foo ON c1 = c2", "SELECT * FROM t1 JOIN t2 foo ON c1 = c2",
"SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2", "SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2",
); );
// Test parsing of different join operators // Test parsing of different join operators
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 JOIN t2 ON c1 = c2")), joins_from(verified_stmt("SELECT * FROM t1 JOIN t2 ON c1 = c2")),
vec![join_with_constraint("t2", None, JoinOperator::Inner)] vec![join_with_constraint("t2", None, JoinOperator::Inner)]
); );
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2")), joins_from(verified_stmt("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2")),
vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)]
); );
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2")), joins_from(verified_stmt("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2")),
vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] vec![join_with_constraint("t2", None, JoinOperator::RightOuter)]
); );
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2")), joins_from(verified_stmt("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2")),
vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] vec![join_with_constraint("t2", None, JoinOperator::FullOuter)]
); );
} }
@ -643,32 +639,32 @@ fn parse_joins_using() {
} }
// Test parsing of aliases // Test parsing of aliases
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 JOIN t2 AS foo USING(c1)")), joins_from(verified_stmt("SELECT * FROM t1 JOIN t2 AS foo USING(c1)")),
vec![join_with_constraint( vec![join_with_constraint(
"t2", "t2",
Some("foo".to_string()), Some("foo".to_string()),
JoinOperator::Inner JoinOperator::Inner
)] )]
); );
parses_to( one_statement_parses_to(
"SELECT * FROM t1 JOIN t2 foo USING(c1)", "SELECT * FROM t1 JOIN t2 foo USING(c1)",
"SELECT * FROM t1 JOIN t2 AS foo USING(c1)", "SELECT * FROM t1 JOIN t2 AS foo USING(c1)",
); );
// Test parsing of different join operators // Test parsing of different join operators
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 JOIN t2 USING(c1)")), joins_from(verified_stmt("SELECT * FROM t1 JOIN t2 USING(c1)")),
vec![join_with_constraint("t2", None, JoinOperator::Inner)] vec![join_with_constraint("t2", None, JoinOperator::Inner)]
); );
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 USING(c1)")), joins_from(verified_stmt("SELECT * FROM t1 LEFT JOIN t2 USING(c1)")),
vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)]
); );
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)")), joins_from(verified_stmt("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)")),
vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] vec![join_with_constraint("t2", None, JoinOperator::RightOuter)]
); );
assert_eq!( assert_eq!(
joins_from(verified("SELECT * FROM t1 FULL JOIN t2 USING(c1)")), joins_from(verified_stmt("SELECT * FROM t1 FULL JOIN t2 USING(c1)")),
vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] vec![join_with_constraint("t2", None, JoinOperator::FullOuter)]
); );
} }
@ -676,54 +672,68 @@ fn parse_joins_using() {
#[test] #[test]
fn parse_complex_join() { fn parse_complex_join() {
let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c"; let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c";
assert_eq!(sql, parse_sql(sql).to_string()); verified_stmt(sql);
} }
#[test] #[test]
fn parse_join_syntax_variants() { fn parse_join_syntax_variants() {
parses_to( one_statement_parses_to(
"SELECT c1 FROM t1 INNER JOIN t2 USING(c1)", "SELECT c1 FROM t1 INNER JOIN t2 USING(c1)",
"SELECT c1 FROM t1 JOIN t2 USING(c1)", "SELECT c1 FROM t1 JOIN t2 USING(c1)",
); );
parses_to( one_statement_parses_to(
"SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)",
"SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)", "SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)",
); );
parses_to( one_statement_parses_to(
"SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)",
"SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)", "SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)",
); );
parses_to( one_statement_parses_to(
"SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)",
"SELECT c1 FROM t1 FULL JOIN t2 USING(c1)", "SELECT c1 FROM t1 FULL JOIN t2 USING(c1)",
); );
} }
fn verified(query: &str) -> ASTNode { fn verified_stmt(query: &str) -> SQLStatement {
let ast = parse_sql(query); one_statement_parses_to(query, query)
}
fn verified_expr(query: &str) -> ASTNode {
let ast = parse_sql_expr(query);
assert_eq!(query, &ast.to_string()); assert_eq!(query, &ast.to_string());
ast ast
} }
fn parses_to(from: &str, to: &str) { fn joins_from(ast: SQLStatement) -> Vec<Join> {
assert_eq!(to, &parse_sql(from).to_string())
}
fn joins_from(ast: ASTNode) -> Vec<Join> {
match ast { match ast {
ASTNode::SQLSelect(SQLSelect { joins, .. }) => joins, SQLStatement::SQLSelect(SQLSelect { joins, .. }) => joins,
_ => panic!("Expected SELECT"), _ => panic!("Expected SELECT"),
} }
} }
fn parse_sql(sql: &str) -> ASTNode { /// Ensures that `sql` parses as a statement, optionally checking that
let generic_ast = parse_sql_with(sql, &GenericSqlDialect {}); /// converting AST back to string equals to `canonical` (unless an empty string
let pg_ast = parse_sql_with(sql, &PostgreSqlDialect {}); /// is provided).
fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement {
let generic_ast = Parser::parse_sql(&GenericSqlDialect {}, sql.to_string()).unwrap();
let pg_ast = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()).unwrap();
assert_eq!(generic_ast, pg_ast);
if !canonical.is_empty() {
assert_eq!(canonical, generic_ast.to_string())
}
generic_ast
}
fn parse_sql_expr(sql: &str) -> ASTNode {
let generic_ast = parse_sql_expr_with(&GenericSqlDialect {}, &sql.to_string());
let pg_ast = parse_sql_expr_with(&PostgreSqlDialect {}, &sql.to_string());
assert_eq!(generic_ast, pg_ast); assert_eq!(generic_ast, pg_ast);
generic_ast generic_ast
} }
fn parse_sql_with(sql: &str, dialect: &Dialect) -> ASTNode { fn parse_sql_expr_with(dialect: &Dialect, sql: &str) -> ASTNode {
let mut tokenizer = Tokenizer::new(dialect, &sql); let mut tokenizer = Tokenizer::new(dialect, &sql);
let tokens = tokenizer.tokenize().unwrap(); let tokens = tokenizer.tokenize().unwrap();
let mut parser = Parser::new(tokens); let mut parser = Parser::new(tokens);

View file

@ -24,8 +24,8 @@ fn test_prev_index() {
#[test] #[test]
fn parse_simple_insert() { fn parse_simple_insert() {
let sql = String::from("INSERT INTO customer VALUES(1, 2, 3)"); let sql = String::from("INSERT INTO customer VALUES(1, 2, 3)");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLInsert { SQLStatement::SQLInsert {
table_name, table_name,
columns, columns,
values, values,
@ -49,8 +49,8 @@ fn parse_simple_insert() {
#[test] #[test]
fn parse_common_insert() { fn parse_common_insert() {
let sql = String::from("INSERT INTO public.customer VALUES(1, 2, 3)"); let sql = String::from("INSERT INTO public.customer VALUES(1, 2, 3)");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLInsert { SQLStatement::SQLInsert {
table_name, table_name,
columns, columns,
values, values,
@ -74,8 +74,8 @@ fn parse_common_insert() {
#[test] #[test]
fn parse_complex_insert() { fn parse_complex_insert() {
let sql = String::from("INSERT INTO db.public.customer VALUES(1, 2, 3)"); let sql = String::from("INSERT INTO db.public.customer VALUES(1, 2, 3)");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLInsert { SQLStatement::SQLInsert {
table_name, table_name,
columns, columns,
values, values,
@ -113,8 +113,8 @@ fn parse_no_table_name() {
#[test] #[test]
fn parse_insert_with_columns() { fn parse_insert_with_columns() {
let sql = String::from("INSERT INTO public.customer (id, name, active) VALUES(1, 2, 3)"); let sql = String::from("INSERT INTO public.customer (id, name, active) VALUES(1, 2, 3)");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLInsert { SQLStatement::SQLInsert {
table_name, table_name,
columns, columns,
values, values,
@ -141,8 +141,7 @@ fn parse_insert_with_columns() {
#[test] #[test]
fn parse_insert_invalid() { fn parse_insert_invalid() {
let sql = String::from("INSERT public.customer (id, name, active) VALUES (1, 2, 3)"); let sql = String::from("INSERT public.customer (id, name, active) VALUES (1, 2, 3)");
let mut parser = parser(&sql); match Parser::parse_sql(&PostgreSqlDialect {}, sql) {
match parser.parse() {
Err(_) => {} Err(_) => {}
_ => assert!(false), _ => assert!(false),
} }
@ -163,8 +162,8 @@ fn parse_create_table_with_defaults() {
last_update timestamp without time zone DEFAULT now() NOT NULL, last_update timestamp without time zone DEFAULT now() NOT NULL,
active integer NOT NULL)", active integer NOT NULL)",
); );
match parse_sql(&sql) { match one_statement_parses_to(&sql, "") {
ASTNode::SQLCreateTable { name, columns } => { SQLStatement::SQLCreateTable { name, columns } => {
assert_eq!("public.customer", name); assert_eq!("public.customer", name);
assert_eq!(10, columns.len()); assert_eq!(10, columns.len());
@ -204,9 +203,8 @@ fn parse_create_table_from_pg_dump() {
release_year public.year, release_year public.year,
active integer active integer
)"); )");
let ast = parse_sql(&sql); match one_statement_parses_to(&sql, "") {
match ast { SQLStatement::SQLCreateTable { name, columns } => {
ASTNode::SQLCreateTable { name, columns } => {
assert_eq!("public.customer", name); assert_eq!("public.customer", name);
let c_customer_id = &columns[0]; let c_customer_id = &columns[0];
@ -259,8 +257,8 @@ fn parse_create_table_with_inherit() {
use_metric boolean DEFAULT true\ use_metric boolean DEFAULT true\
)", )",
); );
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLCreateTable { name, columns } => { SQLStatement::SQLCreateTable { name, columns } => {
assert_eq!("bazaar.settings", name); assert_eq!("bazaar.settings", name);
let c_name = &columns[0]; let c_name = &columns[0];
@ -288,8 +286,8 @@ fn parse_alter_table_constraint_primary_key() {
ALTER TABLE bazaar.address \ ALTER TABLE bazaar.address \
ADD CONSTRAINT address_pkey PRIMARY KEY (address_id)", ADD CONSTRAINT address_pkey PRIMARY KEY (address_id)",
); );
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLAlterTable { name, .. } => { SQLStatement::SQLAlterTable { name, .. } => {
assert_eq!(name, "bazaar.address"); assert_eq!(name, "bazaar.address");
} }
_ => assert!(false), _ => assert!(false),
@ -301,8 +299,8 @@ fn parse_alter_table_constraint_foreign_key() {
let sql = String::from("\ let sql = String::from("\
ALTER TABLE public.customer \ ALTER TABLE public.customer \
ADD CONSTRAINT customer_address_id_fkey FOREIGN KEY (address_id) REFERENCES public.address(address_id)"); ADD CONSTRAINT customer_address_id_fkey FOREIGN KEY (address_id) REFERENCES public.address(address_id)");
match verified(&sql) { match verified_stmt(&sql) {
ASTNode::SQLAlterTable { name, .. } => { SQLStatement::SQLAlterTable { name, .. } => {
assert_eq!(name, "public.customer"); assert_eq!(name, "public.customer");
} }
_ => assert!(false), _ => assert!(false),
@ -331,7 +329,7 @@ Kwara & Kogi
PHP USD $ PHP USD $
\N Some other value \N Some other value
\\."#); \\."#);
let ast = parse_sql(&sql); let ast = one_statement_parses_to(&sql, "");
println!("{:#?}", ast); println!("{:#?}", ast);
//assert_eq!(sql, ast.to_string()); //assert_eq!(sql, ast.to_string());
} }
@ -339,7 +337,7 @@ PHP ₱ USD $
#[test] #[test]
fn parse_timestamps_example() { fn parse_timestamps_example() {
let sql = "2016-02-15 09:43:33"; let sql = "2016-02-15 09:43:33";
let _ = parse_sql(sql); let _ = parse_sql_expr(sql);
//TODO add assertion //TODO add assertion
//assert_eq!(sql, ast.to_string()); //assert_eq!(sql, ast.to_string());
} }
@ -347,7 +345,7 @@ fn parse_timestamps_example() {
#[test] #[test]
fn parse_timestamps_with_millis_example() { fn parse_timestamps_with_millis_example() {
let sql = "2017-11-02 19:15:42.308637"; let sql = "2017-11-02 19:15:42.308637";
let _ = parse_sql(sql); let _ = parse_sql_expr(sql);
//TODO add assertion //TODO add assertion
//assert_eq!(sql, ast.to_string()); //assert_eq!(sql, ast.to_string());
} }
@ -355,24 +353,33 @@ fn parse_timestamps_with_millis_example() {
#[test] #[test]
fn parse_example_value() { fn parse_example_value() {
let sql = "SARAH.LEWIS@sakilacustomer.org"; let sql = "SARAH.LEWIS@sakilacustomer.org";
let ast = parse_sql(sql); let ast = parse_sql_expr(sql);
assert_eq!(sql, ast.to_string()); assert_eq!(sql, ast.to_string());
} }
#[test] #[test]
fn parse_function_now() { fn parse_function_now() {
let sql = "now()"; let sql = "now()";
let ast = parse_sql(sql); let ast = parse_sql_expr(sql);
assert_eq!(sql, ast.to_string()); assert_eq!(sql, ast.to_string());
} }
fn verified(query: &str) -> ASTNode { fn verified_stmt(query: &str) -> SQLStatement {
let ast = parse_sql(query); one_statement_parses_to(query, query)
assert_eq!(query, &ast.to_string());
ast
} }
fn parse_sql(sql: &str) -> ASTNode { /// Ensures that `sql` parses as a single statement, optionally checking that
/// converting AST back to string equals to `canonical` (unless an empty string
/// is provided).
fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement {
let only_statement = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()).unwrap();
if !canonical.is_empty() {
assert_eq!(canonical, only_statement.to_string())
}
only_statement
}
fn parse_sql_expr(sql: &str) -> ASTNode {
debug!("sql: {}", sql); debug!("sql: {}", sql);
let mut parser = parser(sql); let mut parser = parser(sql);
let ast = parser.parse().unwrap(); let ast = parser.parse().unwrap();