SET statements: scope modifier for multiple assignments (#1772)

This commit is contained in:
Mohamed Abdeen 2025-03-22 07:38:00 +02:00 committed by GitHub
parent 939fbdd4f6
commit 3a8a3bb7a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 133 additions and 94 deletions

View file

@ -2638,7 +2638,7 @@ pub enum Set {
/// SQL Standard-style /// SQL Standard-style
/// SET a = 1; /// SET a = 1;
SingleAssignment { SingleAssignment {
scope: ContextModifier, scope: Option<ContextModifier>,
hivevar: bool, hivevar: bool,
variable: ObjectName, variable: ObjectName,
values: Vec<Expr>, values: Vec<Expr>,
@ -2668,7 +2668,7 @@ pub enum Set {
/// [4]: https://docs.oracle.com/cd/B19306_01/server.102/b14200/statements_10004.htm /// [4]: https://docs.oracle.com/cd/B19306_01/server.102/b14200/statements_10004.htm
SetRole { SetRole {
/// Non-ANSI optional identifier to inform if the role is defined inside the current session (`SESSION`) or transaction (`LOCAL`). /// Non-ANSI optional identifier to inform if the role is defined inside the current session (`SESSION`) or transaction (`LOCAL`).
context_modifier: ContextModifier, context_modifier: Option<ContextModifier>,
/// Role name. If NONE is specified, then the current role name is removed. /// Role name. If NONE is specified, then the current role name is removed.
role_name: Option<Ident>, role_name: Option<Ident>,
}, },
@ -2720,7 +2720,13 @@ impl Display for Set {
role_name, role_name,
} => { } => {
let role_name = role_name.clone().unwrap_or_else(|| Ident::new("NONE")); let role_name = role_name.clone().unwrap_or_else(|| Ident::new("NONE"));
write!(f, "SET {context_modifier}ROLE {role_name}") write!(
f,
"SET {modifier}ROLE {role_name}",
modifier = context_modifier
.map(|m| format!("{}", m))
.unwrap_or_default()
)
} }
Self::SetSessionParam(kind) => write!(f, "SET {kind}"), Self::SetSessionParam(kind) => write!(f, "SET {kind}"),
Self::SetTransaction { Self::SetTransaction {
@ -2775,7 +2781,7 @@ impl Display for Set {
write!( write!(
f, f,
"SET {}{}{} = {}", "SET {}{}{} = {}",
scope, scope.map(|s| format!("{}", s)).unwrap_or_default(),
if *hivevar { "HIVEVAR:" } else { "" }, if *hivevar { "HIVEVAR:" } else { "" },
variable, variable,
display_comma_separated(values) display_comma_separated(values)
@ -5736,13 +5742,20 @@ impl fmt::Display for SequenceOptions {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct SetAssignment { pub struct SetAssignment {
pub scope: Option<ContextModifier>,
pub name: ObjectName, pub name: ObjectName,
pub value: Expr, pub value: Expr,
} }
impl fmt::Display for SetAssignment { impl fmt::Display for SetAssignment {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} = {}", self.name, self.value) write!(
f,
"{}{} = {}",
self.scope.map(|s| format!("{}", s)).unwrap_or_default(),
self.name,
self.value
)
} }
} }
@ -7969,8 +7982,6 @@ impl fmt::Display for FlushLocation {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ContextModifier { pub enum ContextModifier {
/// No context defined. Each dialect defines the default in this scenario.
None,
/// `LOCAL` identifier, usually related to transactional states. /// `LOCAL` identifier, usually related to transactional states.
Local, Local,
/// `SESSION` identifier /// `SESSION` identifier
@ -7982,9 +7993,6 @@ pub enum ContextModifier {
impl fmt::Display for ContextModifier { impl fmt::Display for ContextModifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
Self::None => {
write!(f, "")
}
Self::Local => { Self::Local => {
write!(f, "LOCAL ") write!(f, "LOCAL ")
} }

View file

@ -159,4 +159,8 @@ impl Dialect for GenericDialect {
fn supports_set_names(&self) -> bool { fn supports_set_names(&self) -> bool {
true true
} }
fn supports_comma_separated_set_assignments(&self) -> bool {
true
}
} }

View file

@ -1822,12 +1822,12 @@ impl<'a> Parser<'a> {
}) })
} }
fn keyword_to_modifier(k: Option<Keyword>) -> ContextModifier { fn keyword_to_modifier(k: Keyword) -> Option<ContextModifier> {
match k { match k {
Some(Keyword::LOCAL) => ContextModifier::Local, Keyword::LOCAL => Some(ContextModifier::Local),
Some(Keyword::GLOBAL) => ContextModifier::Global, Keyword::GLOBAL => Some(ContextModifier::Global),
Some(Keyword::SESSION) => ContextModifier::Session, Keyword::SESSION => Some(ContextModifier::Session),
_ => ContextModifier::None, _ => None,
} }
} }
@ -11157,9 +11157,11 @@ impl<'a> Parser<'a> {
} }
/// Parse a `SET ROLE` statement. Expects SET to be consumed already. /// Parse a `SET ROLE` statement. Expects SET to be consumed already.
fn parse_set_role(&mut self, modifier: Option<Keyword>) -> Result<Statement, ParserError> { fn parse_set_role(
&mut self,
modifier: Option<ContextModifier>,
) -> Result<Statement, ParserError> {
self.expect_keyword_is(Keyword::ROLE)?; self.expect_keyword_is(Keyword::ROLE)?;
let context_modifier = Self::keyword_to_modifier(modifier);
let role_name = if self.parse_keyword(Keyword::NONE) { let role_name = if self.parse_keyword(Keyword::NONE) {
None None
@ -11167,7 +11169,7 @@ impl<'a> Parser<'a> {
Some(self.parse_identifier()?) Some(self.parse_identifier()?)
}; };
Ok(Statement::Set(Set::SetRole { Ok(Statement::Set(Set::SetRole {
context_modifier, context_modifier: modifier,
role_name, role_name,
})) }))
} }
@ -11203,46 +11205,52 @@ impl<'a> Parser<'a> {
} }
} }
fn parse_set_assignment( fn parse_context_modifier(&mut self) -> Option<ContextModifier> {
&mut self, let modifier =
) -> Result<(OneOrManyWithParens<ObjectName>, Expr), ParserError> { self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::GLOBAL])?;
let variables = if self.dialect.supports_parenthesized_set_variables()
Self::keyword_to_modifier(modifier)
}
/// Parse a single SET statement assignment `var = expr`.
fn parse_set_assignment(&mut self) -> Result<SetAssignment, ParserError> {
let scope = self.parse_context_modifier();
let name = if self.dialect.supports_parenthesized_set_variables()
&& self.consume_token(&Token::LParen) && self.consume_token(&Token::LParen)
{ {
let vars = OneOrManyWithParens::Many( // Parenthesized assignments are handled in the `parse_set` function after
self.parse_comma_separated(|parser: &mut Parser<'a>| parser.parse_identifier())? // trying to parse list of assignments using this function.
.into_iter() // If a dialect supports both, and we find a LParen, we early exit from this function.
.map(|ident| ObjectName::from(vec![ident])) self.expected("Unparenthesized assignment", self.peek_token())?
.collect(),
);
self.expect_token(&Token::RParen)?;
vars
} else { } else {
OneOrManyWithParens::One(self.parse_object_name(false)?) self.parse_object_name(false)?
}; };
if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) { if !(self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO)) {
return self.expected("assignment operator", self.peek_token()); return self.expected("assignment operator", self.peek_token());
} }
let values = self.parse_expr()?; let value = self.parse_expr()?;
Ok((variables, values)) Ok(SetAssignment { scope, name, value })
} }
fn parse_set(&mut self) -> Result<Statement, ParserError> { fn parse_set(&mut self) -> Result<Statement, ParserError> {
let modifier = self.parse_one_of_keywords(&[ let hivevar = self.parse_keyword(Keyword::HIVEVAR);
Keyword::SESSION,
Keyword::LOCAL,
Keyword::HIVEVAR,
Keyword::GLOBAL,
]);
if let Some(Keyword::HIVEVAR) = modifier { // Modifier is either HIVEVAR: or a ContextModifier (LOCAL, SESSION, etc), not both
let scope = if !hivevar {
self.parse_context_modifier()
} else {
None
};
if hivevar {
self.expect_token(&Token::Colon)?; self.expect_token(&Token::Colon)?;
} }
if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(modifier))? { if let Some(set_role_stmt) = self.maybe_parse(|parser| parser.parse_set_role(scope))? {
return Ok(set_role_stmt); return Ok(set_role_stmt);
} }
@ -11252,8 +11260,8 @@ impl<'a> Parser<'a> {
{ {
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
return Ok(Set::SingleAssignment { return Ok(Set::SingleAssignment {
scope: Self::keyword_to_modifier(modifier), scope,
hivevar: modifier == Some(Keyword::HIVEVAR), hivevar,
variable: ObjectName::from(vec!["TIMEZONE".into()]), variable: ObjectName::from(vec!["TIMEZONE".into()]),
values: self.parse_set_values(false)?, values: self.parse_set_values(false)?,
} }
@ -11263,7 +11271,7 @@ impl<'a> Parser<'a> {
// the assignment operator. It's originally PostgreSQL specific, // the assignment operator. It's originally PostgreSQL specific,
// but we allow it for all the dialects // but we allow it for all the dialects
return Ok(Set::SetTimeZone { return Ok(Set::SetTimeZone {
local: modifier == Some(Keyword::LOCAL), local: scope == Some(ContextModifier::Local),
value: self.parse_expr()?, value: self.parse_expr()?,
} }
.into()); .into());
@ -11311,41 +11319,26 @@ impl<'a> Parser<'a> {
} }
if self.dialect.supports_comma_separated_set_assignments() { if self.dialect.supports_comma_separated_set_assignments() {
if scope.is_some() {
self.prev_token();
}
if let Some(assignments) = self if let Some(assignments) = self
.maybe_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))? .maybe_parse(|parser| parser.parse_comma_separated(Parser::parse_set_assignment))?
{ {
return if assignments.len() > 1 { return if assignments.len() > 1 {
let assignments = assignments
.into_iter()
.map(|(var, val)| match var {
OneOrManyWithParens::One(v) => Ok(SetAssignment {
name: v,
value: val,
}),
OneOrManyWithParens::Many(_) => {
self.expected("List of single identifiers", self.peek_token())
}
})
.collect::<Result<_, _>>()?;
Ok(Set::MultipleAssignments { assignments }.into()) Ok(Set::MultipleAssignments { assignments }.into())
} else { } else {
let (vars, values): (Vec<_>, Vec<_>) = assignments.into_iter().unzip(); let SetAssignment { scope, name, value } =
assignments.into_iter().next().ok_or_else(|| {
let variable = match vars.into_iter().next() { ParserError::ParserError("Expected at least one assignment".to_string())
Some(OneOrManyWithParens::One(v)) => Ok(v), })?;
Some(OneOrManyWithParens::Many(_)) => self.expected(
"Single assignment or list of assignments",
self.peek_token(),
),
None => self.expected("At least one identifier", self.peek_token()),
}?;
Ok(Set::SingleAssignment { Ok(Set::SingleAssignment {
scope: Self::keyword_to_modifier(modifier), scope,
hivevar: modifier == Some(Keyword::HIVEVAR), hivevar,
variable, variable: name,
values, values: vec![value],
} }
.into()) .into())
}; };
@ -11370,8 +11363,8 @@ impl<'a> Parser<'a> {
if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
let stmt = match variables { let stmt = match variables {
OneOrManyWithParens::One(var) => Set::SingleAssignment { OneOrManyWithParens::One(var) => Set::SingleAssignment {
scope: Self::keyword_to_modifier(modifier), scope,
hivevar: modifier == Some(Keyword::HIVEVAR), hivevar,
variable: var, variable: var,
values: self.parse_set_values(false)?, values: self.parse_set_values(false)?,
}, },

View file

@ -8635,7 +8635,7 @@ fn parse_set_variable() {
variable, variable,
values, values,
}) => { }) => {
assert_eq!(scope, ContextModifier::None); assert_eq!(scope, None);
assert!(!hivevar); assert!(!hivevar);
assert_eq!(variable, ObjectName::from(vec!["SOMETHING".into()])); assert_eq!(variable, ObjectName::from(vec!["SOMETHING".into()]));
assert_eq!( assert_eq!(
@ -8655,7 +8655,7 @@ fn parse_set_variable() {
variable, variable,
values, values,
}) => { }) => {
assert_eq!(scope, ContextModifier::Global); assert_eq!(scope, Some(ContextModifier::Global));
assert!(!hivevar); assert!(!hivevar);
assert_eq!(variable, ObjectName::from(vec!["VARIABLE".into()])); assert_eq!(variable, ObjectName::from(vec!["VARIABLE".into()]));
assert_eq!( assert_eq!(
@ -8747,7 +8747,7 @@ fn parse_set_role_as_variable() {
variable, variable,
values, values,
}) => { }) => {
assert_eq!(scope, ContextModifier::None); assert_eq!(scope, None);
assert!(!hivevar); assert!(!hivevar);
assert_eq!(variable, ObjectName::from(vec!["role".into()])); assert_eq!(variable, ObjectName::from(vec!["role".into()]));
assert_eq!( assert_eq!(
@ -8794,7 +8794,7 @@ fn parse_set_time_zone() {
variable, variable,
values, values,
}) => { }) => {
assert_eq!(scope, ContextModifier::None); assert_eq!(scope, None);
assert!(!hivevar); assert!(!hivevar);
assert_eq!(variable, ObjectName::from(vec!["TIMEZONE".into()])); assert_eq!(variable, ObjectName::from(vec!["TIMEZONE".into()]));
assert_eq!( assert_eq!(
@ -14859,10 +14859,12 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> {
assignments, assignments,
vec![ vec![
SetAssignment { SetAssignment {
scope: None,
name: ObjectName::from(vec!["@a".into()]), name: ObjectName::from(vec!["@a".into()]),
value: Expr::value(number("1")) value: Expr::value(number("1"))
}, },
SetAssignment { SetAssignment {
scope: None,
name: ObjectName::from(vec!["b".into()]), name: ObjectName::from(vec!["b".into()]),
value: Expr::value(number("2")) value: Expr::value(number("2"))
} }
@ -14872,6 +14874,39 @@ fn parse_multiple_set_statements() -> Result<(), ParserError> {
_ => panic!("Expected SetVariable with 2 variables and 2 values"), _ => panic!("Expected SetVariable with 2 variables and 2 values"),
}; };
let stmt = dialects.verified_stmt("SET GLOBAL @a = 1, SESSION b = 2, LOCAL c = 3, d = 4");
match stmt {
Statement::Set(Set::MultipleAssignments { assignments }) => {
assert_eq!(
assignments,
vec![
SetAssignment {
scope: Some(ContextModifier::Global),
name: ObjectName::from(vec!["@a".into()]),
value: Expr::value(number("1"))
},
SetAssignment {
scope: Some(ContextModifier::Session),
name: ObjectName::from(vec!["b".into()]),
value: Expr::value(number("2"))
},
SetAssignment {
scope: Some(ContextModifier::Local),
name: ObjectName::from(vec!["c".into()]),
value: Expr::value(number("3"))
},
SetAssignment {
scope: None,
name: ObjectName::from(vec!["d".into()]),
value: Expr::value(number("4"))
}
]
);
}
_ => panic!("Expected MultipleAssignments with 4 scoped variables and 4 values"),
};
Ok(()) Ok(())
} }

View file

@ -21,10 +21,9 @@
//! is also tested (on the inputs it can handle). //! is also tested (on the inputs it can handle).
use sqlparser::ast::{ use sqlparser::ast::{
ClusteredBy, CommentDef, ContextModifier, CreateFunction, CreateFunctionBody, ClusteredBy, CommentDef, CreateFunction, CreateFunctionBody, CreateFunctionUsing, CreateTable,
CreateFunctionUsing, CreateTable, Expr, Function, FunctionArgumentList, FunctionArguments, Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
Ident, ObjectName, OrderByExpr, OrderByOptions, SelectItem, Set, Statement, TableFactor, OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
UnaryOperator, Use, Value,
}; };
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect}; use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
use sqlparser::parser::ParserError; use sqlparser::parser::ParserError;
@ -370,7 +369,7 @@ fn set_statement_with_minus() {
assert_eq!( assert_eq!(
hive().verified_stmt("SET hive.tez.java.opts = -Xmx4g"), hive().verified_stmt("SET hive.tez.java.opts = -Xmx4g"),
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![ variable: ObjectName::from(vec![
Ident::new("hive"), Ident::new("hive"),

View file

@ -1252,7 +1252,7 @@ fn parse_mssql_declare() {
}] }]
}, },
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![Ident::new("@bar")]), variable: ObjectName::from(vec![Ident::new("@bar")]),
values: vec![Expr::Value( values: vec![Expr::Value(

View file

@ -618,7 +618,7 @@ fn parse_set_variables() {
assert_eq!( assert_eq!(
mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"), mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"),
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::Local, scope: Some(ContextModifier::Local),
hivevar: false, hivevar: false,
variable: ObjectName::from(vec!["autocommit".into()]), variable: ObjectName::from(vec!["autocommit".into()]),
values: vec![Expr::value(number("1"))], values: vec![Expr::value(number("1"))],

View file

@ -1432,7 +1432,7 @@ fn parse_set() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![Ident::new("a")]), variable: ObjectName::from(vec![Ident::new("a")]),
values: vec![Expr::Identifier(Ident { values: vec![Expr::Identifier(Ident {
@ -1447,7 +1447,7 @@ fn parse_set() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![Ident::new("a")]), variable: ObjectName::from(vec![Ident::new("a")]),
values: vec![Expr::Value( values: vec![Expr::Value(
@ -1460,7 +1460,7 @@ fn parse_set() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![Ident::new("a")]), variable: ObjectName::from(vec![Ident::new("a")]),
values: vec![Expr::value(number("0"))], values: vec![Expr::value(number("0"))],
@ -1471,7 +1471,7 @@ fn parse_set() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![Ident::new("a")]), variable: ObjectName::from(vec![Ident::new("a")]),
values: vec![Expr::Identifier(Ident::new("DEFAULT"))], values: vec![Expr::Identifier(Ident::new("DEFAULT"))],
@ -1482,7 +1482,7 @@ fn parse_set() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::Local, scope: Some(ContextModifier::Local),
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![Ident::new("a")]), variable: ObjectName::from(vec![Ident::new("a")]),
values: vec![Expr::Identifier("b".into())], values: vec![Expr::Identifier("b".into())],
@ -1493,7 +1493,7 @@ fn parse_set() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]), variable: ObjectName::from(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]),
values: vec![Expr::Identifier(Ident { values: vec![Expr::Identifier(Ident {
@ -1511,7 +1511,7 @@ fn parse_set() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SingleAssignment { Statement::Set(Set::SingleAssignment {
scope: ContextModifier::None, scope: None,
hivevar: false, hivevar: false,
variable: ObjectName::from(vec![ variable: ObjectName::from(vec![
Ident::new("hive"), Ident::new("hive"),
@ -1555,7 +1555,7 @@ fn parse_set_role() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SetRole { Statement::Set(Set::SetRole {
context_modifier: ContextModifier::Session, context_modifier: Some(ContextModifier::Session),
role_name: None, role_name: None,
}) })
); );
@ -1566,7 +1566,7 @@ fn parse_set_role() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SetRole { Statement::Set(Set::SetRole {
context_modifier: ContextModifier::Local, context_modifier: Some(ContextModifier::Local),
role_name: Some(Ident { role_name: Some(Ident {
value: "rolename".to_string(), value: "rolename".to_string(),
quote_style: Some('\"'), quote_style: Some('\"'),
@ -1581,7 +1581,7 @@ fn parse_set_role() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::Set(Set::SetRole { Statement::Set(Set::SetRole {
context_modifier: ContextModifier::None, context_modifier: None,
role_name: Some(Ident { role_name: Some(Ident {
value: "rolename".to_string(), value: "rolename".to_string(),
quote_style: Some('\''), quote_style: Some('\''),