Enable dialect specific behaviours in the parser (#254)

* Change `Parser { ... }` to store the dialect used:
    `Parser<'a> { ... dialect: &'a dyn Dialect }`

    Thanks to @c7hm4r for the initial version of this submitted as
    part of https://github.com/ballista-compute/sqlparser-rs/pull/170

* Introduce `dialect_of!(parser is SQLiteDialect |  GenericDialect)` helper
    to branch on the dialect's type

* Use the new functionality to make `AUTO_INCREMENT` and `AUTOINCREMENT`
  parsing dialect-dependent.


Co-authored-by: Christoph Müller <pmzqxfmn@runbox.com>
Co-authored-by: Nickolay Ponomarev <asqueella@gmail.com>
This commit is contained in:
eyalleshem 2020-08-10 16:51:59 +03:00 committed by GitHub
parent 3871bbc5ee
commit 1b46e82eec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 17 deletions

View file

@ -18,6 +18,7 @@ mod mysql;
mod postgresql;
mod sqlite;
use std::any::{Any, TypeId};
use std::fmt::Debug;
pub use self::ansi::AnsiDialect;
@ -27,7 +28,15 @@ pub use self::mysql::MySqlDialect;
pub use self::postgresql::PostgreSqlDialect;
pub use self::sqlite::SQLiteDialect;
pub trait Dialect: Debug {
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` iff `parser.dialect` is one of the `Dialect`s specified.
macro_rules! dialect_of {
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
($($parsed_dialect.dialect.is::<$dialect_type>())||+)
};
}
pub trait Dialect: Debug + Any {
/// Determine if a character starts a quoted identifier. The default
/// implementation, accepting "double quoted" ids is both ANSI-compliant
/// and appropriate for most dialects (with the notable exception of
@ -41,3 +50,51 @@ pub trait Dialect: Debug {
/// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool;
}
impl dyn Dialect {
#[inline]
pub fn is<T: Dialect>(&self) -> bool {
// borrowed from `Any` implementation
TypeId::of::<T>() == self.type_id()
}
}
#[cfg(test)]
mod tests {
use super::ansi::AnsiDialect;
use super::generic::GenericDialect;
use super::*;
struct DialectHolder<'a> {
dialect: &'a dyn Dialect,
}
#[test]
fn test_is_dialect() {
let generic_dialect: &dyn Dialect = &GenericDialect {};
let ansi_dialect: &dyn Dialect = &AnsiDialect {};
let generic_holder = DialectHolder {
dialect: generic_dialect,
};
let ansi_holder = DialectHolder {
dialect: ansi_dialect,
};
assert_eq!(
dialect_of!(generic_holder is GenericDialect | AnsiDialect),
true
);
assert_eq!(dialect_of!(generic_holder is AnsiDialect), false);
assert_eq!(dialect_of!(ansi_holder is AnsiDialect), true);
assert_eq!(
dialect_of!(ansi_holder is GenericDialect | AnsiDialect),
true
);
assert_eq!(
dialect_of!(ansi_holder is GenericDialect | MsSqlDialect),
false
);
}
}

View file

@ -35,6 +35,7 @@
#![warn(clippy::all)]
pub mod ast;
#[macro_use]
pub mod dialect;
pub mod parser;
pub mod tokenizer;

View file

@ -15,9 +15,8 @@
use log::debug;
use super::ast::*;
use super::dialect::keywords;
use super::dialect::keywords::Keyword;
use super::dialect::Dialect;
use super::dialect::*;
use super::tokenizer::*;
use std::error::Error;
use std::fmt;
@ -82,24 +81,28 @@ impl fmt::Display for ParserError {
impl Error for ParserError {}
/// SQL Parser
pub struct Parser {
pub struct Parser<'a> {
tokens: Vec<Token>,
/// The index of the first unprocessed token in `self.tokens`
index: usize,
dialect: &'a dyn Dialect,
}
impl Parser {
impl<'a> Parser<'a> {
/// Parse the specified tokens
pub fn new(tokens: Vec<Token>) -> Self {
Parser { tokens, index: 0 }
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
Parser {
tokens,
index: 0,
dialect,
}
}
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, &sql);
let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens);
let mut parser = Parser::new(tokens, dialect);
let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false;
debug!("Parsing sql '{}'...", sql);
@ -950,7 +953,7 @@ impl Parser {
/// Parse a comma-separated list of 1+ items accepted by `F`
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
where
F: FnMut(&mut Parser) -> Result<T, ParserError>,
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
{
let mut values = vec![];
loop {
@ -1285,10 +1288,14 @@ impl Parser {
let expr = self.parse_expr()?;
self.expect_token(&Token::RParen)?;
ColumnOption::Check(expr)
} else if self.parse_keyword(Keyword::AUTO_INCREMENT) {
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
&& dialect_of!(self is MySqlDialect | GenericDialect)
{
// Support AUTO_INCREMENT for MySQL
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])
} else if self.parse_keyword(Keyword::AUTOINCREMENT) {
} else if self.parse_keyword(Keyword::AUTOINCREMENT)
&& dialect_of!(self is SQLiteDialect | GenericDialect)
{
// Support AUTOINCREMENT for SQLite
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")])
} else {

View file

@ -53,7 +53,7 @@ impl TestedDialects {
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize().unwrap();
f(&mut Parser::new(tokens))
f(&mut Parser::new(tokens, dialect))
})
}
@ -104,7 +104,9 @@ impl TestedDialects {
/// Ensures that `sql` parses as an expression, and is not modified
/// after a serialization round-trip.
pub fn verified_expr(&self, sql: &str) -> Expr {
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
let ast = self
.run_parser_method(sql, |parser| parser.parse_expr())
.unwrap();
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
ast
}

View file

@ -22,7 +22,7 @@ use matches::assert_matches;
use sqlparser::ast::*;
use sqlparser::dialect::keywords::ALL_KEYWORDS;
use sqlparser::parser::{Parser, ParserError};
use sqlparser::parser::ParserError;
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
#[test]
@ -147,13 +147,14 @@ fn parse_update() {
#[test]
fn parse_invalid_table_name() {
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
let ast = all_dialects()
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
assert!(ast.is_err());
}
#[test]
fn parse_no_table_name() {
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
assert!(ast.is_err());
}