Support multiple SET variables (#1252)

This commit is contained in:
Ifeanyi Ubah 2024-05-07 18:51:39 +02:00 committed by GitHub
parent c4f3ef9600
commit eb36bd7138
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 145 additions and 33 deletions

View file

@ -2275,7 +2275,8 @@ pub enum Statement {
role_name: Option<Ident>,
},
/// ```sql
/// SET <variable>
/// SET <variable> = expression;
/// SET (variable[, ...]) = (expression[, ...]);
/// ```
///
/// Note: this is not a standard SQL statement, but it is supported by at
@ -2284,7 +2285,7 @@ pub enum Statement {
SetVariable {
local: bool,
hivevar: bool,
variable: ObjectName,
variables: OneOrManyWithParens<ObjectName>,
value: Vec<Expr>,
},
/// ```sql
@ -3823,7 +3824,7 @@ impl fmt::Display for Statement {
}
Statement::SetVariable {
local,
variable,
variables,
hivevar,
value,
} => {
@ -3831,12 +3832,15 @@ impl fmt::Display for Statement {
if *local {
f.write_str("LOCAL ")?;
}
let parenthesized = matches!(variables, OneOrManyWithParens::Many(_));
write!(
f,
"{hivevar}{name} = {value}",
"{hivevar}{name} = {l_paren}{value}{r_paren}",
hivevar = if *hivevar { "HIVEVAR:" } else { "" },
name = variable,
value = display_comma_separated(value)
name = variables,
l_paren = parenthesized.then_some("(").unwrap_or_default(),
value = display_comma_separated(value),
r_paren = parenthesized.then_some(")").unwrap_or_default(),
)
}
Statement::SetTimeZone { local, value } => {

View file

@ -44,4 +44,9 @@ impl Dialect for BigQueryDialect {
fn supports_window_clause_named_window_reference(&self) -> bool {
true
}
/// See [doc](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#set)
fn supports_parenthesized_set_variables(&self) -> bool {
true
}
}

View file

@ -62,4 +62,8 @@ impl Dialect for GenericDialect {
fn supports_window_clause_named_window_reference(&self) -> bool {
true
}
fn supports_parenthesized_set_variables(&self) -> bool {
true
}
}

View file

@ -217,6 +217,15 @@ pub trait Dialect: Debug + Any {
fn supports_lambda_functions(&self) -> bool {
false
}
/// Returns true if the dialect supports multiple variable assignment
/// using parentheses in a `SET` variable declaration.
///
/// ```sql
/// SET (variable[, ...]) = (expression[, ...]);
/// ```
fn supports_parenthesized_set_variables(&self) -> bool {
false
}
/// Returns true if the dialect has a CONVERT function which accepts a type first
/// and an expression second, e.g. `CONVERT(varchar, 1)`
fn convert_type_before_value(&self) -> bool {

View file

@ -71,6 +71,11 @@ impl Dialect for SnowflakeDialect {
true
}
/// See [doc](https://docs.snowflake.com/en/sql-reference/sql/set#syntax)
fn supports_parenthesized_set_variables(&self) -> bool {
true
}
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::CREATE) {
// possibly CREATE STAGE

View file

@ -8030,14 +8030,27 @@ impl<'a> Parser<'a> {
});
}
let variable = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
ObjectName(vec!["TIMEZONE".into()])
let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
OneOrManyWithParens::One(ObjectName(vec!["TIMEZONE".into()]))
} else if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen)
{
let variables = OneOrManyWithParens::Many(
self.parse_comma_separated(|parser: &mut Parser<'a>| {
parser.parse_identifier(false)
})?
.into_iter()
.map(|ident| ObjectName(vec![ident]))
.collect(),
);
self.expect_token(&Token::RParen)?;
variables
} else {
self.parse_object_name(false)?
OneOrManyWithParens::One(self.parse_object_name(false)?)
};
if variable.to_string().eq_ignore_ascii_case("NAMES")
&& dialect_of!(self is MySqlDialect | GenericDialect)
if matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES")
&& dialect_of!(self is MySqlDialect | GenericDialect))
{
if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::SetNamesDefault {});
@ -8050,11 +8063,19 @@ impl<'a> Parser<'a> {
None
};
Ok(Statement::SetNames {
return Ok(Statement::SetNames {
charset_name,
collation_name,
})
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
});
}
let parenthesized_assignment = matches!(&variables, OneOrManyWithParens::Many(_));
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
if parenthesized_assignment {
self.expect_token(&Token::LParen)?;
}
let mut values = vec![];
loop {
let value = if let Ok(expr) = self.parse_expr() {
@ -8067,14 +8088,24 @@ impl<'a> Parser<'a> {
if self.consume_token(&Token::Comma) {
continue;
}
if parenthesized_assignment {
self.expect_token(&Token::RParen)?;
}
return Ok(Statement::SetVariable {
local: modifier == Some(Keyword::LOCAL),
hivevar: Some(Keyword::HIVEVAR) == modifier,
variable,
variables,
value: values,
});
}
} else if variable.to_string().eq_ignore_ascii_case("TIMEZONE") {
}
let OneOrManyWithParens::One(variable) = variables else {
return self.expected("set variable", self.peek_token());
};
if variable.to_string().eq_ignore_ascii_case("TIMEZONE") {
// for some db (e.g. postgresql), SET TIME ZONE <value> is an alias for SET TIMEZONE [TO|=] <value>
match self.parse_expr() {
Ok(expr) => Ok(Statement::SetTimeZone {

View file

@ -6958,12 +6958,15 @@ fn parse_set_variable() {
Statement::SetVariable {
local,
hivevar,
variable,
variables,
value,
} => {
assert!(!local);
assert!(!hivevar);
assert_eq!(variable, ObjectName(vec!["SOMETHING".into()]));
assert_eq!(
variables,
OneOrManyWithParens::One(ObjectName(vec!["SOMETHING".into()]))
);
assert_eq!(
value,
vec![Expr::Value(Value::SingleQuotedString("1".into()))]
@ -6972,6 +6975,50 @@ fn parse_set_variable() {
_ => unreachable!(),
}
let multi_variable_dialects = all_dialects_where(|d| d.supports_parenthesized_set_variables());
let sql = r#"SET (a, b, c) = (1, 2, 3)"#;
match multi_variable_dialects.verified_stmt(sql) {
Statement::SetVariable {
local,
hivevar,
variables,
value,
} => {
assert!(!local);
assert!(!hivevar);
assert_eq!(
variables,
OneOrManyWithParens::Many(vec![
ObjectName(vec!["a".into()]),
ObjectName(vec!["b".into()]),
ObjectName(vec!["c".into()]),
])
);
assert_eq!(
value,
vec![
Expr::Value(number("1")),
Expr::Value(number("2")),
Expr::Value(number("3")),
]
);
}
_ => unreachable!(),
}
let error_sqls = [
("SET (a, b, c) = (1, 2, 3", "Expected ), found: EOF"),
("SET (a, b, c) = 1, 2, 3", "Expected (, found: 1"),
];
for (sql, error) in error_sqls {
assert_eq!(
ParserError::ParserError(error.to_string()),
multi_variable_dialects
.parse_sql_statements(sql)
.unwrap_err()
);
}
one_statement_parses_to("SET SOMETHING TO '1'", "SET SOMETHING = '1'");
}
@ -6981,12 +7028,15 @@ fn parse_set_time_zone() {
Statement::SetVariable {
local,
hivevar,
variable,
variables: variable,
value,
} => {
assert!(!local);
assert!(!hivevar);
assert_eq!(variable, ObjectName(vec!["TIMEZONE".into()]));
assert_eq!(
variable,
OneOrManyWithParens::One(ObjectName(vec!["TIMEZONE".into()]))
);
assert_eq!(
value,
vec![Expr::Value(Value::SingleQuotedString("UTC".into()))]

View file

@ -17,8 +17,8 @@
use sqlparser::ast::{
CreateFunctionBody, CreateFunctionUsing, Expr, Function, FunctionArgumentList,
FunctionArguments, FunctionDefinition, Ident, ObjectName, SelectItem, Statement, TableFactor,
UnaryOperator,
FunctionArguments, FunctionDefinition, Ident, ObjectName, OneOrManyWithParens, SelectItem,
Statement, TableFactor, UnaryOperator,
};
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
use sqlparser::parser::{ParserError, ParserOptions};
@ -268,12 +268,12 @@ fn set_statement_with_minus() {
Statement::SetVariable {
local: false,
hivevar: false,
variable: ObjectName(vec![
variables: OneOrManyWithParens::One(ObjectName(vec![
Ident::new("hive"),
Ident::new("tez"),
Ident::new("java"),
Ident::new("opts")
]),
])),
value: vec![Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Identifier(Ident::new("Xmx4g")))

View file

@ -460,7 +460,7 @@ fn parse_set_variables() {
Statement::SetVariable {
local: true,
hivevar: false,
variable: ObjectName(vec!["autocommit".into()]),
variables: OneOrManyWithParens::One(ObjectName(vec!["autocommit".into()])),
value: vec![Expr::Value(number("1"))],
}
);

View file

@ -1201,7 +1201,7 @@ fn parse_set() {
Statement::SetVariable {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
@ -1215,7 +1215,7 @@ fn parse_set() {
Statement::SetVariable {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Value(Value::SingleQuotedString("b".into()))],
}
);
@ -1226,7 +1226,7 @@ fn parse_set() {
Statement::SetVariable {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Value(number("0"))],
}
);
@ -1237,7 +1237,7 @@ fn parse_set() {
Statement::SetVariable {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Identifier(Ident {
value: "DEFAULT".into(),
quote_style: None
@ -1251,7 +1251,7 @@ fn parse_set() {
Statement::SetVariable {
local: true,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])),
value: vec![Expr::Identifier("b".into())],
}
);
@ -1262,7 +1262,11 @@ fn parse_set() {
Statement::SetVariable {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]),
variables: OneOrManyWithParens::One(ObjectName(vec![
Ident::new("a"),
Ident::new("b"),
Ident::new("c")
])),
value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
@ -1279,13 +1283,13 @@ fn parse_set() {
Statement::SetVariable {
local: false,
hivevar: false,
variable: ObjectName(vec![
variables: OneOrManyWithParens::One(ObjectName(vec![
Ident::new("hive"),
Ident::new("tez"),
Ident::new("auto"),
Ident::new("reducer"),
Ident::new("parallelism")
]),
])),
value: vec![Expr::Value(Value::Boolean(false))],
}
);