Replace parallel condition/result vectors with single CaseWhen vector in Expr::Case (#1733)

This commit is contained in:
Ophir LOJKINE 2025-02-22 07:23:36 +01:00 committed by GitHub
parent 7fc37a76e5
commit 72312ba82a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 160 additions and 50 deletions

View file

@ -600,6 +600,22 @@ pub enum CeilFloorKind {
Scale(Value),
}
/// A WHEN clause in a CASE expression containing both
/// the condition and its corresponding result
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseWhen {
pub condition: Expr,
pub result: Expr,
}
impl fmt::Display for CaseWhen {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WHEN {} THEN {}", self.condition, self.result)
}
}
/// An SQL expression of any type.
///
/// # Semantics / Type Checking
@ -918,8 +934,7 @@ pub enum Expr {
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
Case {
operand: Option<Box<Expr>>,
conditions: Vec<Expr>,
results: Vec<Expr>,
conditions: Vec<CaseWhen>,
else_result: Option<Box<Expr>>,
},
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
@ -1621,17 +1636,15 @@ impl fmt::Display for Expr {
Expr::Case {
operand,
conditions,
results,
else_result,
} => {
write!(f, "CASE")?;
if let Some(operand) = operand {
write!(f, " {operand}")?;
}
for (c, r) in conditions.iter().zip(results) {
write!(f, " WHEN {c} THEN {r}")?;
for when in conditions {
write!(f, " {when}")?;
}
if let Some(else_result) = else_result {
write!(f, " ELSE {else_result}")?;
}

View file

@ -1450,15 +1450,15 @@ impl Spanned for Expr {
Expr::Case {
operand,
conditions,
results,
else_result,
} => union_spans(
operand
.as_ref()
.map(|i| i.span())
.into_iter()
.chain(conditions.iter().map(|i| i.span()))
.chain(results.iter().map(|i| i.span()))
.chain(conditions.iter().flat_map(|case_when| {
[case_when.condition.span(), case_when.result.span()]
}))
.chain(else_result.as_ref().map(|i| i.span())),
),
Expr::Exists { subquery, .. } => subquery.span(),

View file

@ -2065,11 +2065,11 @@ impl<'a> Parser<'a> {
self.expect_keyword_is(Keyword::WHEN)?;
}
let mut conditions = vec![];
let mut results = vec![];
loop {
conditions.push(self.parse_expr()?);
let condition = self.parse_expr()?;
self.expect_keyword_is(Keyword::THEN)?;
results.push(self.parse_expr()?);
let result = self.parse_expr()?;
conditions.push(CaseWhen { condition, result });
if !self.parse_keyword(Keyword::WHEN) {
break;
}
@ -2083,7 +2083,6 @@ impl<'a> Parser<'a> {
Ok(Expr::Case {
operand,
conditions,
results,
else_result,
})
}

View file

@ -6695,22 +6695,26 @@ fn parse_searched_case_expr() {
&Case {
operand: None,
conditions: vec![
IsNull(Box::new(Identifier(Ident::new("bar")))),
BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: Eq,
right: Box::new(Expr::Value(number("0"))),
CaseWhen {
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
},
BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: GtEq,
right: Box::new(Expr::Value(number("0"))),
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: Eq,
right: Box::new(Expr::Value(number("0"))),
},
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
},
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: GtEq,
right: Box::new(Expr::Value(number("0"))),
},
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
},
],
results: vec![
Expr::Value(Value::SingleQuotedString("null".to_string())),
Expr::Value(Value::SingleQuotedString("=0".to_string())),
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"<0".to_string()
@ -6729,8 +6733,10 @@ fn parse_simple_case_expr() {
assert_eq!(
&Case {
operand: Some(Box::new(Identifier(Ident::new("foo")))),
conditions: vec![Expr::Value(number("1"))],
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
conditions: vec![CaseWhen {
condition: Expr::Value(number("1")),
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
}],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"N".to_string()
)))),
@ -13902,6 +13908,31 @@ fn test_trailing_commas_in_from() {
);
}
#[test]
#[cfg(feature = "visitor")]
fn test_visit_order() {
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
let stmt = verified_stmt(sql);
let mut visited = vec![];
sqlparser::ast::visit_expressions(&stmt, |expr| {
visited.push(expr.to_string());
core::ops::ControlFlow::<()>::Continue(())
});
assert_eq!(
visited,
[
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
"a",
"1",
"2",
"3",
"4",
"5"
]
);
}
#[test]
fn test_lambdas() {
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
@ -13929,28 +13960,30 @@ fn test_lambdas() {
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
result: Expr::Value(number("0"))
},
Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
))
}
],
results: vec![
Expr::Value(number("0")),
Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
))
},
result: Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
}
],
else_result: Some(Box::new(Expr::Value(number("1"))))

View file

@ -83,6 +83,71 @@ fn test_databricks_exists() {
);
}
#[test]
fn test_databricks_lambdas() {
#[rustfmt::skip]
let sql = concat!(
"SELECT array_sort(array('Hello', 'World'), ",
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
"ELSE 1 END)",
);
pretty_assertions::assert_eq!(
SelectItem::UnnamedExpr(call(
"array_sort",
[
call(
"array",
[
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
Expr::Value(Value::SingleQuotedString("World".to_owned()))
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
result: Expr::Value(number("0"))
},
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
)),
},
result: Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
},
],
else_result: Some(Box::new(Expr::Value(number("1"))))
})
})
]
)),
databricks().verified_only_select(sql).projection[0]
);
databricks().verified_expr(
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
);
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
}
#[test]
fn test_values_clause() {
let values = Values {