Enable dialect specific behaviours in the parser (#254)

* Change `Parser { ... }` to store the dialect used:
    `Parser<'a> { ... dialect: &'a dyn Dialect }`

    Thanks to @c7hm4r for the initial version of this submitted as
    part of https://github.com/ballista-compute/sqlparser-rs/pull/170

* Introduce `dialect_of!(parser is SQLiteDialect |  GenericDialect)` helper
    to branch on the dialect's type

* Use the new functionality to make `AUTO_INCREMENT` and `AUTOINCREMENT`
  parsing dialect-dependent.


Co-authored-by: Christoph Müller <pmzqxfmn@runbox.com>
Co-authored-by: Nickolay Ponomarev <asqueella@gmail.com>
This commit is contained in:
eyalleshem 2020-08-10 16:51:59 +03:00 committed by GitHub
parent 3871bbc5ee
commit 1b46e82eec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 17 deletions

View file

@ -18,6 +18,7 @@ mod mysql;
mod postgresql; mod postgresql;
mod sqlite; mod sqlite;
use std::any::{Any, TypeId};
use std::fmt::Debug; use std::fmt::Debug;
pub use self::ansi::AnsiDialect; pub use self::ansi::AnsiDialect;
@ -27,7 +28,15 @@ pub use self::mysql::MySqlDialect;
pub use self::postgresql::PostgreSqlDialect; pub use self::postgresql::PostgreSqlDialect;
pub use self::sqlite::SQLiteDialect; pub use self::sqlite::SQLiteDialect;
pub trait Dialect: Debug { /// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` iff `parser.dialect` is one of the `Dialect`s specified.
macro_rules! dialect_of {
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
($($parsed_dialect.dialect.is::<$dialect_type>())||+)
};
}
pub trait Dialect: Debug + Any {
/// Determine if a character starts a quoted identifier. The default /// Determine if a character starts a quoted identifier. The default
/// implementation, accepting "double quoted" ids is both ANSI-compliant /// implementation, accepting "double quoted" ids is both ANSI-compliant
/// and appropriate for most dialects (with the notable exception of /// and appropriate for most dialects (with the notable exception of
@ -41,3 +50,51 @@ pub trait Dialect: Debug {
/// Determine if a character is a valid unquoted identifier character /// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool; fn is_identifier_part(&self, ch: char) -> bool;
} }
impl dyn Dialect {
#[inline]
pub fn is<T: Dialect>(&self) -> bool {
// borrowed from `Any` implementation
TypeId::of::<T>() == self.type_id()
}
}
#[cfg(test)]
mod tests {
use super::ansi::AnsiDialect;
use super::generic::GenericDialect;
use super::*;
struct DialectHolder<'a> {
dialect: &'a dyn Dialect,
}
#[test]
fn test_is_dialect() {
let generic_dialect: &dyn Dialect = &GenericDialect {};
let ansi_dialect: &dyn Dialect = &AnsiDialect {};
let generic_holder = DialectHolder {
dialect: generic_dialect,
};
let ansi_holder = DialectHolder {
dialect: ansi_dialect,
};
assert_eq!(
dialect_of!(generic_holder is GenericDialect | AnsiDialect),
true
);
assert_eq!(dialect_of!(generic_holder is AnsiDialect), false);
assert_eq!(dialect_of!(ansi_holder is AnsiDialect), true);
assert_eq!(
dialect_of!(ansi_holder is GenericDialect | AnsiDialect),
true
);
assert_eq!(
dialect_of!(ansi_holder is GenericDialect | MsSqlDialect),
false
);
}
}

View file

@ -35,6 +35,7 @@
#![warn(clippy::all)] #![warn(clippy::all)]
pub mod ast; pub mod ast;
#[macro_use]
pub mod dialect; pub mod dialect;
pub mod parser; pub mod parser;
pub mod tokenizer; pub mod tokenizer;

View file

@ -15,9 +15,8 @@
use log::debug; use log::debug;
use super::ast::*; use super::ast::*;
use super::dialect::keywords;
use super::dialect::keywords::Keyword; use super::dialect::keywords::Keyword;
use super::dialect::Dialect; use super::dialect::*;
use super::tokenizer::*; use super::tokenizer::*;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
@ -82,24 +81,28 @@ impl fmt::Display for ParserError {
impl Error for ParserError {} impl Error for ParserError {}
/// SQL Parser pub struct Parser<'a> {
pub struct Parser {
tokens: Vec<Token>, tokens: Vec<Token>,
/// The index of the first unprocessed token in `self.tokens` /// The index of the first unprocessed token in `self.tokens`
index: usize, index: usize,
dialect: &'a dyn Dialect,
} }
impl Parser { impl<'a> Parser<'a> {
/// Parse the specified tokens /// Parse the specified tokens
pub fn new(tokens: Vec<Token>) -> Self { pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
Parser { tokens, index: 0 } Parser {
tokens,
index: 0,
dialect,
}
} }
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST) /// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> { pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, &sql); let mut tokenizer = Tokenizer::new(dialect, &sql);
let tokens = tokenizer.tokenize()?; let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens); let mut parser = Parser::new(tokens, dialect);
let mut stmts = Vec::new(); let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false; let mut expecting_statement_delimiter = false;
debug!("Parsing sql '{}'...", sql); debug!("Parsing sql '{}'...", sql);
@ -950,7 +953,7 @@ impl Parser {
/// Parse a comma-separated list of 1+ items accepted by `F` /// Parse a comma-separated list of 1+ items accepted by `F`
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError> pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
where where
F: FnMut(&mut Parser) -> Result<T, ParserError>, F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
{ {
let mut values = vec![]; let mut values = vec![];
loop { loop {
@ -1285,10 +1288,14 @@ impl Parser {
let expr = self.parse_expr()?; let expr = self.parse_expr()?;
self.expect_token(&Token::RParen)?; self.expect_token(&Token::RParen)?;
ColumnOption::Check(expr) ColumnOption::Check(expr)
} else if self.parse_keyword(Keyword::AUTO_INCREMENT) { } else if self.parse_keyword(Keyword::AUTO_INCREMENT)
&& dialect_of!(self is MySqlDialect | GenericDialect)
{
// Support AUTO_INCREMENT for MySQL // Support AUTO_INCREMENT for MySQL
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")]) ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])
} else if self.parse_keyword(Keyword::AUTOINCREMENT) { } else if self.parse_keyword(Keyword::AUTOINCREMENT)
&& dialect_of!(self is SQLiteDialect | GenericDialect)
{
// Support AUTOINCREMENT for SQLite // Support AUTOINCREMENT for SQLite
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")]) ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")])
} else { } else {

View file

@ -53,7 +53,7 @@ impl TestedDialects {
self.one_of_identical_results(|dialect| { self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql); let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize().unwrap(); let tokens = tokenizer.tokenize().unwrap();
f(&mut Parser::new(tokens)) f(&mut Parser::new(tokens, dialect))
}) })
} }
@ -104,7 +104,9 @@ impl TestedDialects {
/// Ensures that `sql` parses as an expression, and is not modified /// Ensures that `sql` parses as an expression, and is not modified
/// after a serialization round-trip. /// after a serialization round-trip.
pub fn verified_expr(&self, sql: &str) -> Expr { pub fn verified_expr(&self, sql: &str) -> Expr {
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap(); let ast = self
.run_parser_method(sql, |parser| parser.parse_expr())
.unwrap();
assert_eq!(sql, &ast.to_string(), "round-tripping without changes"); assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
ast ast
} }

View file

@ -22,7 +22,7 @@ use matches::assert_matches;
use sqlparser::ast::*; use sqlparser::ast::*;
use sqlparser::dialect::keywords::ALL_KEYWORDS; use sqlparser::dialect::keywords::ALL_KEYWORDS;
use sqlparser::parser::{Parser, ParserError}; use sqlparser::parser::ParserError;
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only}; use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
#[test] #[test]
@ -147,13 +147,14 @@ fn parse_update() {
#[test] #[test]
fn parse_invalid_table_name() { fn parse_invalid_table_name() {
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name); let ast = all_dialects()
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
assert!(ast.is_err()); assert!(ast.is_err());
} }
#[test] #[test]
fn parse_no_table_name() { fn parse_no_table_name() {
let ast = all_dialects().run_parser_method("", Parser::parse_object_name); let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
assert!(ast.is_err()); assert!(ast.is_err());
} }