Support multiple SET variables (#1252)

This commit is contained in:
Ifeanyi Ubah 2024-05-07 18:51:39 +02:00 committed by GitHub
parent c4f3ef9600
commit eb36bd7138
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 145 additions and 33 deletions

View file

@ -2275,7 +2275,8 @@ pub enum Statement {
role_name: Option<Ident>, role_name: Option<Ident>,
}, },
/// ```sql /// ```sql
/// SET <variable> /// SET <variable> = expression;
/// SET (variable[, ...]) = (expression[, ...]);
/// ``` /// ```
/// ///
/// Note: this is not a standard SQL statement, but it is supported by at /// Note: this is not a standard SQL statement, but it is supported by at
@ -2284,7 +2285,7 @@ pub enum Statement {
SetVariable { SetVariable {
local: bool, local: bool,
hivevar: bool, hivevar: bool,
variable: ObjectName, variables: OneOrManyWithParens<ObjectName>,
value: Vec<Expr>, value: Vec<Expr>,
}, },
/// ```sql /// ```sql
@ -3823,7 +3824,7 @@ impl fmt::Display for Statement {
} }
Statement::SetVariable { Statement::SetVariable {
local, local,
variable, variables,
hivevar, hivevar,
value, value,
} => { } => {
@ -3831,12 +3832,15 @@ impl fmt::Display for Statement {
if *local { if *local {
f.write_str("LOCAL ")?; f.write_str("LOCAL ")?;
} }
let parenthesized = matches!(variables, OneOrManyWithParens::Many(_));
write!( write!(
f, f,
"{hivevar}{name} = {value}", "{hivevar}{name} = {l_paren}{value}{r_paren}",
hivevar = if *hivevar { "HIVEVAR:" } else { "" }, hivevar = if *hivevar { "HIVEVAR:" } else { "" },
name = variable, name = variables,
value = display_comma_separated(value) l_paren = parenthesized.then_some("(").unwrap_or_default(),
value = display_comma_separated(value),
r_paren = parenthesized.then_some(")").unwrap_or_default(),
) )
} }
Statement::SetTimeZone { local, value } => { Statement::SetTimeZone { local, value } => {

View file

@ -44,4 +44,9 @@ impl Dialect for BigQueryDialect {
fn supports_window_clause_named_window_reference(&self) -> bool { fn supports_window_clause_named_window_reference(&self) -> bool {
true true
} }
/// See [doc](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#set)
fn supports_parenthesized_set_variables(&self) -> bool {
true
}
} }

View file

@ -62,4 +62,8 @@ impl Dialect for GenericDialect {
fn supports_window_clause_named_window_reference(&self) -> bool { fn supports_window_clause_named_window_reference(&self) -> bool {
true true
} }
fn supports_parenthesized_set_variables(&self) -> bool {
true
}
} }

View file

@ -217,6 +217,15 @@ pub trait Dialect: Debug + Any {
fn supports_lambda_functions(&self) -> bool { fn supports_lambda_functions(&self) -> bool {
false false
} }
/// Returns true if the dialect supports multiple variable assignment
/// using parentheses in a `SET` variable declaration.
///
/// ```sql
/// SET (variable[, ...]) = (expression[, ...]);
/// ```
fn supports_parenthesized_set_variables(&self) -> bool {
false
}
/// Returns true if the dialect has a CONVERT function which accepts a type first /// Returns true if the dialect has a CONVERT function which accepts a type first
/// and an expression second, e.g. `CONVERT(varchar, 1)` /// and an expression second, e.g. `CONVERT(varchar, 1)`
fn convert_type_before_value(&self) -> bool { fn convert_type_before_value(&self) -> bool {

View file

@ -71,6 +71,11 @@ impl Dialect for SnowflakeDialect {
true true
} }
/// See [doc](https://docs.snowflake.com/en/sql-reference/sql/set#syntax)
fn supports_parenthesized_set_variables(&self) -> bool {
true
}
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> { fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::CREATE) { if parser.parse_keyword(Keyword::CREATE) {
// possibly CREATE STAGE // possibly CREATE STAGE

View file

@ -8030,14 +8030,27 @@ impl<'a> Parser<'a> {
}); });
} }
let variable = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) { let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
ObjectName(vec!["TIMEZONE".into()]) OneOrManyWithParens::One(ObjectName(vec!["TIMEZONE".into()]))
} else if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let variables = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| {
parser.parse_identifier(false)
})?
.into_iter()
.map(|ident| ObjectName(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
variables
} else { } else {
self.parse_object_name(false)? OneOrManyWithParens::One(self.parse_object_name(false)?)
}; };
if variable.to_string().eq_ignore_ascii_case("NAMES") if matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES")
&& dialect_of!(self is MySqlDialect | GenericDialect) && dialect_of!(self is MySqlDialect | GenericDialect))
{ {
if self.parse_keyword(Keyword::DEFAULT) { if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::SetNamesDefault {}); return Ok(Statement::SetNamesDefault {});
@ -8050,11 +8063,19 @@ impl<'a> Parser<'a> {
None None
}; };
Ok(Statement::SetNames { return Ok(Statement::SetNames {
charset_name, charset_name,
collation_name, collation_name,
}) });
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { }
let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}
let mut values = vec![]; let mut values = vec![];
loop { loop {
let value = if let Ok(expr) = self.parse_expr() { let value = if let Ok(expr) = self.parse_expr() {
@ -8067,14 +8088,24 @@ impl<'a> Parser<'a> {
if self.consume_token(&Token::Comma) { if self.consume_token(&Token::Comma) {
continue; continue;
} }
if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(Statement::SetVariable { return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL), local: modifier == Some(Keyword::LOCAL),
hivevar: Some(Keyword::HIVEVAR) == modifier, hivevar: Some(Keyword::HIVEVAR) == modifier,
variable, variables,
value: values, value: values,
}); });
} }
} else if variable.to_string().eq_ignore_ascii_case("TIMEZONE") { }
let OneOrManyWithParens::One(variable) = variables else {
return self.expected("set variable", self.peek_token());
};
if variable.to_string().eq_ignore_ascii_case("TIMEZONE") {
// for some db (e.g. postgresql), SET TIME ZONE <value> is an alias for SET TIMEZONE [TO|=] <value> // for some db (e.g. postgresql), SET TIME ZONE <value> is an alias for SET TIMEZONE [TO|=] <value>
match self.parse_expr() { match self.parse_expr() {
Ok(expr) => Ok(Statement::SetTimeZone { Ok(expr) => Ok(Statement::SetTimeZone {

View file

@ -6958,12 +6958,15 @@ fn parse_set_variable() {
Statement::SetVariable { Statement::SetVariable {
local, local,
hivevar, hivevar,
variable, variables,
value, value,
} => { } => {
assert!(!local); assert!(!local);
assert!(!hivevar); assert!(!hivevar);
assert_eq!(variable, ObjectName(vec!["SOMETHING".into()])); assert_eq!(
variables,
OneOrManyWithParens::One(ObjectName(vec!["SOMETHING".into()]))
);
assert_eq!( assert_eq!(
value, value,
vec![Expr::Value(Value::SingleQuotedString("1".into()))] vec![Expr::Value(Value::SingleQuotedString("1".into()))]
@ -6972,6 +6975,50 @@ fn parse_set_variable() {
_ => unreachable!(), _ => unreachable!(),
} }
let multi_variable_dialects = all_dialects_where(|d| d.supports_parenthesized_set_variables());
let sql = r#"SET (a, b, c) = (1, 2, 3)"#;
match multi_variable_dialects.verified_stmt(sql) {
Statement::SetVariable {
local,
hivevar,
variables,
value,
} => {
assert!(!local);
assert!(!hivevar);
assert_eq!(
variables,
OneOrManyWithParens::Many(vec![
ObjectName(vec!["a".into()]),
ObjectName(vec!["b".into()]),
ObjectName(vec!["c".into()]),
])
);
assert_eq!(
value,
vec![
Expr::Value(number("1")),
Expr::Value(number("2")),
Expr::Value(number("3")),
]
);
}
_ => unreachable!(),
}
let error_sqls = [
("SET (a, b, c) = (1, 2, 3", "Expected ), found: EOF"),
("SET (a, b, c) = 1, 2, 3", "Expected (, found: 1"),
];
for (sql, error) in error_sqls {
assert_eq!(
ParserError::ParserError(error.to_string()),
multi_variable_dialects
.parse_sql_statements(sql)
.unwrap_err()
);
}
one_statement_parses_to("SET SOMETHING TO '1'", "SET SOMETHING = '1'"); one_statement_parses_to("SET SOMETHING TO '1'", "SET SOMETHING = '1'");
} }
@ -6981,12 +7028,15 @@ fn parse_set_time_zone() {
Statement::SetVariable { Statement::SetVariable {
local, local,
hivevar, hivevar,
variable, variables: variable,
value, value,
} => { } => {
assert!(!local); assert!(!local);
assert!(!hivevar); assert!(!hivevar);
assert_eq!(variable, ObjectName(vec!["TIMEZONE".into()])); assert_eq!(
variable,
OneOrManyWithParens::One(ObjectName(vec!["TIMEZONE".into()]))
);
assert_eq!( assert_eq!(
value, value,
vec![Expr::Value(Value::SingleQuotedString("UTC".into()))] vec![Expr::Value(Value::SingleQuotedString("UTC".into()))]

View file

@ -17,8 +17,8 @@
use sqlparser::ast::{ use sqlparser::ast::{
CreateFunctionBody, CreateFunctionUsing, Expr, Function, FunctionArgumentList, CreateFunctionBody, CreateFunctionUsing, Expr, Function, FunctionArgumentList,
FunctionArguments, FunctionDefinition, Ident, ObjectName, SelectItem, Statement, TableFactor, FunctionArguments, FunctionDefinition, Ident, ObjectName, OneOrManyWithParens, SelectItem,
UnaryOperator, Statement, TableFactor, UnaryOperator,
}; };
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect}; use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
use sqlparser::parser::{ParserError, ParserOptions}; use sqlparser::parser::{ParserError, ParserOptions};
@ -268,12 +268,12 @@ fn set_statement_with_minus() {
Statement::SetVariable { Statement::SetVariable {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![ variables: OneOrManyWithParens::One(ObjectName(vec![
Ident::new("hive"), Ident::new("hive"),
Ident::new("tez"), Ident::new("tez"),
Ident::new("java"), Ident::new("java"),
Ident::new("opts") Ident::new("opts")
]), ])),
value: vec![Expr::UnaryOp { value: vec![Expr::UnaryOp {
op: UnaryOperator::Minus, op: UnaryOperator::Minus,
expr: Box::new(Expr::Identifier(Ident::new("Xmx4g"))) expr: Box::new(Expr::Identifier(Ident::new("Xmx4g")))

View file

@ -460,7 +460,7 @@ fn parse_set_variables() {
Statement::SetVariable { Statement::SetVariable {
local: true, local: true,
hivevar: false, hivevar: false,
variable: ObjectName(vec!["autocommit".into()]), variables: OneOrManyWithParens::One(ObjectName(vec!["autocommit".into()])),
value: vec![Expr::Value(number("1"))], value: vec![Expr::Value(number("1"))],
} }
); );

View file

@ -1201,7 +1201,7 @@ fn parse_set() {
Statement::SetVariable { Statement::SetVariable {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Identifier(Ident { value: vec![Expr::Identifier(Ident {
value: "b".into(), value: "b".into(),
quote_style: None quote_style: None
@ -1215,7 +1215,7 @@ fn parse_set() {
Statement::SetVariable { Statement::SetVariable {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Value(Value::SingleQuotedString("b".into()))], value: vec![Expr::Value(Value::SingleQuotedString("b".into()))],
} }
); );
@ -1226,7 +1226,7 @@ fn parse_set() {
Statement::SetVariable { Statement::SetVariable {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Value(number("0"))], value: vec![Expr::Value(number("0"))],
} }
); );
@ -1237,7 +1237,7 @@ fn parse_set() {
Statement::SetVariable { Statement::SetVariable {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Identifier(Ident { value: vec![Expr::Identifier(Ident {
value: "DEFAULT".into(), value: "DEFAULT".into(),
quote_style: None quote_style: None
@ -1251,7 +1251,7 @@ fn parse_set() {
Statement::SetVariable { Statement::SetVariable {
local: true, local: true,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Identifier("b".into())], value: vec![Expr::Identifier("b".into())],
} }
); );
@ -1262,7 +1262,11 @@ fn parse_set() {
Statement::SetVariable { Statement::SetVariable {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]), variables: OneOrManyWithParens::One(ObjectName(vec![
Ident::new("a"),
Ident::new("b"),
Ident::new("c")
])),
value: vec![Expr::Identifier(Ident { value: vec![Expr::Identifier(Ident {
value: "b".into(), value: "b".into(),
quote_style: None quote_style: None
@ -1279,13 +1283,13 @@ fn parse_set() {
Statement::SetVariable { Statement::SetVariable {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![ variables: OneOrManyWithParens::One(ObjectName(vec![
Ident::new("hive"), Ident::new("hive"),
Ident::new("tez"), Ident::new("tez"),
Ident::new("auto"), Ident::new("auto"),
Ident::new("reducer"), Ident::new("reducer"),
Ident::new("parallelism") Ident::new("parallelism")
]), ])),
value: vec![Expr::Value(Value::Boolean(false))], value: vec![Expr::Value(Value::Boolean(false))],
} }
); );