Fallback to identifier parsing if expression parsing fails (#1513)

This commit is contained in:
Yoav Cohen 2024-11-25 22:01:02 +01:00 committed by GitHub
parent 0fb2ef331e
commit fd21fae297
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 276 additions and 190 deletions

View file

@ -681,6 +681,12 @@ pub trait Dialect: Debug + Any {
fn supports_partiql(&self) -> bool {
false
}
/// Returns true if the specified keyword is reserved and cannot be
/// used as an identifier without special handling like quoting.
fn is_reserved_for_identifier(&self, kw: Keyword) -> bool {
keywords::RESERVED_FOR_IDENTIFIER.contains(&kw)
}
}
/// This represents the operators for which precedence must be defined

View file

@ -38,6 +38,8 @@ use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
use alloc::{format, vec};
use super::keywords::RESERVED_FOR_IDENTIFIER;
/// A [`Dialect`] for [Snowflake](https://www.snowflake.com/)
#[derive(Debug, Default)]
pub struct SnowflakeDialect;
@ -214,6 +216,16 @@ impl Dialect for SnowflakeDialect {
fn supports_show_like_before_in(&self) -> bool {
true
}
fn is_reserved_for_identifier(&self, kw: Keyword) -> bool {
// Unreserve some keywords that Snowflake accepts as identifiers
// See: https://docs.snowflake.com/en/sql-reference/reserved-keywords
if matches!(kw, Keyword::INTERVAL) {
false
} else {
RESERVED_FOR_IDENTIFIER.contains(&kw)
}
}
}
/// Parse snowflake create table statement.

View file

@ -948,3 +948,13 @@ pub const RESERVED_FOR_COLUMN_ALIAS: &[Keyword] = &[
Keyword::INTO,
Keyword::END,
];
/// Global list of reserved keywords that cannot be parsed as identifiers
/// without special handling like quoting. Parser should call `Dialect::is_reserved_for_identifier`
/// to allow for each dialect to customize the list.
pub const RESERVED_FOR_IDENTIFIER: &[Keyword] = &[
Keyword::EXISTS,
Keyword::INTERVAL,
Keyword::STRUCT,
Keyword::TRIM,
];

View file

@ -1025,6 +1025,183 @@ impl<'a> Parser<'a> {
Ok(Statement::NOTIFY { channel, payload })
}
// Tries to parse an expression by matching the specified word to known keywords that have a special meaning in the dialect.
// Returns `None if no match is found.
fn parse_expr_prefix_by_reserved_word(
&mut self,
w: &Word,
) -> Result<Option<Expr>, ParserError> {
match w.keyword {
Keyword::TRUE | Keyword::FALSE if self.dialect.supports_boolean_literals() => {
self.prev_token();
Ok(Some(Expr::Value(self.parse_value()?)))
}
Keyword::NULL => {
self.prev_token();
Ok(Some(Expr::Value(self.parse_value()?)))
}
Keyword::CURRENT_CATALOG
| Keyword::CURRENT_USER
| Keyword::SESSION_USER
| Keyword::USER
if dialect_of!(self is PostgreSqlDialect | GenericDialect) =>
{
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![w.to_ident()]),
parameters: FunctionArguments::None,
args: FunctionArguments::None,
null_treatment: None,
filter: None,
over: None,
within_group: vec![],
})))
}
Keyword::CURRENT_TIMESTAMP
| Keyword::CURRENT_TIME
| Keyword::CURRENT_DATE
| Keyword::LOCALTIME
| Keyword::LOCALTIMESTAMP => {
Ok(Some(self.parse_time_functions(ObjectName(vec![w.to_ident()]))?))
}
Keyword::CASE => Ok(Some(self.parse_case_expr()?)),
Keyword::CONVERT => Ok(Some(self.parse_convert_expr(false)?)),
Keyword::TRY_CONVERT if self.dialect.supports_try_convert() => Ok(Some(self.parse_convert_expr(true)?)),
Keyword::CAST => Ok(Some(self.parse_cast_expr(CastKind::Cast)?)),
Keyword::TRY_CAST => Ok(Some(self.parse_cast_expr(CastKind::TryCast)?)),
Keyword::SAFE_CAST => Ok(Some(self.parse_cast_expr(CastKind::SafeCast)?)),
Keyword::EXISTS
// Support parsing Databricks has a function named `exists`.
if !dialect_of!(self is DatabricksDialect)
|| matches!(
self.peek_nth_token(1).token,
Token::Word(Word {
keyword: Keyword::SELECT | Keyword::WITH,
..
})
) =>
{
Ok(Some(self.parse_exists_expr(false)?))
}
Keyword::EXTRACT => Ok(Some(self.parse_extract_expr()?)),
Keyword::CEIL => Ok(Some(self.parse_ceil_floor_expr(true)?)),
Keyword::FLOOR => Ok(Some(self.parse_ceil_floor_expr(false)?)),
Keyword::POSITION if self.peek_token().token == Token::LParen => {
Ok(Some(self.parse_position_expr(w.to_ident())?))
}
Keyword::SUBSTRING => Ok(Some(self.parse_substring_expr()?)),
Keyword::OVERLAY => Ok(Some(self.parse_overlay_expr()?)),
Keyword::TRIM => Ok(Some(self.parse_trim_expr()?)),
Keyword::INTERVAL => Ok(Some(self.parse_interval()?)),
// Treat ARRAY[1,2,3] as an array [1,2,3], otherwise try as subquery or a function call
Keyword::ARRAY if self.peek_token() == Token::LBracket => {
self.expect_token(&Token::LBracket)?;
Ok(Some(self.parse_array_expr(true)?))
}
Keyword::ARRAY
if self.peek_token() == Token::LParen
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
{
self.expect_token(&Token::LParen)?;
let query = self.parse_query()?;
self.expect_token(&Token::RParen)?;
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![w.to_ident()]),
parameters: FunctionArguments::None,
args: FunctionArguments::Subquery(query),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
})))
}
Keyword::NOT => Ok(Some(self.parse_not()?)),
Keyword::MATCH if dialect_of!(self is MySqlDialect | GenericDialect) => {
Ok(Some(self.parse_match_against()?))
}
Keyword::STRUCT if dialect_of!(self is BigQueryDialect | GenericDialect) => {
self.prev_token();
Ok(Some(self.parse_bigquery_struct_literal()?))
}
Keyword::PRIOR if matches!(self.state, ParserState::ConnectBy) => {
let expr = self.parse_subexpr(self.dialect.prec_value(Precedence::PlusMinus))?;
Ok(Some(Expr::Prior(Box::new(expr))))
}
Keyword::MAP if self.peek_token() == Token::LBrace && self.dialect.support_map_literal_syntax() => {
Ok(Some(self.parse_duckdb_map_literal()?))
}
_ => Ok(None)
}
}
// Tries to parse an expression by a word that is not known to have a special meaning in the dialect.
fn parse_expr_prefix_by_unreserved_word(&mut self, w: &Word) -> Result<Expr, ParserError> {
match self.peek_token().token {
Token::LParen | Token::Period => {
let mut id_parts: Vec<Ident> = vec![w.to_ident()];
let mut ends_with_wildcard = false;
while self.consume_token(&Token::Period) {
let next_token = self.next_token();
match next_token.token {
Token::Word(w) => id_parts.push(w.to_ident()),
Token::Mul => {
// Postgres explicitly allows funcnm(tablenm.*) and the
// function array_agg traverses this control flow
if dialect_of!(self is PostgreSqlDialect) {
ends_with_wildcard = true;
break;
} else {
return self.expected("an identifier after '.'", next_token);
}
}
Token::SingleQuotedString(s) => id_parts.push(Ident::with_quote('\'', s)),
_ => {
return self.expected("an identifier or a '*' after '.'", next_token);
}
}
}
if ends_with_wildcard {
Ok(Expr::QualifiedWildcard(ObjectName(id_parts)))
} else if self.consume_token(&Token::LParen) {
if dialect_of!(self is SnowflakeDialect | MsSqlDialect)
&& self.consume_tokens(&[Token::Plus, Token::RParen])
{
Ok(Expr::OuterJoin(Box::new(
match <[Ident; 1]>::try_from(id_parts) {
Ok([ident]) => Expr::Identifier(ident),
Err(parts) => Expr::CompoundIdentifier(parts),
},
)))
} else {
self.prev_token();
self.parse_function(ObjectName(id_parts))
}
} else {
Ok(Expr::CompoundIdentifier(id_parts))
}
}
// string introducer https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html
Token::SingleQuotedString(_)
| Token::DoubleQuotedString(_)
| Token::HexStringLiteral(_)
if w.value.starts_with('_') =>
{
Ok(Expr::IntroducedString {
introducer: w.value.clone(),
value: self.parse_introduced_string_value()?,
})
}
Token::Arrow if self.dialect.supports_lambda_functions() => {
self.expect_token(&Token::Arrow)?;
Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(w.to_ident()),
body: Box::new(self.parse_expr()?),
}))
}
_ => Ok(Expr::Identifier(w.to_ident())),
}
}
/// Parse an expression prefix.
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
// allow the dialect to override prefix parsing
@ -1073,176 +1250,40 @@ impl<'a> Parser<'a> {
let next_token = self.next_token();
let expr = match next_token.token {
Token::Word(w) => match w.keyword {
Keyword::TRUE | Keyword::FALSE if self.dialect.supports_boolean_literals() => {
self.prev_token();
Ok(Expr::Value(self.parse_value()?))
}
Keyword::NULL => {
self.prev_token();
Ok(Expr::Value(self.parse_value()?))
}
Keyword::CURRENT_CATALOG
| Keyword::CURRENT_USER
| Keyword::SESSION_USER
| Keyword::USER
if dialect_of!(self is PostgreSqlDialect | GenericDialect) =>
{
Ok(Expr::Function(Function {
name: ObjectName(vec![w.to_ident()]),
parameters: FunctionArguments::None,
args: FunctionArguments::None,
null_treatment: None,
filter: None,
over: None,
within_group: vec![],
}))
}
Keyword::CURRENT_TIMESTAMP
| Keyword::CURRENT_TIME
| Keyword::CURRENT_DATE
| Keyword::LOCALTIME
| Keyword::LOCALTIMESTAMP => {
self.parse_time_functions(ObjectName(vec![w.to_ident()]))
}
Keyword::CASE => self.parse_case_expr(),
Keyword::CONVERT => self.parse_convert_expr(false),
Keyword::TRY_CONVERT if self.dialect.supports_try_convert() => self.parse_convert_expr(true),
Keyword::CAST => self.parse_cast_expr(CastKind::Cast),
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
Keyword::EXISTS
// Support parsing Databricks has a function named `exists`.
if !dialect_of!(self is DatabricksDialect)
|| matches!(
self.peek_nth_token(1).token,
Token::Word(Word {
keyword: Keyword::SELECT | Keyword::WITH,
..
})
) =>
{
self.parse_exists_expr(false)
}
Keyword::EXTRACT => self.parse_extract_expr(),
Keyword::CEIL => self.parse_ceil_floor_expr(true),
Keyword::FLOOR => self.parse_ceil_floor_expr(false),
Keyword::POSITION if self.peek_token().token == Token::LParen => {
self.parse_position_expr(w.to_ident())
}
Keyword::SUBSTRING => self.parse_substring_expr(),
Keyword::OVERLAY => self.parse_overlay_expr(),
Keyword::TRIM => self.parse_trim_expr(),
Keyword::INTERVAL => self.parse_interval(),
// Treat ARRAY[1,2,3] as an array [1,2,3], otherwise try as subquery or a function call
Keyword::ARRAY if self.peek_token() == Token::LBracket => {
self.expect_token(&Token::LBracket)?;
self.parse_array_expr(true)
}
Keyword::ARRAY
if self.peek_token() == Token::LParen
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
{
self.expect_token(&Token::LParen)?;
let query = self.parse_query()?;
self.expect_token(&Token::RParen)?;
Ok(Expr::Function(Function {
name: ObjectName(vec![w.to_ident()]),
parameters: FunctionArguments::None,
args: FunctionArguments::Subquery(query),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
}))
}
Keyword::NOT => self.parse_not(),
Keyword::MATCH if dialect_of!(self is MySqlDialect | GenericDialect) => {
self.parse_match_against()
}
Keyword::STRUCT if dialect_of!(self is BigQueryDialect | GenericDialect) => {
self.prev_token();
self.parse_bigquery_struct_literal()
}
Keyword::PRIOR if matches!(self.state, ParserState::ConnectBy) => {
let expr = self.parse_subexpr(self.dialect.prec_value(Precedence::PlusMinus))?;
Ok(Expr::Prior(Box::new(expr)))
}
Keyword::MAP if self.peek_token() == Token::LBrace && self.dialect.support_map_literal_syntax() => {
self.parse_duckdb_map_literal()
}
// Here `w` is a word, check if it's a part of a multipart
// identifier, a function call, or a simple identifier:
_ => match self.peek_token().token {
Token::LParen | Token::Period => {
let mut id_parts: Vec<Ident> = vec![w.to_ident()];
let mut ends_with_wildcard = false;
while self.consume_token(&Token::Period) {
let next_token = self.next_token();
match next_token.token {
Token::Word(w) => id_parts.push(w.to_ident()),
Token::Mul => {
// Postgres explicitly allows funcnm(tablenm.*) and the
// function array_agg traverses this control flow
if dialect_of!(self is PostgreSqlDialect) {
ends_with_wildcard = true;
break;
} else {
return self
.expected("an identifier after '.'", next_token);
}
}
Token::SingleQuotedString(s) => {
id_parts.push(Ident::with_quote('\'', s))
}
_ => {
return self
.expected("an identifier or a '*' after '.'", next_token);
}
}
}
Token::Word(w) => {
// The word we consumed may fall into one of two cases: it has a special meaning, or not.
// For example, in Snowflake, the word `interval` may have two meanings depending on the context:
// `SELECT CURRENT_DATE() + INTERVAL '1 DAY', MAX(interval) FROM tbl;`
// ^^^^^^^^^^^^^^^^ ^^^^^^^^
// interval expression identifier
//
// We first try to parse the word and following tokens as a special expression, and if that fails,
// we rollback and try to parse it as an identifier.
match self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w)) {
// This word indicated an expression prefix and parsing was successful
Ok(Some(expr)) => Ok(expr),
if ends_with_wildcard {
Ok(Expr::QualifiedWildcard(ObjectName(id_parts)))
} else if self.consume_token(&Token::LParen) {
if dialect_of!(self is SnowflakeDialect | MsSqlDialect)
&& self.consume_tokens(&[Token::Plus, Token::RParen])
{
Ok(Expr::OuterJoin(Box::new(
match <[Ident; 1]>::try_from(id_parts) {
Ok([ident]) => Expr::Identifier(ident),
Err(parts) => Expr::CompoundIdentifier(parts),
},
)))
} else {
self.prev_token();
self.parse_function(ObjectName(id_parts))
// No expression prefix associated with this word
Ok(None) => Ok(self.parse_expr_prefix_by_unreserved_word(&w)?),
// If parsing of the word as a special expression failed, we are facing two options:
// 1. The statement is malformed, e.g. `SELECT INTERVAL '1 DAI` (`DAI` instead of `DAY`)
// 2. The word is used as an identifier, e.g. `SELECT MAX(interval) FROM tbl`
// We first try to parse the word as an identifier and if that fails
// we rollback and return the parsing error we got from trying to parse a
// special expression (to maintain backwards compatibility of parsing errors).
Err(e) => {
if !self.dialect.is_reserved_for_identifier(w.keyword) {
if let Ok(Some(expr)) = self.maybe_parse(|parser| {
parser.parse_expr_prefix_by_unreserved_word(&w)
}) {
return Ok(expr);
}
} else {
Ok(Expr::CompoundIdentifier(id_parts))
}
return Err(e);
}
// string introducer https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html
Token::SingleQuotedString(_)
| Token::DoubleQuotedString(_)
| Token::HexStringLiteral(_)
if w.value.starts_with('_') =>
{
Ok(Expr::IntroducedString {
introducer: w.value,
value: self.parse_introduced_string_value()?,
})
}
Token::Arrow if self.dialect.supports_lambda_functions() => {
self.expect_token(&Token::Arrow)?;
return Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(w.to_ident()),
body: Box::new(self.parse_expr()?),
}));
}
_ => Ok(Expr::Identifier(w.to_ident())),
},
}, // End of Token::Word
}
} // End of Token::Word
// array `[1, 2, 3]`
Token::LBracket => self.parse_array_expr(false),
tok @ Token::Minus | tok @ Token::Plus => {
@ -3677,18 +3718,30 @@ impl<'a> Parser<'a> {
}
/// Run a parser method `f`, reverting back to the current position if unsuccessful.
pub fn maybe_parse<T, F>(&mut self, mut f: F) -> Result<Option<T>, ParserError>
/// Returns `None` if `f` returns an error
pub fn maybe_parse<T, F>(&mut self, f: F) -> Result<Option<T>, ParserError>
where
F: FnMut(&mut Parser) -> Result<T, ParserError>,
{
match self.try_parse(f) {
Ok(t) => Ok(Some(t)),
Err(ParserError::RecursionLimitExceeded) => Err(ParserError::RecursionLimitExceeded),
_ => Ok(None),
}
}
/// Run a parser method `f`, reverting back to the current position if unsuccessful.
pub fn try_parse<T, F>(&mut self, mut f: F) -> Result<T, ParserError>
where
F: FnMut(&mut Parser) -> Result<T, ParserError>,
{
let index = self.index;
match f(self) {
Ok(t) => Ok(Some(t)),
// Unwind stack if limit exceeded
Err(ParserError::RecursionLimitExceeded) => Err(ParserError::RecursionLimitExceeded),
Err(_) => {
Ok(t) => Ok(t),
Err(e) => {
// Unwind stack if limit exceeded
self.index = index;
Ok(None)
Err(e)
}
}
}

View file

@ -34,7 +34,7 @@ use sqlparser::dialect::{
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, RedshiftSqlDialect,
SQLiteDialect, SnowflakeDialect,
};
use sqlparser::keywords::ALL_KEYWORDS;
use sqlparser::keywords::{Keyword, ALL_KEYWORDS};
use sqlparser::parser::{Parser, ParserError, ParserOptions};
use sqlparser::tokenizer::Tokenizer;
use test_utils::{
@ -5113,7 +5113,6 @@ fn parse_interval_dont_require_unit() {
#[test]
fn parse_interval_require_unit() {
let dialects = all_dialects_where(|d| d.require_interval_qualifier());
let sql = "SELECT INTERVAL '1 DAY'";
let err = dialects.parse_sql_statements(sql).unwrap_err();
assert_eq!(
@ -12198,3 +12197,21 @@ fn parse_create_table_select() {
);
}
}
#[test]
fn test_reserved_keywords_for_identifiers() {
let dialects = all_dialects_where(|d| d.is_reserved_for_identifier(Keyword::INTERVAL));
// Dialects that reserve the word INTERVAL will not allow it as an unquoted identifier
let sql = "SELECT MAX(interval) FROM tbl";
assert_eq!(
dialects.parse_sql_statements(sql),
Err(ParserError::ParserError(
"Expected: an expression, found: )".to_string()
))
);
// Dialects that do not reserve the word INTERVAL will allow it
let dialects = all_dialects_where(|d| !d.is_reserved_for_identifier(Keyword::INTERVAL));
let sql = "SELECT MAX(interval) FROM tbl";
dialects.parse_sql_statements(sql).unwrap();
}

View file

@ -1352,10 +1352,7 @@ fn parse_set() {
local: false,
hivevar: false,
variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Identifier(Ident {
value: "DEFAULT".into(),
quote_style: None
})],
value: vec![Expr::Identifier(Ident::new("DEFAULT"))],
}
);
@ -4229,10 +4226,7 @@ fn test_simple_postgres_insert_with_alias() {
body: Box::new(SetExpr::Values(Values {
explicit_row: false,
rows: vec![vec![
Expr::Identifier(Ident {
value: "DEFAULT".to_string(),
quote_style: None
}),
Expr::Identifier(Ident::new("DEFAULT")),
Expr::Value(Value::Number("123".to_string(), false))
]]
})),
@ -4295,10 +4289,7 @@ fn test_simple_postgres_insert_with_alias() {
body: Box::new(SetExpr::Values(Values {
explicit_row: false,
rows: vec![vec![
Expr::Identifier(Ident {
value: "DEFAULT".to_string(),
quote_style: None
}),
Expr::Identifier(Ident::new("DEFAULT")),
Expr::Value(Value::Number(
bigdecimal::BigDecimal::new(123.into(), 0),
false
@ -4363,10 +4354,7 @@ fn test_simple_insert_with_quoted_alias() {
body: Box::new(SetExpr::Values(Values {
explicit_row: false,
rows: vec![vec![
Expr::Identifier(Ident {
value: "DEFAULT".to_string(),
quote_style: None
}),
Expr::Identifier(Ident::new("DEFAULT")),
Expr::Value(Value::SingleQuotedString("0123".to_string()))
]]
})),