mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-08-31 19:27:21 +00:00
Add ability for dialects to override prefix, infix, and statement parsing (#581)
This commit is contained in:
parent
7c02477151
commit
72559e9b62
5 changed files with 239 additions and 37 deletions
|
@ -22,6 +22,7 @@ mod redshift;
|
|||
mod snowflake;
|
||||
mod sqlite;
|
||||
|
||||
use crate::ast::{Expr, Statement};
|
||||
use core::any::{Any, TypeId};
|
||||
use core::fmt::Debug;
|
||||
use core::iter::Peekable;
|
||||
|
@ -39,6 +40,7 @@ pub use self::redshift::RedshiftSqlDialect;
|
|||
pub use self::snowflake::SnowflakeDialect;
|
||||
pub use self::sqlite::SQLiteDialect;
|
||||
pub use crate::keywords;
|
||||
use crate::parser::{Parser, ParserError};
|
||||
|
||||
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
|
||||
/// to `true` if `parser.dialect` is one of the `Dialect`s specified.
|
||||
|
@ -65,6 +67,31 @@ pub trait Dialect: Debug + Any {
|
|||
fn is_identifier_start(&self, ch: char) -> bool;
|
||||
/// Determine if a character is a valid unquoted identifier character
|
||||
fn is_identifier_part(&self, ch: char) -> bool;
|
||||
/// Dialect-specific prefix parser override
|
||||
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
|
||||
// return None to fall back to the default behavior
|
||||
None
|
||||
}
|
||||
/// Dialect-specific infix parser override
|
||||
fn parse_infix(
|
||||
&self,
|
||||
_parser: &mut Parser,
|
||||
_expr: &Expr,
|
||||
_precendence: u8,
|
||||
) -> Option<Result<Expr, ParserError>> {
|
||||
// return None to fall back to the default behavior
|
||||
None
|
||||
}
|
||||
/// Dialect-specific precedence override
|
||||
fn get_next_precedence(&self, _parser: &Parser) -> Option<Result<u8, ParserError>> {
|
||||
// return None to fall back to the default behavior
|
||||
None
|
||||
}
|
||||
/// Dialect-specific statement parser override
|
||||
fn parse_statement(&self, _parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
|
||||
// return None to fall back to the default behavior
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl dyn Dialect {
|
||||
|
|
|
@ -10,7 +10,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::ast::{CommentObject, Statement};
|
||||
use crate::dialect::Dialect;
|
||||
use crate::keywords::Keyword;
|
||||
use crate::parser::{Parser, ParserError};
|
||||
use crate::tokenizer::Token;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PostgreSqlDialect {}
|
||||
|
@ -30,4 +34,41 @@ impl Dialect for PostgreSqlDialect {
|
|||
|| ch == '$'
|
||||
|| ch == '_'
|
||||
}
|
||||
|
||||
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
|
||||
if parser.parse_keyword(Keyword::COMMENT) {
|
||||
Some(parse_comment(parser))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
|
||||
parser.expect_keyword(Keyword::ON)?;
|
||||
let token = parser.next_token();
|
||||
|
||||
let (object_type, object_name) = match token {
|
||||
Token::Word(w) if w.keyword == Keyword::COLUMN => {
|
||||
let object_name = parser.parse_object_name()?;
|
||||
(CommentObject::Column, object_name)
|
||||
}
|
||||
Token::Word(w) if w.keyword == Keyword::TABLE => {
|
||||
let object_name = parser.parse_object_name()?;
|
||||
(CommentObject::Table, object_name)
|
||||
}
|
||||
_ => parser.expected("comment object_type", token)?,
|
||||
};
|
||||
|
||||
parser.expect_keyword(Keyword::IS)?;
|
||||
let comment = if parser.parse_keyword(Keyword::NULL) {
|
||||
None
|
||||
} else {
|
||||
Some(parser.parse_literal_string()?)
|
||||
};
|
||||
Ok(Statement::Comment {
|
||||
object_type,
|
||||
object_name,
|
||||
comment,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -10,7 +10,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::ast::Statement;
|
||||
use crate::dialect::Dialect;
|
||||
use crate::keywords::Keyword;
|
||||
use crate::parser::{Parser, ParserError};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SQLiteDialect {}
|
||||
|
@ -35,4 +38,13 @@ impl Dialect for SQLiteDialect {
|
|||
fn is_identifier_part(&self, ch: char) -> bool {
|
||||
self.is_identifier_start(ch) || ('0'..='9').contains(&ch)
|
||||
}
|
||||
|
||||
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
|
||||
if parser.parse_keyword(Keyword::REPLACE) {
|
||||
parser.prev_token();
|
||||
Some(parser.parse_insert())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -152,6 +152,11 @@ impl<'a> Parser<'a> {
|
|||
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
|
||||
/// stopping before the statement separator, if any.
|
||||
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
|
||||
// allow the dialect to override statement parsing
|
||||
if let Some(statement) = self.dialect.parse_statement(self) {
|
||||
return statement;
|
||||
}
|
||||
|
||||
match self.next_token() {
|
||||
Token::Word(w) => match w.keyword {
|
||||
Keyword::KILL => Ok(self.parse_kill()?),
|
||||
|
@ -195,13 +200,6 @@ impl<'a> Parser<'a> {
|
|||
Keyword::EXECUTE => Ok(self.parse_execute()?),
|
||||
Keyword::PREPARE => Ok(self.parse_prepare()?),
|
||||
Keyword::MERGE => Ok(self.parse_merge()?),
|
||||
Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => {
|
||||
self.prev_token();
|
||||
Ok(self.parse_insert()?)
|
||||
}
|
||||
Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => {
|
||||
Ok(self.parse_comment()?)
|
||||
}
|
||||
_ => self.expected("an SQL statement", Token::Word(w)),
|
||||
},
|
||||
Token::LParen => {
|
||||
|
@ -381,6 +379,11 @@ impl<'a> Parser<'a> {
|
|||
|
||||
/// Parse an expression prefix
|
||||
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
|
||||
// allow the dialect to override prefix parsing
|
||||
if let Some(prefix) = self.dialect.parse_prefix(self) {
|
||||
return prefix;
|
||||
}
|
||||
|
||||
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
|
||||
// string literal represents a literal of that type. Some examples:
|
||||
//
|
||||
|
@ -1164,6 +1167,11 @@ impl<'a> Parser<'a> {
|
|||
|
||||
/// Parse an operator following an expression
|
||||
pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result<Expr, ParserError> {
|
||||
// allow the dialect to override infix parsing
|
||||
if let Some(infix) = self.dialect.parse_infix(self, &expr, precedence) {
|
||||
return infix;
|
||||
}
|
||||
|
||||
let tok = self.next_token();
|
||||
|
||||
let regular_binary_operator = match &tok {
|
||||
|
@ -1491,6 +1499,11 @@ impl<'a> Parser<'a> {
|
|||
|
||||
/// Get the precedence of the next token
|
||||
pub fn get_next_precedence(&self) -> Result<u8, ParserError> {
|
||||
// allow the dialect to override precedence logic
|
||||
if let Some(precedence) = self.dialect.get_next_precedence(self) {
|
||||
return precedence;
|
||||
}
|
||||
|
||||
let token = self.peek_token();
|
||||
debug!("get_next_precedence() {:?}", token);
|
||||
let token_0 = self.peek_nth_token(0);
|
||||
|
@ -1618,7 +1631,7 @@ impl<'a> Parser<'a> {
|
|||
}
|
||||
|
||||
/// Report unexpected token
|
||||
fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
|
||||
pub fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
|
||||
parser_err!(format!("Expected {}, found: {}", expected, found))
|
||||
}
|
||||
|
||||
|
@ -4735,35 +4748,6 @@ impl<'a> Parser<'a> {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
|
||||
self.expect_keyword(Keyword::ON)?;
|
||||
let token = self.next_token();
|
||||
|
||||
let (object_type, object_name) = match token {
|
||||
Token::Word(w) if w.keyword == Keyword::COLUMN => {
|
||||
let object_name = self.parse_object_name()?;
|
||||
(CommentObject::Column, object_name)
|
||||
}
|
||||
Token::Word(w) if w.keyword == Keyword::TABLE => {
|
||||
let object_name = self.parse_object_name()?;
|
||||
(CommentObject::Table, object_name)
|
||||
}
|
||||
_ => self.expected("comment object_type", token)?,
|
||||
};
|
||||
|
||||
self.expect_keyword(Keyword::IS)?;
|
||||
let comment = if self.parse_keyword(Keyword::NULL) {
|
||||
None
|
||||
} else {
|
||||
Some(self.parse_literal_string()?)
|
||||
};
|
||||
Ok(Statement::Comment {
|
||||
object_type,
|
||||
object_name,
|
||||
comment,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn parse_merge_clauses(&mut self) -> Result<Vec<MergeClause>, ParserError> {
|
||||
let mut clauses: Vec<MergeClause> = vec![];
|
||||
loop {
|
||||
|
|
138
tests/sqlparser_custom_dialect.rs
Normal file
138
tests/sqlparser_custom_dialect.rs
Normal file
|
@ -0,0 +1,138 @@
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Test the ability for dialects to override parsing
|
||||
|
||||
use sqlparser::{
|
||||
ast::{BinaryOperator, Expr, Statement, Value},
|
||||
dialect::Dialect,
|
||||
keywords::Keyword,
|
||||
parser::{Parser, ParserError},
|
||||
tokenizer::Token,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn custom_prefix_parser() -> Result<(), ParserError> {
|
||||
#[derive(Debug)]
|
||||
struct MyDialect {}
|
||||
|
||||
impl Dialect for MyDialect {
|
||||
fn is_identifier_start(&self, ch: char) -> bool {
|
||||
is_identifier_start(ch)
|
||||
}
|
||||
|
||||
fn is_identifier_part(&self, ch: char) -> bool {
|
||||
is_identifier_part(ch)
|
||||
}
|
||||
|
||||
fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
|
||||
if parser.consume_token(&Token::Number("1".to_string(), false)) {
|
||||
Some(Ok(Expr::Value(Value::Null)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let dialect = MyDialect {};
|
||||
let sql = "SELECT 1 + 2";
|
||||
let ast = Parser::parse_sql(&dialect, sql)?;
|
||||
let query = &ast[0];
|
||||
assert_eq!("SELECT NULL + 2", &format!("{}", query));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_infix_parser() -> Result<(), ParserError> {
|
||||
#[derive(Debug)]
|
||||
struct MyDialect {}
|
||||
|
||||
impl Dialect for MyDialect {
|
||||
fn is_identifier_start(&self, ch: char) -> bool {
|
||||
is_identifier_start(ch)
|
||||
}
|
||||
|
||||
fn is_identifier_part(&self, ch: char) -> bool {
|
||||
is_identifier_part(ch)
|
||||
}
|
||||
|
||||
fn parse_infix(
|
||||
&self,
|
||||
parser: &mut Parser,
|
||||
expr: &Expr,
|
||||
_precendence: u8,
|
||||
) -> Option<Result<Expr, ParserError>> {
|
||||
if parser.consume_token(&Token::Plus) {
|
||||
Some(Ok(Expr::BinaryOp {
|
||||
left: Box::new(expr.clone()),
|
||||
op: BinaryOperator::Multiply, // translate Plus to Multiply
|
||||
right: Box::new(parser.parse_expr().unwrap()),
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let dialect = MyDialect {};
|
||||
let sql = "SELECT 1 + 2";
|
||||
let ast = Parser::parse_sql(&dialect, sql)?;
|
||||
let query = &ast[0];
|
||||
assert_eq!("SELECT 1 * 2", &format!("{}", query));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_statement_parser() -> Result<(), ParserError> {
|
||||
#[derive(Debug)]
|
||||
struct MyDialect {}
|
||||
|
||||
impl Dialect for MyDialect {
|
||||
fn is_identifier_start(&self, ch: char) -> bool {
|
||||
is_identifier_start(ch)
|
||||
}
|
||||
|
||||
fn is_identifier_part(&self, ch: char) -> bool {
|
||||
is_identifier_part(ch)
|
||||
}
|
||||
|
||||
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
|
||||
if parser.parse_keyword(Keyword::SELECT) {
|
||||
for _ in 0..3 {
|
||||
let _ = parser.next_token();
|
||||
}
|
||||
Some(Ok(Statement::Commit { chain: false }))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let dialect = MyDialect {};
|
||||
let sql = "SELECT 1 + 2";
|
||||
let ast = Parser::parse_sql(&dialect, sql)?;
|
||||
let query = &ast[0];
|
||||
assert_eq!("COMMIT", &format!("{}", query));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_identifier_start(ch: char) -> bool {
|
||||
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_'
|
||||
}
|
||||
|
||||
fn is_identifier_part(ch: char) -> bool {
|
||||
('a'..='z').contains(&ch)
|
||||
|| ('A'..='Z').contains(&ch)
|
||||
|| ('0'..='9').contains(&ch)
|
||||
|| ch == '$'
|
||||
|| ch == '_'
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue