mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-08-23 15:34:09 +00:00
Enable dialect specific behaviours in the parser (#254)
* Change `Parser { ... }` to store the dialect used: `Parser<'a> { ... dialect: &'a dyn Dialect }` Thanks to @c7hm4r for the initial version of this submitted as part of https://github.com/ballista-compute/sqlparser-rs/pull/170 * Introduce `dialect_of!(parser is SQLiteDialect | GenericDialect)` helper to branch on the dialect's type * Use the new functionality to make `AUTO_INCREMENT` and `AUTOINCREMENT` parsing dialect-dependent. Co-authored-by: Christoph Müller <pmzqxfmn@runbox.com> Co-authored-by: Nickolay Ponomarev <asqueella@gmail.com>
This commit is contained in:
parent
3871bbc5ee
commit
1b46e82eec
5 changed files with 85 additions and 17 deletions
|
@ -18,6 +18,7 @@ mod mysql;
|
||||||
mod postgresql;
|
mod postgresql;
|
||||||
mod sqlite;
|
mod sqlite;
|
||||||
|
|
||||||
|
use std::any::{Any, TypeId};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
pub use self::ansi::AnsiDialect;
|
pub use self::ansi::AnsiDialect;
|
||||||
|
@ -27,7 +28,15 @@ pub use self::mysql::MySqlDialect;
|
||||||
pub use self::postgresql::PostgreSqlDialect;
|
pub use self::postgresql::PostgreSqlDialect;
|
||||||
pub use self::sqlite::SQLiteDialect;
|
pub use self::sqlite::SQLiteDialect;
|
||||||
|
|
||||||
pub trait Dialect: Debug {
|
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
|
||||||
|
/// to `true` iff `parser.dialect` is one of the `Dialect`s specified.
|
||||||
|
macro_rules! dialect_of {
|
||||||
|
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
|
||||||
|
($($parsed_dialect.dialect.is::<$dialect_type>())||+)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Dialect: Debug + Any {
|
||||||
/// Determine if a character starts a quoted identifier. The default
|
/// Determine if a character starts a quoted identifier. The default
|
||||||
/// implementation, accepting "double quoted" ids is both ANSI-compliant
|
/// implementation, accepting "double quoted" ids is both ANSI-compliant
|
||||||
/// and appropriate for most dialects (with the notable exception of
|
/// and appropriate for most dialects (with the notable exception of
|
||||||
|
@ -41,3 +50,51 @@ pub trait Dialect: Debug {
|
||||||
/// Determine if a character is a valid unquoted identifier character
|
/// Determine if a character is a valid unquoted identifier character
|
||||||
fn is_identifier_part(&self, ch: char) -> bool;
|
fn is_identifier_part(&self, ch: char) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl dyn Dialect {
|
||||||
|
#[inline]
|
||||||
|
pub fn is<T: Dialect>(&self) -> bool {
|
||||||
|
// borrowed from `Any` implementation
|
||||||
|
TypeId::of::<T>() == self.type_id()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::ansi::AnsiDialect;
|
||||||
|
use super::generic::GenericDialect;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
struct DialectHolder<'a> {
|
||||||
|
dialect: &'a dyn Dialect,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_dialect() {
|
||||||
|
let generic_dialect: &dyn Dialect = &GenericDialect {};
|
||||||
|
let ansi_dialect: &dyn Dialect = &AnsiDialect {};
|
||||||
|
|
||||||
|
let generic_holder = DialectHolder {
|
||||||
|
dialect: generic_dialect,
|
||||||
|
};
|
||||||
|
let ansi_holder = DialectHolder {
|
||||||
|
dialect: ansi_dialect,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
dialect_of!(generic_holder is GenericDialect | AnsiDialect),
|
||||||
|
true
|
||||||
|
);
|
||||||
|
assert_eq!(dialect_of!(generic_holder is AnsiDialect), false);
|
||||||
|
|
||||||
|
assert_eq!(dialect_of!(ansi_holder is AnsiDialect), true);
|
||||||
|
assert_eq!(
|
||||||
|
dialect_of!(ansi_holder is GenericDialect | AnsiDialect),
|
||||||
|
true
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
dialect_of!(ansi_holder is GenericDialect | MsSqlDialect),
|
||||||
|
false
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@
|
||||||
#![warn(clippy::all)]
|
#![warn(clippy::all)]
|
||||||
|
|
||||||
pub mod ast;
|
pub mod ast;
|
||||||
|
#[macro_use]
|
||||||
pub mod dialect;
|
pub mod dialect;
|
||||||
pub mod parser;
|
pub mod parser;
|
||||||
pub mod tokenizer;
|
pub mod tokenizer;
|
||||||
|
|
|
@ -15,9 +15,8 @@
|
||||||
use log::debug;
|
use log::debug;
|
||||||
|
|
||||||
use super::ast::*;
|
use super::ast::*;
|
||||||
use super::dialect::keywords;
|
|
||||||
use super::dialect::keywords::Keyword;
|
use super::dialect::keywords::Keyword;
|
||||||
use super::dialect::Dialect;
|
use super::dialect::*;
|
||||||
use super::tokenizer::*;
|
use super::tokenizer::*;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
@ -82,24 +81,28 @@ impl fmt::Display for ParserError {
|
||||||
|
|
||||||
impl Error for ParserError {}
|
impl Error for ParserError {}
|
||||||
|
|
||||||
/// SQL Parser
|
pub struct Parser<'a> {
|
||||||
pub struct Parser {
|
|
||||||
tokens: Vec<Token>,
|
tokens: Vec<Token>,
|
||||||
/// The index of the first unprocessed token in `self.tokens`
|
/// The index of the first unprocessed token in `self.tokens`
|
||||||
index: usize,
|
index: usize,
|
||||||
|
dialect: &'a dyn Dialect,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Parser {
|
impl<'a> Parser<'a> {
|
||||||
/// Parse the specified tokens
|
/// Parse the specified tokens
|
||||||
pub fn new(tokens: Vec<Token>) -> Self {
|
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
|
||||||
Parser { tokens, index: 0 }
|
Parser {
|
||||||
|
tokens,
|
||||||
|
index: 0,
|
||||||
|
dialect,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
|
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
|
||||||
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
|
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
|
||||||
let mut tokenizer = Tokenizer::new(dialect, &sql);
|
let mut tokenizer = Tokenizer::new(dialect, &sql);
|
||||||
let tokens = tokenizer.tokenize()?;
|
let tokens = tokenizer.tokenize()?;
|
||||||
let mut parser = Parser::new(tokens);
|
let mut parser = Parser::new(tokens, dialect);
|
||||||
let mut stmts = Vec::new();
|
let mut stmts = Vec::new();
|
||||||
let mut expecting_statement_delimiter = false;
|
let mut expecting_statement_delimiter = false;
|
||||||
debug!("Parsing sql '{}'...", sql);
|
debug!("Parsing sql '{}'...", sql);
|
||||||
|
@ -950,7 +953,7 @@ impl Parser {
|
||||||
/// Parse a comma-separated list of 1+ items accepted by `F`
|
/// Parse a comma-separated list of 1+ items accepted by `F`
|
||||||
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
|
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
|
||||||
where
|
where
|
||||||
F: FnMut(&mut Parser) -> Result<T, ParserError>,
|
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
|
||||||
{
|
{
|
||||||
let mut values = vec![];
|
let mut values = vec![];
|
||||||
loop {
|
loop {
|
||||||
|
@ -1285,10 +1288,14 @@ impl Parser {
|
||||||
let expr = self.parse_expr()?;
|
let expr = self.parse_expr()?;
|
||||||
self.expect_token(&Token::RParen)?;
|
self.expect_token(&Token::RParen)?;
|
||||||
ColumnOption::Check(expr)
|
ColumnOption::Check(expr)
|
||||||
} else if self.parse_keyword(Keyword::AUTO_INCREMENT) {
|
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
|
||||||
|
&& dialect_of!(self is MySqlDialect | GenericDialect)
|
||||||
|
{
|
||||||
// Support AUTO_INCREMENT for MySQL
|
// Support AUTO_INCREMENT for MySQL
|
||||||
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])
|
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTO_INCREMENT")])
|
||||||
} else if self.parse_keyword(Keyword::AUTOINCREMENT) {
|
} else if self.parse_keyword(Keyword::AUTOINCREMENT)
|
||||||
|
&& dialect_of!(self is SQLiteDialect | GenericDialect)
|
||||||
|
{
|
||||||
// Support AUTOINCREMENT for SQLite
|
// Support AUTOINCREMENT for SQLite
|
||||||
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")])
|
ColumnOption::DialectSpecific(vec![Token::make_keyword("AUTOINCREMENT")])
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -53,7 +53,7 @@ impl TestedDialects {
|
||||||
self.one_of_identical_results(|dialect| {
|
self.one_of_identical_results(|dialect| {
|
||||||
let mut tokenizer = Tokenizer::new(dialect, sql);
|
let mut tokenizer = Tokenizer::new(dialect, sql);
|
||||||
let tokens = tokenizer.tokenize().unwrap();
|
let tokens = tokenizer.tokenize().unwrap();
|
||||||
f(&mut Parser::new(tokens))
|
f(&mut Parser::new(tokens, dialect))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +104,9 @@ impl TestedDialects {
|
||||||
/// Ensures that `sql` parses as an expression, and is not modified
|
/// Ensures that `sql` parses as an expression, and is not modified
|
||||||
/// after a serialization round-trip.
|
/// after a serialization round-trip.
|
||||||
pub fn verified_expr(&self, sql: &str) -> Expr {
|
pub fn verified_expr(&self, sql: &str) -> Expr {
|
||||||
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
|
let ast = self
|
||||||
|
.run_parser_method(sql, |parser| parser.parse_expr())
|
||||||
|
.unwrap();
|
||||||
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
|
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
|
||||||
ast
|
ast
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ use matches::assert_matches;
|
||||||
|
|
||||||
use sqlparser::ast::*;
|
use sqlparser::ast::*;
|
||||||
use sqlparser::dialect::keywords::ALL_KEYWORDS;
|
use sqlparser::dialect::keywords::ALL_KEYWORDS;
|
||||||
use sqlparser::parser::{Parser, ParserError};
|
use sqlparser::parser::ParserError;
|
||||||
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
|
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -147,13 +147,14 @@ fn parse_update() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_invalid_table_name() {
|
fn parse_invalid_table_name() {
|
||||||
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
|
let ast = all_dialects()
|
||||||
|
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
|
||||||
assert!(ast.is_err());
|
assert!(ast.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_no_table_name() {
|
fn parse_no_table_name() {
|
||||||
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
|
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
|
||||||
assert!(ast.is_err());
|
assert!(ast.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue