Add ability for dialects to override prefix, infix, and statement parsing (#581)

This commit is contained in:
Andy Grove 2022-08-19 05:44:14 -06:00 committed by GitHub
parent 7c02477151
commit 72559e9b62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 239 additions and 37 deletions

View file

@ -22,6 +22,7 @@ mod redshift;
mod snowflake; mod snowflake;
mod sqlite; mod sqlite;
use crate::ast::{Expr, Statement};
use core::any::{Any, TypeId}; use core::any::{Any, TypeId};
use core::fmt::Debug; use core::fmt::Debug;
use core::iter::Peekable; use core::iter::Peekable;
@ -39,6 +40,7 @@ pub use self::redshift::RedshiftSqlDialect;
pub use self::snowflake::SnowflakeDialect; pub use self::snowflake::SnowflakeDialect;
pub use self::sqlite::SQLiteDialect; pub use self::sqlite::SQLiteDialect;
pub use crate::keywords; pub use crate::keywords;
use crate::parser::{Parser, ParserError};
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates /// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` if `parser.dialect` is one of the `Dialect`s specified. /// to `true` if `parser.dialect` is one of the `Dialect`s specified.
@ -65,6 +67,31 @@ pub trait Dialect: Debug + Any {
fn is_identifier_start(&self, ch: char) -> bool; fn is_identifier_start(&self, ch: char) -> bool;
/// Determine if a character is a valid unquoted identifier character /// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool; fn is_identifier_part(&self, ch: char) -> bool;
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific infix parser override
fn parse_infix(
&self,
_parser: &mut Parser,
_expr: &Expr,
_precendence: u8,
) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific precedence override
fn get_next_precedence(&self, _parser: &Parser) -> Option<Result<u8, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific statement parser override
fn parse_statement(&self, _parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
// return None to fall back to the default behavior
None
}
} }
impl dyn Dialect { impl dyn Dialect {

View file

@ -10,7 +10,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::ast::{CommentObject, Statement};
use crate::dialect::Dialect; use crate::dialect::Dialect;
use crate::keywords::Keyword;
use crate::parser::{Parser, ParserError};
use crate::tokenizer::Token;
#[derive(Debug)] #[derive(Debug)]
pub struct PostgreSqlDialect {} pub struct PostgreSqlDialect {}
@ -30,4 +34,41 @@ impl Dialect for PostgreSqlDialect {
|| ch == '$' || ch == '$'
|| ch == '_' || ch == '_'
} }
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::COMMENT) {
Some(parse_comment(parser))
} else {
None
}
}
}
pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
parser.expect_keyword(Keyword::ON)?;
let token = parser.next_token();
let (object_type, object_name) = match token {
Token::Word(w) if w.keyword == Keyword::COLUMN => {
let object_name = parser.parse_object_name()?;
(CommentObject::Column, object_name)
}
Token::Word(w) if w.keyword == Keyword::TABLE => {
let object_name = parser.parse_object_name()?;
(CommentObject::Table, object_name)
}
_ => parser.expected("comment object_type", token)?,
};
parser.expect_keyword(Keyword::IS)?;
let comment = if parser.parse_keyword(Keyword::NULL) {
None
} else {
Some(parser.parse_literal_string()?)
};
Ok(Statement::Comment {
object_type,
object_name,
comment,
})
} }

View file

@ -10,7 +10,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::ast::Statement;
use crate::dialect::Dialect; use crate::dialect::Dialect;
use crate::keywords::Keyword;
use crate::parser::{Parser, ParserError};
#[derive(Debug)] #[derive(Debug)]
pub struct SQLiteDialect {} pub struct SQLiteDialect {}
@ -35,4 +38,13 @@ impl Dialect for SQLiteDialect {
fn is_identifier_part(&self, ch: char) -> bool { fn is_identifier_part(&self, ch: char) -> bool {
self.is_identifier_start(ch) || ('0'..='9').contains(&ch) self.is_identifier_start(ch) || ('0'..='9').contains(&ch)
} }
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::REPLACE) {
parser.prev_token();
Some(parser.parse_insert())
} else {
None
}
}
} }

View file

@ -152,6 +152,11 @@ impl<'a> Parser<'a> {
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
/// stopping before the statement separator, if any. /// stopping before the statement separator, if any.
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> { pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
// allow the dialect to override statement parsing
if let Some(statement) = self.dialect.parse_statement(self) {
return statement;
}
match self.next_token() { match self.next_token() {
Token::Word(w) => match w.keyword { Token::Word(w) => match w.keyword {
Keyword::KILL => Ok(self.parse_kill()?), Keyword::KILL => Ok(self.parse_kill()?),
@ -195,13 +200,6 @@ impl<'a> Parser<'a> {
Keyword::EXECUTE => Ok(self.parse_execute()?), Keyword::EXECUTE => Ok(self.parse_execute()?),
Keyword::PREPARE => Ok(self.parse_prepare()?), Keyword::PREPARE => Ok(self.parse_prepare()?),
Keyword::MERGE => Ok(self.parse_merge()?), Keyword::MERGE => Ok(self.parse_merge()?),
Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => {
self.prev_token();
Ok(self.parse_insert()?)
}
Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => {
Ok(self.parse_comment()?)
}
_ => self.expected("an SQL statement", Token::Word(w)), _ => self.expected("an SQL statement", Token::Word(w)),
}, },
Token::LParen => { Token::LParen => {
@ -381,6 +379,11 @@ impl<'a> Parser<'a> {
/// Parse an expression prefix /// Parse an expression prefix
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> { pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
// allow the dialect to override prefix parsing
if let Some(prefix) = self.dialect.parse_prefix(self) {
return prefix;
}
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the // PostgreSQL allows any string literal to be preceded by a type name, indicating that the
// string literal represents a literal of that type. Some examples: // string literal represents a literal of that type. Some examples:
// //
@ -1164,6 +1167,11 @@ impl<'a> Parser<'a> {
/// Parse an operator following an expression /// Parse an operator following an expression
pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result<Expr, ParserError> { pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result<Expr, ParserError> {
// allow the dialect to override infix parsing
if let Some(infix) = self.dialect.parse_infix(self, &expr, precedence) {
return infix;
}
let tok = self.next_token(); let tok = self.next_token();
let regular_binary_operator = match &tok { let regular_binary_operator = match &tok {
@ -1491,6 +1499,11 @@ impl<'a> Parser<'a> {
/// Get the precedence of the next token /// Get the precedence of the next token
pub fn get_next_precedence(&self) -> Result<u8, ParserError> { pub fn get_next_precedence(&self) -> Result<u8, ParserError> {
// allow the dialect to override precedence logic
if let Some(precedence) = self.dialect.get_next_precedence(self) {
return precedence;
}
let token = self.peek_token(); let token = self.peek_token();
debug!("get_next_precedence() {:?}", token); debug!("get_next_precedence() {:?}", token);
let token_0 = self.peek_nth_token(0); let token_0 = self.peek_nth_token(0);
@ -1618,7 +1631,7 @@ impl<'a> Parser<'a> {
} }
/// Report unexpected token /// Report unexpected token
fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> { pub fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
parser_err!(format!("Expected {}, found: {}", expected, found)) parser_err!(format!("Expected {}, found: {}", expected, found))
} }
@ -4735,35 +4748,6 @@ impl<'a> Parser<'a> {
}) })
} }
pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword(Keyword::ON)?;
let token = self.next_token();
let (object_type, object_name) = match token {
Token::Word(w) if w.keyword == Keyword::COLUMN => {
let object_name = self.parse_object_name()?;
(CommentObject::Column, object_name)
}
Token::Word(w) if w.keyword == Keyword::TABLE => {
let object_name = self.parse_object_name()?;
(CommentObject::Table, object_name)
}
_ => self.expected("comment object_type", token)?,
};
self.expect_keyword(Keyword::IS)?;
let comment = if self.parse_keyword(Keyword::NULL) {
None
} else {
Some(self.parse_literal_string()?)
};
Ok(Statement::Comment {
object_type,
object_name,
comment,
})
}
pub fn parse_merge_clauses(&mut self) -> Result<Vec<MergeClause>, ParserError> { pub fn parse_merge_clauses(&mut self) -> Result<Vec<MergeClause>, ParserError> {
let mut clauses: Vec<MergeClause> = vec![]; let mut clauses: Vec<MergeClause> = vec![];
loop { loop {

View file

@ -0,0 +1,138 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Test the ability for dialects to override parsing
use sqlparser::{
ast::{BinaryOperator, Expr, Statement, Value},
dialect::Dialect,
keywords::Keyword,
parser::{Parser, ParserError},
tokenizer::Token,
};
#[test]
fn custom_prefix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}
impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}
fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}
fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
if parser.consume_token(&Token::Number("1".to_string(), false)) {
Some(Ok(Expr::Value(Value::Null)))
} else {
None
}
}
}
let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("SELECT NULL + 2", &format!("{}", query));
Ok(())
}
#[test]
fn custom_infix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}
impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}
fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}
fn parse_infix(
&self,
parser: &mut Parser,
expr: &Expr,
_precendence: u8,
) -> Option<Result<Expr, ParserError>> {
if parser.consume_token(&Token::Plus) {
Some(Ok(Expr::BinaryOp {
left: Box::new(expr.clone()),
op: BinaryOperator::Multiply, // translate Plus to Multiply
right: Box::new(parser.parse_expr().unwrap()),
}))
} else {
None
}
}
}
let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("SELECT 1 * 2", &format!("{}", query));
Ok(())
}
#[test]
fn custom_statement_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}
impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}
fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::SELECT) {
for _ in 0..3 {
let _ = parser.next_token();
}
Some(Ok(Statement::Commit { chain: false }))
} else {
None
}
}
}
let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("COMMIT", &format!("{}", query));
Ok(())
}
fn is_identifier_start(ch: char) -> bool {
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_'
}
fn is_identifier_part(ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| ch == '$'
|| ch == '_'
}