Add CASE and IF statement support (#1741)

This commit is contained in:
Ifeanyi Ubah 2025-03-14 07:49:25 +01:00 committed by GitHub
parent cf4ab7f9ab
commit 862e887a66
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 473 additions and 22 deletions

View file

@ -151,6 +151,15 @@ where
DisplaySeparated { slice, sep: ", " }
}
/// Writes the given statements to the formatter, each ending with
/// a semicolon and space separated.
fn format_statement_list(f: &mut fmt::Formatter, statements: &[Statement]) -> fmt::Result {
write!(f, "{}", display_separated(statements, "; "))?;
// We manually insert semicolon for the last statement,
// since display_separated doesn't handle that case.
write!(f, ";")
}
/// An identifier, decomposed into its value or character data and the quote style.
#[derive(Debug, Clone, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@ -2080,6 +2089,173 @@ pub enum Password {
NullPassword,
}
/// A `CASE` statement.
///
/// Examples:
/// ```sql
/// CASE
/// WHEN EXISTS(SELECT 1)
/// THEN SELECT 1 FROM T;
/// WHEN EXISTS(SELECT 2)
/// THEN SELECT 1 FROM U;
/// ELSE
/// SELECT 1 FROM V;
/// END CASE;
/// ```
///
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case_search_expression)
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseStatement {
pub match_expr: Option<Expr>,
pub when_blocks: Vec<ConditionalStatements>,
pub else_block: Option<Vec<Statement>>,
/// TRUE if the statement ends with `END CASE` (vs `END`).
pub has_end_case: bool,
}
impl fmt::Display for CaseStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let CaseStatement {
match_expr,
when_blocks,
else_block,
has_end_case,
} = self;
write!(f, "CASE")?;
if let Some(expr) = match_expr {
write!(f, " {expr}")?;
}
if !when_blocks.is_empty() {
write!(f, " {}", display_separated(when_blocks, " "))?;
}
if let Some(else_block) = else_block {
write!(f, " ELSE ")?;
format_statement_list(f, else_block)?;
}
write!(f, " END")?;
if *has_end_case {
write!(f, " CASE")?;
}
Ok(())
}
}
/// An `IF` statement.
///
/// Examples:
/// ```sql
/// IF TRUE THEN
/// SELECT 1;
/// SELECT 2;
/// ELSEIF TRUE THEN
/// SELECT 3;
/// ELSE
/// SELECT 4;
/// END IF
/// ```
///
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct IfStatement {
pub if_block: ConditionalStatements,
pub elseif_blocks: Vec<ConditionalStatements>,
pub else_block: Option<Vec<Statement>>,
}
impl fmt::Display for IfStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let IfStatement {
if_block,
elseif_blocks,
else_block,
} = self;
write!(f, "{if_block}")?;
if !elseif_blocks.is_empty() {
write!(f, " {}", display_separated(elseif_blocks, " "))?;
}
if let Some(else_block) = else_block {
write!(f, " ELSE ")?;
format_statement_list(f, else_block)?;
}
write!(f, " END IF")?;
Ok(())
}
}
/// Represents a type of [ConditionalStatements]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ConditionalStatementKind {
/// `WHEN <condition> THEN <statements>`
When,
/// `IF <condition> THEN <statements>`
If,
/// `ELSEIF <condition> THEN <statements>`
ElseIf,
}
/// A block within a [Statement::Case] or [Statement::If]-like statement
///
/// Examples:
/// ```sql
/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
///
/// IF TRUE THEN SELECT 1; SELECT 2;
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ConditionalStatements {
/// The condition expression.
pub condition: Expr,
/// Statement list of the `THEN` clause.
pub statements: Vec<Statement>,
pub kind: ConditionalStatementKind,
}
impl fmt::Display for ConditionalStatements {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let ConditionalStatements {
condition: expr,
statements,
kind,
} = self;
let kind = match kind {
ConditionalStatementKind::When => "WHEN",
ConditionalStatementKind::If => "IF",
ConditionalStatementKind::ElseIf => "ELSEIF",
};
write!(f, "{kind} {expr} THEN")?;
if !statements.is_empty() {
write!(f, " ")?;
format_statement_list(f, statements)?;
}
Ok(())
}
}
/// Represents an expression assignment within a variable `DECLARE` statement.
///
/// Examples:
@ -2647,6 +2823,10 @@ pub enum Statement {
file_format: Option<FileFormat>,
source: Box<Query>,
},
/// A `CASE` statement.
Case(CaseStatement),
/// An `IF` statement.
If(IfStatement),
/// ```sql
/// CALL <function>
/// ```
@ -3940,6 +4120,12 @@ impl fmt::Display for Statement {
}
Ok(())
}
Statement::Case(stmt) => {
write!(f, "{stmt}")
}
Statement::If(stmt) => {
write!(f, "{stmt}")
}
Statement::AttachDatabase {
schema_name,
database_file_name,
@ -4942,18 +5128,14 @@ impl fmt::Display for Statement {
write!(f, " {}", display_comma_separated(modes))?;
}
if !statements.is_empty() {
write!(f, " {}", display_separated(statements, "; "))?;
// We manually insert semicolon for the last statement,
// since display_separated doesn't handle that case.
write!(f, ";")?;
write!(f, " ")?;
format_statement_list(f, statements)?;
}
if let Some(exception_statements) = exception_statements {
write!(f, " EXCEPTION WHEN ERROR THEN")?;
if !exception_statements.is_empty() {
write!(f, " {}", display_separated(exception_statements, "; "))?;
// We manually insert semicolon for the last statement,
// since display_separated doesn't handle that case.
write!(f, ";")?;
write!(f, " ")?;
format_statement_list(f, exception_statements)?;
}
}
if *has_end_keyword {

View file

@ -22,20 +22,21 @@ use crate::tokenizer::Span;
use super::{
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CloseCursor,
ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, ConnectBy,
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
FunctionArguments, GroupByExpr, HavingBound, IlikeSelectItem, Insert, Interpolate,
InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView,
LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart,
Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition,
PivotValueSource, ProjectionSelect, Query, ReferentialAction, RenameSelectItem,
ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption,
Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint,
TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use,
Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, CaseStatement,
CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConditionalStatements,
ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable,
CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr,
ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr,
FunctionArgumentClause, FunctionArgumentList, FunctionArguments, GroupByExpr, HavingBound,
IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, JoinConstraint,
JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure,
NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction,
OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect,
Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
WildcardAdditionalOptions, With, WithFill,
};
/// Given an iterator of spans, return the [Span::union] of all spans.
@ -334,6 +335,8 @@ impl Spanned for Statement {
file_format: _,
source,
} => source.span(),
Statement::Case(stmt) => stmt.span(),
Statement::If(stmt) => stmt.span(),
Statement::Call(function) => function.span(),
Statement::Copy {
source,
@ -732,6 +735,53 @@ impl Spanned for CreateIndex {
}
}
impl Spanned for CaseStatement {
fn span(&self) -> Span {
let CaseStatement {
match_expr,
when_blocks,
else_block,
has_end_case: _,
} = self;
union_spans(
match_expr
.iter()
.map(|e| e.span())
.chain(when_blocks.iter().map(|b| b.span()))
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
)
}
}
impl Spanned for IfStatement {
fn span(&self) -> Span {
let IfStatement {
if_block,
elseif_blocks,
else_block,
} = self;
union_spans(
iter::once(if_block.span())
.chain(elseif_blocks.iter().map(|b| b.span()))
.chain(else_block.iter().flat_map(|e| e.iter().map(|s| s.span()))),
)
}
}
impl Spanned for ConditionalStatements {
fn span(&self) -> Span {
let ConditionalStatements {
condition,
statements,
kind: _,
} = self;
union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| s.span())))
}
}
/// # partial span
///
/// Missing spans:

View file

@ -297,6 +297,7 @@ define_keywords!(
ELEMENT,
ELEMENTS,
ELSE,
ELSEIF,
EMPTY,
ENABLE,
ENABLE_SCHEMA_EVOLUTION,

View file

@ -528,6 +528,14 @@ impl<'a> Parser<'a> {
Keyword::DESCRIBE => self.parse_explain(DescribeAlias::Describe),
Keyword::EXPLAIN => self.parse_explain(DescribeAlias::Explain),
Keyword::ANALYZE => self.parse_analyze(),
Keyword::CASE => {
self.prev_token();
self.parse_case_stmt()
}
Keyword::IF => {
self.prev_token();
self.parse_if_stmt()
}
Keyword::SELECT | Keyword::WITH | Keyword::VALUES | Keyword::FROM => {
self.prev_token();
self.parse_query().map(Statement::Query)
@ -615,6 +623,102 @@ impl<'a> Parser<'a> {
}
}
/// Parse a `CASE` statement.
///
/// See [Statement::Case]
pub fn parse_case_stmt(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword_is(Keyword::CASE)?;
let match_expr = if self.peek_keyword(Keyword::WHEN) {
None
} else {
Some(self.parse_expr()?)
};
self.expect_keyword_is(Keyword::WHEN)?;
let when_blocks = self.parse_keyword_separated(Keyword::WHEN, |parser| {
parser.parse_conditional_statements(
ConditionalStatementKind::When,
&[Keyword::WHEN, Keyword::ELSE, Keyword::END],
)
})?;
let else_block = if self.parse_keyword(Keyword::ELSE) {
Some(self.parse_statement_list(&[Keyword::END])?)
} else {
None
};
self.expect_keyword_is(Keyword::END)?;
let has_end_case = self.parse_keyword(Keyword::CASE);
Ok(Statement::Case(CaseStatement {
match_expr,
when_blocks,
else_block,
has_end_case,
}))
}
/// Parse an `IF` statement.
///
/// See [Statement::If]
pub fn parse_if_stmt(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword_is(Keyword::IF)?;
let if_block = self.parse_conditional_statements(
ConditionalStatementKind::If,
&[Keyword::ELSE, Keyword::ELSEIF, Keyword::END],
)?;
let elseif_blocks = if self.parse_keyword(Keyword::ELSEIF) {
self.parse_keyword_separated(Keyword::ELSEIF, |parser| {
parser.parse_conditional_statements(
ConditionalStatementKind::ElseIf,
&[Keyword::ELSEIF, Keyword::ELSE, Keyword::END],
)
})?
} else {
vec![]
};
let else_block = if self.parse_keyword(Keyword::ELSE) {
Some(self.parse_statement_list(&[Keyword::END])?)
} else {
None
};
self.expect_keywords(&[Keyword::END, Keyword::IF])?;
Ok(Statement::If(IfStatement {
if_block,
elseif_blocks,
else_block,
}))
}
/// Parses an expression and associated list of statements
/// belonging to a conditional statement like `IF` or `WHEN`.
///
/// Example:
/// ```sql
/// IF condition THEN statement1; statement2;
/// ```
fn parse_conditional_statements(
&mut self,
kind: ConditionalStatementKind,
terminal_keywords: &[Keyword],
) -> Result<ConditionalStatements, ParserError> {
let condition = self.parse_expr()?;
self.expect_keyword_is(Keyword::THEN)?;
let statements = self.parse_statement_list(terminal_keywords)?;
Ok(ConditionalStatements {
condition,
statements,
kind,
})
}
pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]);

View file

@ -14179,6 +14179,120 @@ fn test_visit_order() {
);
}
#[test]
fn parse_case_statement() {
let sql = "CASE 1 WHEN 2 THEN SELECT 1; SELECT 2; ELSE SELECT 3; END CASE";
let Statement::Case(stmt) = verified_stmt(sql) else {
unreachable!()
};
assert_eq!(Some(Expr::value(number("1"))), stmt.match_expr);
assert_eq!(Expr::value(number("2")), stmt.when_blocks[0].condition);
assert_eq!(2, stmt.when_blocks[0].statements.len());
assert_eq!(1, stmt.else_block.unwrap().len());
verified_stmt(concat!(
"CASE 1",
" WHEN a THEN",
" SELECT 1; SELECT 2; SELECT 3;",
" WHEN b THEN",
" SELECT 4; SELECT 5;",
" ELSE",
" SELECT 7; SELECT 8;",
" END CASE"
));
verified_stmt(concat!(
"CASE 1",
" WHEN a THEN",
" SELECT 1; SELECT 2; SELECT 3;",
" WHEN b THEN",
" SELECT 4; SELECT 5;",
" END CASE"
));
verified_stmt(concat!(
"CASE 1",
" WHEN a THEN",
" SELECT 1; SELECT 2; SELECT 3;",
" END CASE"
));
verified_stmt(concat!(
"CASE 1",
" WHEN a THEN",
" SELECT 1; SELECT 2; SELECT 3;",
" END"
));
assert_eq!(
ParserError::ParserError("Expected: THEN, found: END".to_string()),
parse_sql_statements("CASE 1 WHEN a END").unwrap_err()
);
assert_eq!(
ParserError::ParserError("Expected: WHEN, found: ELSE".to_string()),
parse_sql_statements("CASE 1 ELSE SELECT 1; END").unwrap_err()
);
}
#[test]
fn parse_if_statement() {
let sql = "IF 1 THEN SELECT 1; ELSEIF 2 THEN SELECT 2; ELSE SELECT 3; END IF";
let Statement::If(stmt) = verified_stmt(sql) else {
unreachable!()
};
assert_eq!(Expr::value(number("1")), stmt.if_block.condition);
assert_eq!(Expr::value(number("2")), stmt.elseif_blocks[0].condition);
assert_eq!(1, stmt.else_block.unwrap().len());
verified_stmt(concat!(
"IF 1 THEN",
" SELECT 1;",
" SELECT 2;",
" SELECT 3;",
" ELSEIF 2 THEN",
" SELECT 4;",
" SELECT 5;",
" ELSEIF 3 THEN",
" SELECT 6;",
" SELECT 7;",
" ELSE",
" SELECT 8;",
" SELECT 9;",
" END IF"
));
verified_stmt(concat!(
"IF 1 THEN",
" SELECT 1;",
" SELECT 2;",
" ELSE",
" SELECT 3;",
" SELECT 4;",
" END IF"
));
verified_stmt(concat!(
"IF 1 THEN",
" SELECT 1;",
" SELECT 2;",
" SELECT 3;",
" ELSEIF 2 THEN",
" SELECT 3;",
" SELECT 4;",
" END IF"
));
verified_stmt(concat!("IF 1 THEN", " SELECT 1;", " SELECT 2;", " END IF"));
verified_stmt(concat!(
"IF (1) THEN",
" SELECT 1;",
" SELECT 2;",
" END IF"
));
verified_stmt("IF 1 THEN END IF");
verified_stmt("IF 1 THEN SELECT 1; ELSEIF 1 THEN END IF");
assert_eq!(
ParserError::ParserError("Expected: IF, found: EOF".to_string()),
parse_sql_statements("IF 1 THEN SELECT 1; ELSEIF 1 THEN SELECT 2; END").unwrap_err()
);
}
#[test]
fn test_lambdas() {
let dialects = all_dialects_where(|d| d.supports_lambda_functions());