diff --git a/Cargo.toml b/Cargo.toml index b24348b6..15376f2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ path = "src/lib.rs" [dependencies] log = "0.4.5" -chrono = "0.4.6" [dev-dependencies] simple_logger = "1.0.1" diff --git a/src/dialect/ansi_sql.rs b/src/dialect/ansi_sql.rs index 522ed0ef..939f8546 100644 --- a/src/dialect/ansi_sql.rs +++ b/src/dialect/ansi_sql.rs @@ -1,5 +1,6 @@ use crate::dialect::Dialect; +#[derive(Debug)] pub struct AnsiSqlDialect {} impl Dialect for AnsiSqlDialect { diff --git a/src/dialect/generic_sql.rs b/src/dialect/generic_sql.rs index fe48ab2d..09de4745 100644 --- a/src/dialect/generic_sql.rs +++ b/src/dialect/generic_sql.rs @@ -1,9 +1,11 @@ use crate::dialect::Dialect; + +#[derive(Debug)] pub struct GenericSqlDialect {} impl Dialect for GenericSqlDialect { fn is_identifier_start(&self, ch: char) -> bool { - (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '@' + (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' || ch == '#' || ch == '@' } fn is_identifier_part(&self, ch: char) -> bool { @@ -11,6 +13,8 @@ impl Dialect for GenericSqlDialect { || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '@' + || ch == '$' + || ch == '#' || ch == '_' } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 95ecf792..df6f2603 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -1,13 +1,17 @@ mod ansi_sql; mod generic_sql; pub mod keywords; +mod mssql; mod postgresql; +use std::fmt::Debug; + pub use self::ansi_sql::AnsiSqlDialect; pub use self::generic_sql::GenericSqlDialect; +pub use self::mssql::MsSqlDialect; pub use self::postgresql::PostgreSqlDialect; -pub trait Dialect { +pub trait Dialect: Debug { /// Determine if a character starts a quoted identifier. The default /// implementation, accepting "double quoted" ids is both ANSI-compliant /// and appropriate for most dialects (with the notable exception of diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs new file mode 100644 index 00000000..65daeb8d --- /dev/null +++ b/src/dialect/mssql.rs @@ -0,0 +1,22 @@ +use crate::dialect::Dialect; + +#[derive(Debug)] +pub struct MsSqlDialect {} + +impl Dialect for MsSqlDialect { + fn is_identifier_start(&self, ch: char) -> bool { + // See https://docs.microsoft.com/en-us/sql/relational-databases/databases/database-identifiers?view=sql-server-2017#rules-for-regular-identifiers + // We don't support non-latin "letters" currently. + (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' || ch == '#' || ch == '@' + } + + fn is_identifier_part(&self, ch: char) -> bool { + (ch >= 'a' && ch <= 'z') + || (ch >= 'A' && ch <= 'Z') + || (ch >= '0' && ch <= '9') + || ch == '@' + || ch == '$' + || ch == '#' + || ch == '_' + } +} diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index dac3740d..5433b440 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -1,17 +1,21 @@ use crate::dialect::Dialect; +#[derive(Debug)] pub struct PostgreSqlDialect {} impl Dialect for PostgreSqlDialect { fn is_identifier_start(&self, ch: char) -> bool { - (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '@' + // See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + // We don't yet support identifiers beginning with "letters with + // diacritical marks and non-Latin letters" + (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' } fn is_identifier_part(&self, ch: char) -> bool { (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') - || ch == '@' + || ch == '$' || ch == '_' } } diff --git a/src/lib.rs b/src/lib.rs index fc90be85..420e9c3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,3 +40,8 @@ pub mod dialect; pub mod sqlast; pub mod sqlparser; pub mod sqltokenizer; + +#[doc(hidden)] +// This is required to make utilities accessible by both the crate-internal +// unit-tests and by the integration tests +pub mod test_utils; diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 0a5bb6ba..81668b2f 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -41,8 +41,11 @@ fn comma_separated_string(vec: &[T]) -> String { /// Identifier name, in the originally quoted form (e.g. `"id"`) pub type SQLIdent = String; -/// Represents a parsed SQL expression, which is a common building -/// block of SQL statements (the part after SELECT, WHERE, etc.) +/// An SQL expression of any type. +/// +/// The parser does not distinguish between expressions of different types +/// (e.g. boolean vs string), so the caller must handle expressions of +/// inappropriate type, like `WHERE 1` or `SELECT 1=1`, as necessary. #[derive(Debug, Clone, PartialEq)] pub enum ASTNode { /// Identifier e.g. table name or column name @@ -72,7 +75,7 @@ pub enum ASTNode { subquery: Box, negated: bool, }, - /// [ NOT ] BETWEEN AND + /// ` [ NOT ] BETWEEN AND ` SQLBetween { expr: Box, negated: bool, diff --git a/src/sqlast/value.rs b/src/sqlast/value.rs index a36f8d27..1ca4b62b 100644 --- a/src/sqlast/value.rs +++ b/src/sqlast/value.rs @@ -1,5 +1,3 @@ -use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime}; - /// SQL values such as int, double, string, timestamp #[derive(Debug, Clone, PartialEq)] pub enum Value { @@ -13,14 +11,6 @@ pub enum Value { NationalStringLiteral(String), /// Boolean value true or false, Boolean(bool), - /// Date value - Date(NaiveDate), - // Time - Time(NaiveTime), - /// Date and time - DateTime(NaiveDateTime), - /// Timstamp with time zone - Timestamp(DateTime), /// NULL value in insert statements, Null, } @@ -33,10 +23,6 @@ impl ToString for Value { Value::SingleQuotedString(v) => format!("'{}'", escape_single_quote_string(v)), Value::NationalStringLiteral(v) => format!("N'{}'", v), Value::Boolean(v) => v.to_string(), - Value::Date(v) => v.to_string(), - Value::Time(v) => v.to_string(), - Value::DateTime(v) => v.to_string(), - Value::Timestamp(v) => format!("{}", v), Value::Null => "NULL".to_string(), } } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 3b6f9c28..ab2755fc 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -20,7 +20,6 @@ use super::dialect::keywords; use super::dialect::Dialect; use super::sqlast::*; use super::sqltokenizer::*; -use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime}; #[derive(Debug, Clone, PartialEq)] pub enum ParserError { @@ -922,37 +921,29 @@ impl Parser { /// Parse a literal value (numbers, strings, date/time, booleans) fn parse_value(&mut self) -> Result { match self.next_token() { - Some(t) => { - match t { - Token::SQLWord(k) => match k.keyword.as_ref() { - "TRUE" => Ok(Value::Boolean(true)), - "FALSE" => Ok(Value::Boolean(false)), - "NULL" => Ok(Value::Null), - _ => { - return parser_err!(format!( - "No value parser for keyword {}", - k.keyword - )); - } - }, - //TODO: parse the timestamp here (see parse_timestamp_value()) - Token::Number(ref n) if n.contains('.') => match n.parse::() { - Ok(n) => Ok(Value::Double(n)), - Err(e) => parser_err!(format!("Could not parse '{}' as f64: {}", n, e)), - }, - Token::Number(ref n) => match n.parse::() { - Ok(n) => Ok(Value::Long(n)), - Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), - }, - Token::SingleQuotedString(ref s) => { - Ok(Value::SingleQuotedString(s.to_string())) + Some(t) => match t { + Token::SQLWord(k) => match k.keyword.as_ref() { + "TRUE" => Ok(Value::Boolean(true)), + "FALSE" => Ok(Value::Boolean(false)), + "NULL" => Ok(Value::Null), + _ => { + return parser_err!(format!("No value parser for keyword {}", k.keyword)); } - Token::NationalStringLiteral(ref s) => { - Ok(Value::NationalStringLiteral(s.to_string())) - } - _ => parser_err!(format!("Unsupported value: {:?}", t)), + }, + Token::Number(ref n) if n.contains('.') => match n.parse::() { + Ok(n) => Ok(Value::Double(n)), + Err(e) => parser_err!(format!("Could not parse '{}' as f64: {}", n, e)), + }, + Token::Number(ref n) => match n.parse::() { + Ok(n) => Ok(Value::Long(n)), + Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), + }, + Token::SingleQuotedString(ref s) => Ok(Value::SingleQuotedString(s.to_string())), + Token::NationalStringLiteral(ref s) => { + Ok(Value::NationalStringLiteral(s.to_string())) } - } + _ => parser_err!(format!("Unsupported value: {:?}", t)), + }, None => parser_err!("Expecting a value, but found EOF"), } } @@ -985,86 +976,6 @@ impl Parser { } } - pub fn parse_timezone_offset(&mut self) -> Result { - match self.next_token() { - Some(Token::Plus) => { - let n = self.parse_literal_int()?; - Ok(n as i8) - } - Some(Token::Minus) => { - let n = self.parse_literal_int()?; - Ok(-n as i8) - } - other => parser_err!(format!( - "Expecting `+` or `-` in timezone, but found {:?}", - other - )), - } - } - - pub fn parse_timestamp_value(&mut self) -> Result { - let year = self.parse_literal_int()?; - let date = self.parse_date(year)?; - if let Ok(time) = self.parse_time() { - let date_time = NaiveDateTime::new(date, time); - match self.peek_token() { - Some(token) => match token { - Token::Plus | Token::Minus => { - let tz = self.parse_timezone_offset()?; - let offset = FixedOffset::east(i32::from(tz) * 3600); - Ok(Value::Timestamp(DateTime::from_utc(date_time, offset))) - } - _ => Ok(Value::DateTime(date_time)), - }, - _ => Ok(Value::DateTime(date_time)), - } - } else { - parser_err!(format!( - "Expecting time after date, but found {:?}", - self.peek_token() - )) - } - } - - pub fn parse_date(&mut self, year: i64) -> Result { - if self.consume_token(&Token::Minus) { - let month = self.parse_literal_int()?; - if self.consume_token(&Token::Minus) { - let day = self.parse_literal_int()?; - let date = NaiveDate::from_ymd(year as i32, month as u32, day as u32); - Ok(date) - } else { - parser_err!(format!( - "Expecting `-` for date separator, found {:?}", - self.peek_token() - )) - } - } else { - parser_err!(format!( - "Expecting `-` for date separator, found {:?}", - self.peek_token() - )) - } - } - - pub fn parse_time(&mut self) -> Result { - let hour = self.parse_literal_int()?; - self.expect_token(&Token::Colon)?; - let min = self.parse_literal_int()?; - self.expect_token(&Token::Colon)?; - // On one hand, the SQL specs defines ::= , - // so it would be more correct to parse it as such - let sec = self.parse_literal_double()?; - // On the other, chrono only supports nanoseconds, which should(?) fit in seconds-as-f64... - let nanos = (sec.fract() * 1_000_000_000.0).round(); - Ok(NaiveTime::from_hms_nano( - hour as u32, - min as u32, - sec as u32, - nanos as u32, - )) - } - /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) pub fn parse_data_type(&mut self) -> Result { match self.next_token() { @@ -1671,3 +1582,23 @@ impl SQLWord { self.to_string() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::all_dialects; + + #[test] + fn test_prev_index() { + let sql = "SELECT version()"; + all_dialects().run_parser_method(sql, |parser| { + assert_eq!(parser.prev_token(), None); + assert_eq!(parser.next_token(), Some(Token::make_keyword("SELECT"))); + assert_eq!(parser.next_token(), Some(Token::make_word("version", None))); + assert_eq!(parser.prev_token(), Some(Token::make_word("version", None))); + assert_eq!(parser.peek_token(), Some(Token::make_word("version", None))); + assert_eq!(parser.prev_token(), Some(Token::make_keyword("SELECT"))); + assert_eq!(parser.prev_token(), None); + }); + } +} diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 00000000..16216bfe --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,122 @@ +use std::fmt::Debug; + +use super::dialect::*; +use super::sqlast::*; +use super::sqlparser::{Parser, ParserError}; +use super::sqltokenizer::Tokenizer; + +/// Tests use the methods on this struct to invoke the parser on one or +/// multiple dialects. +pub struct TestedDialects { + pub dialects: Vec>, +} + +impl TestedDialects { + /// Run the given function for all of `self.dialects`, assert that they + /// return the same result, and return that result. + pub fn one_of_identical_results(&self, f: F) -> T + where + F: Fn(&dyn Dialect) -> T, + { + let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect))); + parse_results + .fold(None, |s, (dialect, parsed)| { + if let Some((prev_dialect, prev_parsed)) = s { + assert_eq!( + prev_parsed, parsed, + "Parse results with {:?} are different from {:?}", + prev_dialect, dialect + ); + } + Some((dialect, parsed)) + }) + .unwrap() + .1 + } + + pub fn run_parser_method(&self, sql: &str, f: F) -> T + where + F: Fn(&mut Parser) -> T, + { + self.one_of_identical_results(|dialect| { + let mut tokenizer = Tokenizer::new(dialect, sql); + let tokens = tokenizer.tokenize().unwrap(); + f(&mut Parser::new(tokens)) + }) + } + + pub fn parse_sql_statements(&self, sql: &str) -> Result, ParserError> { + self.one_of_identical_results(|dialect| Parser::parse_sql(dialect, sql.to_string())) + // To fail the `ensure_multiple_dialects_are_tested` test: + // Parser::parse_sql(&**self.dialects.first().unwrap(), sql.to_string()) + } + + /// Ensures that `sql` parses as a single statement, optionally checking + /// that converting AST back to string equals to `canonical` (unless an + /// empty canonical string is provided). + pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> SQLStatement { + let mut statements = self.parse_sql_statements(&sql).unwrap(); + assert_eq!(statements.len(), 1); + + let only_statement = statements.pop().unwrap(); + if !canonical.is_empty() { + assert_eq!(canonical, only_statement.to_string()) + } + only_statement + } + + /// Ensures that `sql` parses as a single SQLStatement, and is not modified + /// after a serialization round-trip. + pub fn verified_stmt(&self, query: &str) -> SQLStatement { + self.one_statement_parses_to(query, query) + } + + /// Ensures that `sql` parses as a single SQLQuery, and is not modified + /// after a serialization round-trip. + pub fn verified_query(&self, sql: &str) -> SQLQuery { + match self.verified_stmt(sql) { + SQLStatement::SQLQuery(query) => *query, + _ => panic!("Expected SQLQuery"), + } + } + + /// Ensures that `sql` parses as a single SQLSelect, and is not modified + /// after a serialization round-trip. + pub fn verified_only_select(&self, query: &str) -> SQLSelect { + match self.verified_query(query).body { + SQLSetExpr::Select(s) => *s, + _ => panic!("Expected SQLSetExpr::Select"), + } + } + + /// Ensures that `sql` parses as an expression, and is not modified + /// after a serialization round-trip. + pub fn verified_expr(&self, sql: &str) -> ASTNode { + let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap(); + assert_eq!(sql, &ast.to_string(), "round-tripping without changes"); + ast + } +} + +pub fn all_dialects() -> TestedDialects { + TestedDialects { + dialects: vec![ + Box::new(GenericSqlDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiSqlDialect {}), + ], + } +} + +pub fn only(v: &[T]) -> &T { + assert_eq!(1, v.len()); + v.first().unwrap() +} + +pub fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode { + match item { + SQLSelectItem::UnnamedExpression(expr) => expr, + _ => panic!("Expected UnnamedExpression"), + } +} diff --git a/tests/sqlparser_ansi.rs b/tests/sqlparser_ansi.rs deleted file mode 100644 index ab80ae92..00000000 --- a/tests/sqlparser_ansi.rs +++ /dev/null @@ -1,24 +0,0 @@ -#![warn(clippy::all)] - -use sqlparser::dialect::AnsiSqlDialect; -use sqlparser::sqlast::*; -use sqlparser::sqlparser::*; - -#[test] -fn parse_simple_select() { - let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1"); - let mut ast = Parser::parse_sql(&AnsiSqlDialect {}, sql).unwrap(); - assert_eq!(1, ast.len()); - match ast.pop().unwrap() { - SQLStatement::SQLQuery(q) => match *q { - SQLQuery { - body: SQLSetExpr::Select(select), - .. - } => { - assert_eq!(3, select.projection.len()); - } - _ => unreachable!(), - }, - _ => unreachable!(), - } -} diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_common.rs similarity index 88% rename from tests/sqlparser_generic.rs rename to tests/sqlparser_common.rs index e1e96972..14498ce2 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_common.rs @@ -1,11 +1,80 @@ #![warn(clippy::all)] +//! Test SQL syntax, which all sqlparser dialects must parse in the same way. +//! +//! Note that it does not mean all SQL here is valid in all the dialects, only +//! that 1) it's either standard or widely supported and 2) it can be parsed by +//! sqlparser regardless of the chosen dialect (i.e. it doesn't conflict with +//! dialect-specific parsing rules). use matches::assert_matches; -use sqlparser::dialect::*; use sqlparser::sqlast::*; use sqlparser::sqlparser::*; -use sqlparser::sqltokenizer::*; +use sqlparser::test_utils::{all_dialects, expr_from_projection, only}; + +#[test] +fn parse_insert_values() { + let sql = "INSERT INTO customer VALUES(1, 2, 3)"; + check_one(sql, "customer", vec![]); + + let sql = "INSERT INTO public.customer VALUES(1, 2, 3)"; + check_one(sql, "public.customer", vec![]); + + let sql = "INSERT INTO db.public.customer VALUES(1, 2, 3)"; + check_one(sql, "db.public.customer", vec![]); + + let sql = "INSERT INTO public.customer (id, name, active) VALUES(1, 2, 3)"; + check_one( + sql, + "public.customer", + vec!["id".to_string(), "name".to_string(), "active".to_string()], + ); + + fn check_one(sql: &str, expected_table_name: &str, expected_columns: Vec) { + match verified_stmt(sql) { + SQLStatement::SQLInsert { + table_name, + columns, + values, + .. + } => { + assert_eq!(table_name.to_string(), expected_table_name); + assert_eq!(columns, expected_columns); + assert_eq!( + vec![vec![ + ASTNode::SQLValue(Value::Long(1)), + ASTNode::SQLValue(Value::Long(2)), + ASTNode::SQLValue(Value::Long(3)) + ]], + values + ); + } + _ => unreachable!(), + } + } +} + +#[test] +fn parse_insert_invalid() { + let sql = "INSERT public.customer (id, name, active) VALUES (1, 2, 3)"; + let res = parse_sql_statements(sql); + assert_eq!( + ParserError::ParserError("Expected INTO, found: public".to_string()), + res.unwrap_err() + ); +} + +#[test] +fn parse_invalid_table_name() { + let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name); + assert!(ast.is_err()); +} + +#[test] +fn parse_no_table_name() { + let ast = all_dialects().run_parser_method("", Parser::parse_object_name); + assert!(ast.is_err()); +} #[test] fn parse_delete_statement() { @@ -452,14 +521,12 @@ fn parse_cast() { #[test] fn parse_create_table() { - let sql = String::from( - "CREATE TABLE uk_cities (\ - name VARCHAR(100) NOT NULL,\ - lat DOUBLE NULL,\ - lng DOUBLE NULL)", - ); + let sql = "CREATE TABLE uk_cities (\ + name VARCHAR(100) NOT NULL,\ + lat DOUBLE NULL,\ + lng DOUBLE NULL)"; let ast = one_statement_parses_to( - &sql, + sql, "CREATE TABLE uk_cities (\ name character varying(100) NOT NULL, \ lat double, \ @@ -497,15 +564,13 @@ fn parse_create_table() { #[test] fn parse_create_external_table() { - let sql = String::from( - "CREATE EXTERNAL TABLE uk_cities (\ - name VARCHAR(100) NOT NULL,\ - lat DOUBLE NULL,\ - lng DOUBLE NULL)\ - STORED AS TEXTFILE LOCATION '/tmp/example.csv", - ); + let sql = "CREATE EXTERNAL TABLE uk_cities (\ + name VARCHAR(100) NOT NULL,\ + lat DOUBLE NULL,\ + lng DOUBLE NULL)\ + STORED AS TEXTFILE LOCATION '/tmp/example.csv"; let ast = one_statement_parses_to( - &sql, + sql, "CREATE EXTERNAL TABLE uk_cities (\ name character varying(100) NOT NULL, \ lat double, \ @@ -546,14 +611,38 @@ fn parse_create_external_table() { } } +#[test] +fn parse_alter_table_constraint_primary_key() { + let sql = "ALTER TABLE bazaar.address \ + ADD CONSTRAINT address_pkey PRIMARY KEY (address_id)"; + match verified_stmt(sql) { + SQLStatement::SQLAlterTable { name, .. } => { + assert_eq!(name.to_string(), "bazaar.address"); + } + _ => unreachable!(), + } +} + +#[test] +fn parse_alter_table_constraint_foreign_key() { + let sql = "ALTER TABLE public.customer \ + ADD CONSTRAINT customer_address_id_fkey FOREIGN KEY (address_id) REFERENCES public.address(address_id)"; + match verified_stmt(sql) { + SQLStatement::SQLAlterTable { name, .. } => { + assert_eq!(name.to_string(), "public.customer"); + } + _ => unreachable!(), + } +} + #[test] fn parse_scalar_function_in_projection() { let sql = "SELECT sqrt(id) FROM foo"; let select = verified_only_select(sql); assert_eq!( &ASTNode::SQLFunction { - name: SQLObjectName(vec![String::from("sqrt")]), - args: vec![ASTNode::SQLIdentifier(String::from("id"))], + name: SQLObjectName(vec!["sqrt".to_string()]), + args: vec![ASTNode::SQLIdentifier("id".to_string())], over: None, }, expr_from_projection(only(&select.projection)) @@ -623,16 +712,6 @@ fn parse_simple_math_expr_minus() { verified_only_select(sql); } -#[test] -fn parse_select_version() { - let sql = "SELECT @@version"; - let select = verified_only_select(sql); - assert_eq!( - &ASTNode::SQLIdentifier("@@version".to_string()), - expr_from_projection(only(&select.projection)), - ); -} - #[test] fn parse_delimited_identifiers() { // check that quoted identifiers in any position remain quoted after serialization @@ -1097,73 +1176,37 @@ fn parse_invalid_subquery_without_parens() { ); } -fn only(v: &[T]) -> &T { - assert_eq!(1, v.len()); - v.first().unwrap() -} - -fn verified_query(query: &str) -> SQLQuery { - match verified_stmt(query) { - SQLStatement::SQLQuery(query) => *query, - _ => panic!("Expected SQLQuery"), - } -} - -fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode { - match item { - SQLSelectItem::UnnamedExpression(expr) => expr, - _ => panic!("Expected UnnamedExpression"), - } -} - -fn verified_only_select(query: &str) -> SQLSelect { - match verified_query(query).body { - SQLSetExpr::Select(s) => *s, - _ => panic!("Expected SQLSetExpr::Select"), - } -} - -fn verified_stmt(query: &str) -> SQLStatement { - one_statement_parses_to(query, query) -} - -fn verified_expr(query: &str) -> ASTNode { - let ast = parse_sql_expr(query); - assert_eq!(query, &ast.to_string()); - ast -} - -/// Ensures that `sql` parses as a single statement, optionally checking that -/// converting AST back to string equals to `canonical` (unless an empty string -/// is provided). -fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement { - let mut statements = parse_sql_statements(&sql).unwrap(); - assert_eq!(statements.len(), 1); - - let only_statement = statements.pop().unwrap(); - if !canonical.is_empty() { - assert_eq!(canonical, only_statement.to_string()) - } - only_statement +#[test] +#[should_panic( + expected = "Parse results with GenericSqlDialect are different from PostgreSqlDialect" +)] +fn ensure_multiple_dialects_are_tested() { + // The SQL here must be parsed differently by different dialects. + // At the time of writing, `@foo` is accepted as a valid identifier + // by the Generic and the MSSQL dialect, but not by Postgres and ANSI. + let _ = parse_sql_statements("SELECT @foo"); } fn parse_sql_statements(sql: &str) -> Result, ParserError> { - let generic_ast = Parser::parse_sql(&GenericSqlDialect {}, sql.to_string()); - let pg_ast = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()); - assert_eq!(generic_ast, pg_ast); - generic_ast + all_dialects().parse_sql_statements(sql) } -fn parse_sql_expr(sql: &str) -> ASTNode { - let generic_ast = parse_sql_expr_with(&GenericSqlDialect {}, &sql.to_string()); - let pg_ast = parse_sql_expr_with(&PostgreSqlDialect {}, &sql.to_string()); - assert_eq!(generic_ast, pg_ast); - generic_ast +fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement { + all_dialects().one_statement_parses_to(sql, canonical) } -fn parse_sql_expr_with(dialect: &dyn Dialect, sql: &str) -> ASTNode { - let mut tokenizer = Tokenizer::new(dialect, &sql); - let tokens = tokenizer.tokenize().unwrap(); - let mut parser = Parser::new(tokens); - parser.parse_expr().unwrap() +fn verified_stmt(query: &str) -> SQLStatement { + all_dialects().verified_stmt(query) +} + +fn verified_query(query: &str) -> SQLQuery { + all_dialects().verified_query(query) +} + +fn verified_only_select(query: &str) -> SQLSelect { + all_dialects().verified_only_select(query) +} + +fn verified_expr(query: &str) -> ASTNode { + all_dialects().verified_expr(query) } diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs new file mode 100644 index 00000000..d209ed94 --- /dev/null +++ b/tests/sqlparser_mssql.rs @@ -0,0 +1,40 @@ +#![warn(clippy::all)] +//! Test SQL syntax specific to Microsoft's T-SQL. The parser based on the +//! generic dialect is also tested (on the inputs it can handle). + +use sqlparser::dialect::{GenericSqlDialect, MsSqlDialect}; +use sqlparser::sqlast::*; +use sqlparser::test_utils::*; + +#[test] +fn parse_mssql_identifiers() { + let sql = "SELECT @@version, _foo$123 FROM ##temp"; + let select = ms_and_generic().verified_only_select(sql); + assert_eq!( + &ASTNode::SQLIdentifier("@@version".to_string()), + expr_from_projection(&select.projection[0]), + ); + assert_eq!( + &ASTNode::SQLIdentifier("_foo$123".to_string()), + expr_from_projection(&select.projection[1]), + ); + assert_eq!(2, select.projection.len()); + match select.relation { + Some(TableFactor::Table { name, .. }) => { + assert_eq!("##temp".to_string(), name.to_string()); + } + _ => unreachable!(), + }; +} + +#[allow(dead_code)] +fn ms() -> TestedDialects { + TestedDialects { + dialects: vec![Box::new(MsSqlDialect {})], + } +} +fn ms_and_generic() -> TestedDialects { + TestedDialects { + dialects: vec![Box::new(MsSqlDialect {}), Box::new(GenericSqlDialect {})], + } +} diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 98cdb902..522bd74f 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1,155 +1,14 @@ #![warn(clippy::all)] +//! Test SQL syntax specific to PostgreSQL. The parser based on the +//! generic dialect is also tested (on the inputs it can handle). -use log::debug; - -use sqlparser::dialect::PostgreSqlDialect; +use sqlparser::dialect::{GenericSqlDialect, PostgreSqlDialect}; use sqlparser::sqlast::*; -use sqlparser::sqlparser::*; -use sqlparser::sqltokenizer::*; - -#[test] -fn test_prev_index() { - let sql: &str = "SELECT version()"; - let mut parser = parser(sql); - assert_eq!(parser.prev_token(), None); - assert_eq!(parser.next_token(), Some(Token::make_keyword("SELECT"))); - assert_eq!(parser.next_token(), Some(Token::make_word("version", None))); - assert_eq!(parser.prev_token(), Some(Token::make_word("version", None))); - assert_eq!(parser.peek_token(), Some(Token::make_word("version", None))); - assert_eq!(parser.prev_token(), Some(Token::make_keyword("SELECT"))); - assert_eq!(parser.prev_token(), None); -} - -#[test] -fn parse_simple_insert() { - let sql = String::from("INSERT INTO customer VALUES(1, 2, 3)"); - match verified_stmt(&sql) { - SQLStatement::SQLInsert { - table_name, - columns, - values, - .. - } => { - assert_eq!(table_name.to_string(), "customer"); - assert!(columns.is_empty()); - assert_eq!( - vec![vec![ - ASTNode::SQLValue(Value::Long(1)), - ASTNode::SQLValue(Value::Long(2)), - ASTNode::SQLValue(Value::Long(3)) - ]], - values - ); - } - _ => unreachable!(), - } -} - -#[test] -fn parse_common_insert() { - let sql = String::from("INSERT INTO public.customer VALUES(1, 2, 3)"); - match verified_stmt(&sql) { - SQLStatement::SQLInsert { - table_name, - columns, - values, - .. - } => { - assert_eq!(table_name.to_string(), "public.customer"); - assert!(columns.is_empty()); - assert_eq!( - vec![vec![ - ASTNode::SQLValue(Value::Long(1)), - ASTNode::SQLValue(Value::Long(2)), - ASTNode::SQLValue(Value::Long(3)) - ]], - values - ); - } - _ => unreachable!(), - } -} - -#[test] -fn parse_complex_insert() { - let sql = String::from("INSERT INTO db.public.customer VALUES(1, 2, 3)"); - match verified_stmt(&sql) { - SQLStatement::SQLInsert { - table_name, - columns, - values, - .. - } => { - assert_eq!(table_name.to_string(), "db.public.customer"); - assert!(columns.is_empty()); - assert_eq!( - vec![vec![ - ASTNode::SQLValue(Value::Long(1)), - ASTNode::SQLValue(Value::Long(2)), - ASTNode::SQLValue(Value::Long(3)) - ]], - values - ); - } - _ => unreachable!(), - } -} - -#[test] -fn parse_invalid_table_name() { - let mut parser = parser("db.public..customer"); - let ast = parser.parse_object_name(); - assert!(ast.is_err()); -} - -#[test] -fn parse_no_table_name() { - let mut parser = parser(""); - let ast = parser.parse_object_name(); - assert!(ast.is_err()); -} - -#[test] -fn parse_insert_with_columns() { - let sql = String::from("INSERT INTO public.customer (id, name, active) VALUES(1, 2, 3)"); - match verified_stmt(&sql) { - SQLStatement::SQLInsert { - table_name, - columns, - values, - .. - } => { - assert_eq!(table_name.to_string(), "public.customer"); - assert_eq!( - columns, - vec!["id".to_string(), "name".to_string(), "active".to_string()] - ); - assert_eq!( - vec![vec![ - ASTNode::SQLValue(Value::Long(1)), - ASTNode::SQLValue(Value::Long(2)), - ASTNode::SQLValue(Value::Long(3)) - ]], - values - ); - } - _ => unreachable!(), - } -} - -#[test] -fn parse_insert_invalid() { - let sql = String::from("INSERT public.customer (id, name, active) VALUES (1, 2, 3)"); - match Parser::parse_sql(&PostgreSqlDialect {}, sql) { - Err(_) => {} - _ => unreachable!(), - } -} +use sqlparser::test_utils::*; #[test] fn parse_create_table_with_defaults() { - let sql = String::from( - "CREATE TABLE public.customer ( + let sql = "CREATE TABLE public.customer ( customer_id integer DEFAULT nextval(public.customer_customer_id_seq) NOT NULL, store_id smallint NOT NULL, first_name character varying(45) NOT NULL, @@ -159,9 +18,8 @@ fn parse_create_table_with_defaults() { activebool boolean DEFAULT true NOT NULL, create_date date DEFAULT now()::text NOT NULL, last_update timestamp without time zone DEFAULT now() NOT NULL, - active integer NOT NULL)", - ); - match one_statement_parses_to(&sql, "") { + active integer NOT NULL)"; + match pg_and_generic().one_statement_parses_to(sql, "") { SQLStatement::SQLCreateTable { name, columns, @@ -193,8 +51,7 @@ fn parse_create_table_with_defaults() { #[test] fn parse_create_table_from_pg_dump() { - let sql = String::from(" - CREATE TABLE public.customer ( + let sql = "CREATE TABLE public.customer ( customer_id integer DEFAULT nextval('public.customer_customer_id_seq'::regclass) NOT NULL, store_id smallint NOT NULL, first_name character varying(45) NOT NULL, @@ -207,8 +64,8 @@ fn parse_create_table_from_pg_dump() { last_update timestamp without time zone DEFAULT now(), release_year public.year, active integer - )"); - match one_statement_parses_to(&sql, "") { + )"; + match pg().one_statement_parses_to(sql, "") { SQLStatement::SQLCreateTable { name, columns, @@ -262,16 +119,14 @@ fn parse_create_table_from_pg_dump() { #[test] fn parse_create_table_with_inherit() { - let sql = String::from( - "\ - CREATE TABLE bazaar.settings (\ - settings_id uuid PRIMARY KEY DEFAULT uuid_generate_v4() NOT NULL, \ - user_id uuid UNIQUE, \ - value text[], \ - use_metric boolean DEFAULT true\ - )", - ); - match verified_stmt(&sql) { + let sql = "\ + CREATE TABLE bazaar.settings (\ + settings_id uuid PRIMARY KEY DEFAULT uuid_generate_v4() NOT NULL, \ + user_id uuid UNIQUE, \ + value text[], \ + use_metric boolean DEFAULT true\ + )"; + match pg().verified_stmt(sql) { SQLStatement::SQLCreateTable { name, columns, @@ -299,37 +154,9 @@ fn parse_create_table_with_inherit() { } } -#[test] -fn parse_alter_table_constraint_primary_key() { - let sql = String::from( - "\ - ALTER TABLE bazaar.address \ - ADD CONSTRAINT address_pkey PRIMARY KEY (address_id)", - ); - match verified_stmt(&sql) { - SQLStatement::SQLAlterTable { name, .. } => { - assert_eq!(name.to_string(), "bazaar.address"); - } - _ => unreachable!(), - } -} - -#[test] -fn parse_alter_table_constraint_foreign_key() { - let sql = String::from("\ - ALTER TABLE public.customer \ - ADD CONSTRAINT customer_address_id_fkey FOREIGN KEY (address_id) REFERENCES public.address(address_id)"); - match verified_stmt(&sql) { - SQLStatement::SQLAlterTable { name, .. } => { - assert_eq!(name.to_string(), "public.customer"); - } - _ => unreachable!(), - } -} - #[test] fn parse_copy_example() { - let sql = String::from(r#"COPY public.actor (actor_id, first_name, last_name, last_update, value) FROM stdin; + let sql = r#"COPY public.actor (actor_id, first_name, last_name, last_update, value) FROM stdin; 1 PENELOPE GUINESS 2006-02-15 09:34:33 0.11111 2 NICK WAHLBERG 2006-02-15 09:34:33 0.22222 3 ED CHASE 2006-02-15 09:34:33 0.312323 @@ -348,74 +175,23 @@ Kwara & Kogi 'awe':5 'awe-inspir':4 'barbarella':1 'cat':13 'conquer':16 'dog':18 'feminist':10 'inspir':6 'monasteri':21 'must':15 'stori':7 'streetcar':2 PHP ₱ USD $ \N Some other value -\\."#); - let ast = one_statement_parses_to(&sql, ""); +\\."#; + let ast = pg_and_generic().one_statement_parses_to(sql, ""); println!("{:#?}", ast); //assert_eq!(sql, ast.to_string()); } -#[test] -fn parse_timestamps_example() { - let sql = "2016-02-15 09:43:33"; - let _ = parse_sql_expr(sql); - //TODO add assertion - //assert_eq!(sql, ast.to_string()); -} - -#[test] -fn parse_timestamps_with_millis_example() { - let sql = "2017-11-02 19:15:42.308637"; - let _ = parse_sql_expr(sql); - //TODO add assertion - //assert_eq!(sql, ast.to_string()); -} - -#[test] -fn parse_example_value() { - let sql = "SARAH.LEWIS@sakilacustomer.org"; - let ast = parse_sql_expr(sql); - assert_eq!(sql, ast.to_string()); -} - -#[test] -fn parse_function_now() { - let sql = "now()"; - let ast = parse_sql_expr(sql); - assert_eq!(sql, ast.to_string()); -} - -fn verified_stmt(query: &str) -> SQLStatement { - one_statement_parses_to(query, query) -} - -/// Ensures that `sql` parses as a single statement, optionally checking that -/// converting AST back to string equals to `canonical` (unless an empty string -/// is provided). -fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement { - let mut statements = parse_sql_statements(&sql).unwrap(); - assert_eq!(statements.len(), 1); - - let only_statement = statements.pop().unwrap(); - if !canonical.is_empty() { - assert_eq!(canonical, only_statement.to_string()) +fn pg() -> TestedDialects { + TestedDialects { + dialects: vec![Box::new(PostgreSqlDialect {})], } - only_statement } -fn parse_sql_statements(sql: &str) -> Result, ParserError> { - Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()) -} - -fn parse_sql_expr(sql: &str) -> ASTNode { - debug!("sql: {}", sql); - let mut parser = parser(sql); - parser.parse_expr().unwrap() -} - -fn parser(sql: &str) -> Parser { - let dialect = PostgreSqlDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); - let tokens = tokenizer.tokenize().unwrap(); - debug!("tokens: {:#?}", tokens); - Parser::new(tokens) +fn pg_and_generic() -> TestedDialects { + TestedDialects { + dialects: vec![ + Box::new(PostgreSqlDialect {}), + Box::new(GenericSqlDialect {}), + ], + } }