Replace parallel condition/result vectors with single CaseWhen vector in Expr::Case (#1733)

This commit is contained in:
Ophir LOJKINE 2025-02-22 07:23:36 +01:00 committed by GitHub
parent 7fc37a76e5
commit 72312ba82a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 160 additions and 50 deletions

View file

@ -600,6 +600,22 @@ pub enum CeilFloorKind {
Scale(Value), Scale(Value),
} }
/// A WHEN clause in a CASE expression containing both
/// the condition and its corresponding result
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseWhen {
pub condition: Expr,
pub result: Expr,
}
impl fmt::Display for CaseWhen {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WHEN {} THEN {}", self.condition, self.result)
}
}
/// An SQL expression of any type. /// An SQL expression of any type.
/// ///
/// # Semantics / Type Checking /// # Semantics / Type Checking
@ -918,8 +934,7 @@ pub enum Expr {
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause> /// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
Case { Case {
operand: Option<Box<Expr>>, operand: Option<Box<Expr>>,
conditions: Vec<Expr>, conditions: Vec<CaseWhen>,
results: Vec<Expr>,
else_result: Option<Box<Expr>>, else_result: Option<Box<Expr>>,
}, },
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like /// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
@ -1621,17 +1636,15 @@ impl fmt::Display for Expr {
Expr::Case { Expr::Case {
operand, operand,
conditions, conditions,
results,
else_result, else_result,
} => { } => {
write!(f, "CASE")?; write!(f, "CASE")?;
if let Some(operand) = operand { if let Some(operand) = operand {
write!(f, " {operand}")?; write!(f, " {operand}")?;
} }
for (c, r) in conditions.iter().zip(results) { for when in conditions {
write!(f, " WHEN {c} THEN {r}")?; write!(f, " {when}")?;
} }
if let Some(else_result) = else_result { if let Some(else_result) = else_result {
write!(f, " ELSE {else_result}")?; write!(f, " ELSE {else_result}")?;
} }

View file

@ -1450,15 +1450,15 @@ impl Spanned for Expr {
Expr::Case { Expr::Case {
operand, operand,
conditions, conditions,
results,
else_result, else_result,
} => union_spans( } => union_spans(
operand operand
.as_ref() .as_ref()
.map(|i| i.span()) .map(|i| i.span())
.into_iter() .into_iter()
.chain(conditions.iter().map(|i| i.span())) .chain(conditions.iter().flat_map(|case_when| {
.chain(results.iter().map(|i| i.span())) [case_when.condition.span(), case_when.result.span()]
}))
.chain(else_result.as_ref().map(|i| i.span())), .chain(else_result.as_ref().map(|i| i.span())),
), ),
Expr::Exists { subquery, .. } => subquery.span(), Expr::Exists { subquery, .. } => subquery.span(),

View file

@ -2065,11 +2065,11 @@ impl<'a> Parser<'a> {
self.expect_keyword_is(Keyword::WHEN)?; self.expect_keyword_is(Keyword::WHEN)?;
} }
let mut conditions = vec![]; let mut conditions = vec![];
let mut results = vec![];
loop { loop {
conditions.push(self.parse_expr()?); let condition = self.parse_expr()?;
self.expect_keyword_is(Keyword::THEN)?; self.expect_keyword_is(Keyword::THEN)?;
results.push(self.parse_expr()?); let result = self.parse_expr()?;
conditions.push(CaseWhen { condition, result });
if !self.parse_keyword(Keyword::WHEN) { if !self.parse_keyword(Keyword::WHEN) {
break; break;
} }
@ -2083,7 +2083,6 @@ impl<'a> Parser<'a> {
Ok(Expr::Case { Ok(Expr::Case {
operand, operand,
conditions, conditions,
results,
else_result, else_result,
}) })
} }

View file

@ -6695,22 +6695,26 @@ fn parse_searched_case_expr() {
&Case { &Case {
operand: None, operand: None,
conditions: vec![ conditions: vec![
IsNull(Box::new(Identifier(Ident::new("bar")))), CaseWhen {
BinaryOp { condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
},
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))), left: Box::new(Identifier(Ident::new("bar"))),
op: Eq, op: Eq,
right: Box::new(Expr::Value(number("0"))), right: Box::new(Expr::Value(number("0"))),
}, },
BinaryOp { result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
},
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))), left: Box::new(Identifier(Ident::new("bar"))),
op: GtEq, op: GtEq,
right: Box::new(Expr::Value(number("0"))), right: Box::new(Expr::Value(number("0"))),
}, },
], result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
results: vec![ },
Expr::Value(Value::SingleQuotedString("null".to_string())),
Expr::Value(Value::SingleQuotedString("=0".to_string())),
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
], ],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString( else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"<0".to_string() "<0".to_string()
@ -6729,8 +6733,10 @@ fn parse_simple_case_expr() {
assert_eq!( assert_eq!(
&Case { &Case {
operand: Some(Box::new(Identifier(Ident::new("foo")))), operand: Some(Box::new(Identifier(Ident::new("foo")))),
conditions: vec![Expr::Value(number("1"))], conditions: vec![CaseWhen {
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))], condition: Expr::Value(number("1")),
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
}],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString( else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"N".to_string() "N".to_string()
)))), )))),
@ -13902,6 +13908,31 @@ fn test_trailing_commas_in_from() {
); );
} }
#[test]
#[cfg(feature = "visitor")]
fn test_visit_order() {
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
let stmt = verified_stmt(sql);
let mut visited = vec![];
sqlparser::ast::visit_expressions(&stmt, |expr| {
visited.push(expr.to_string());
core::ops::ControlFlow::<()>::Continue(())
});
assert_eq!(
visited,
[
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
"a",
"1",
"2",
"3",
"4",
"5"
]
);
}
#[test] #[test]
fn test_lambdas() { fn test_lambdas() {
let dialects = all_dialects_where(|d| d.supports_lambda_functions()); let dialects = all_dialects_where(|d| d.supports_lambda_functions());
@ -13929,12 +13960,16 @@ fn test_lambdas() {
body: Box::new(Expr::Case { body: Box::new(Expr::Case {
operand: None, operand: None,
conditions: vec![ conditions: vec![
Expr::BinaryOp { CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))), left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq, op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2"))) right: Box::new(Expr::Identifier(Ident::new("p2")))
}, },
Expr::BinaryOp { result: Expr::Value(number("0"))
},
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call( left: Box::new(call(
"reverse", "reverse",
[Expr::Identifier(Ident::new("p1"))] [Expr::Identifier(Ident::new("p1"))]
@ -13944,14 +13979,12 @@ fn test_lambdas() {
"reverse", "reverse",
[Expr::Identifier(Ident::new("p2"))] [Expr::Identifier(Ident::new("p2"))]
)) ))
} },
], result: Expr::UnaryOp {
results: vec![
Expr::Value(number("0")),
Expr::UnaryOp {
op: UnaryOperator::Minus, op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1"))) expr: Box::new(Expr::Value(number("1")))
} }
}
], ],
else_result: Some(Box::new(Expr::Value(number("1")))) else_result: Some(Box::new(Expr::Value(number("1"))))
}) })

View file

@ -83,6 +83,71 @@ fn test_databricks_exists() {
); );
} }
#[test]
fn test_databricks_lambdas() {
#[rustfmt::skip]
let sql = concat!(
"SELECT array_sort(array('Hello', 'World'), ",
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
"ELSE 1 END)",
);
pretty_assertions::assert_eq!(
SelectItem::UnnamedExpr(call(
"array_sort",
[
call(
"array",
[
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
Expr::Value(Value::SingleQuotedString("World".to_owned()))
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
result: Expr::Value(number("0"))
},
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
)),
},
result: Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
},
],
else_result: Some(Box::new(Expr::Value(number("1"))))
})
})
]
)),
databricks().verified_only_select(sql).projection[0]
);
databricks().verified_expr(
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
);
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
}
#[test] #[test]
fn test_values_clause() { fn test_values_clause() {
let values = Values { let values = Values {