Add support for MS-SQL BEGIN/END TRY/CATCH (#1649)

This commit is contained in:
Yoav Cohen 2025-01-08 19:31:24 +01:00 committed by GitHub
parent 397bceb241
commit 687ce2d5f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 112 additions and 25 deletions

View file

@ -548,7 +548,11 @@ mod tests {
#[test]
pub fn test_from_invalid_statement() {
let stmt = Statement::Commit { chain: false };
let stmt = Statement::Commit {
chain: false,
end: false,
modifier: None,
};
assert_eq!(
CreateTableBuilder::try_from(stmt).unwrap_err(),

View file

@ -2958,7 +2958,6 @@ pub enum Statement {
modes: Vec<TransactionMode>,
begin: bool,
transaction: Option<BeginTransactionKind>,
/// Only for SQLite
modifier: Option<TransactionModifier>,
},
/// ```sql
@ -2985,7 +2984,17 @@ pub enum Statement {
/// ```sql
/// COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]
/// ```
Commit { chain: bool },
/// If `end` is false
///
/// ```sql
/// END [ TRY | CATCH ]
/// ```
/// If `end` is true
Commit {
chain: bool,
end: bool,
modifier: Option<TransactionModifier>,
},
/// ```sql
/// ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]
/// ```
@ -4614,8 +4623,23 @@ impl fmt::Display for Statement {
}
Ok(())
}
Statement::Commit { chain } => {
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },)
Statement::Commit {
chain,
end: end_syntax,
modifier,
} => {
if *end_syntax {
write!(f, "END")?;
if let Some(modifier) = *modifier {
write!(f, " {}", modifier)?;
}
if *chain {
write!(f, " AND CHAIN")?;
}
} else {
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" })?;
}
Ok(())
}
Statement::Rollback { chain, savepoint } => {
write!(f, "ROLLBACK")?;
@ -6388,9 +6412,10 @@ impl fmt::Display for TransactionIsolationLevel {
}
}
/// SQLite specific syntax
/// Modifier for the transaction in the `BEGIN` syntax
///
/// <https://sqlite.org/lang_transaction.html>
/// SQLite: <https://sqlite.org/lang_transaction.html>
/// MS-SQL: <https://learn.microsoft.com/en-us/sql/t-sql/language-elements/try-catch-transact-sql>
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@ -6398,6 +6423,8 @@ pub enum TransactionModifier {
Deferred,
Immediate,
Exclusive,
Try,
Catch,
}
impl fmt::Display for TransactionModifier {
@ -6407,6 +6434,8 @@ impl fmt::Display for TransactionModifier {
Deferred => "DEFERRED",
Immediate => "IMMEDIATE",
Exclusive => "EXCLUSIVE",
Try => "TRY",
Catch => "CATCH",
})
}
}

View file

@ -260,11 +260,16 @@ pub trait Dialect: Debug + Any {
false
}
/// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE} [TRANSACTION]` statements
/// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE | TRY | CATCH} [TRANSACTION]` statements
fn supports_start_transaction_modifier(&self) -> bool {
false
}
/// Returns true if the dialect supports `END {TRY | CATCH}` statements
fn supports_end_transaction_modifier(&self) -> bool {
false
}
/// Returns true if the dialect supports named arguments of the form `FUN(a = '1', b = '2')`.
fn supports_named_fn_args_with_eq_operator(&self) -> bool {
false

View file

@ -78,4 +78,11 @@ impl Dialect for MsSqlDialect {
fn supports_named_fn_args_with_rarrow_operator(&self) -> bool {
false
}
fn supports_start_transaction_modifier(&self) -> bool {
true
}
fn supports_end_transaction_modifier(&self) -> bool {
true
}
}

View file

@ -151,6 +151,7 @@ define_keywords!(
CASE,
CAST,
CATALOG,
CATCH,
CEIL,
CEILING,
CENTURY,
@ -812,6 +813,7 @@ define_keywords!(
TRIM_ARRAY,
TRUE,
TRUNCATE,
TRY,
TRY_CAST,
TRY_CONVERT,
TUPLE,

View file

@ -12800,6 +12800,10 @@ impl<'a> Parser<'a> {
Some(TransactionModifier::Immediate)
} else if self.parse_keyword(Keyword::EXCLUSIVE) {
Some(TransactionModifier::Exclusive)
} else if self.parse_keyword(Keyword::TRY) {
Some(TransactionModifier::Try)
} else if self.parse_keyword(Keyword::CATCH) {
Some(TransactionModifier::Catch)
} else {
None
};
@ -12817,8 +12821,19 @@ impl<'a> Parser<'a> {
}
pub fn parse_end(&mut self) -> Result<Statement, ParserError> {
let modifier = if !self.dialect.supports_end_transaction_modifier() {
None
} else if self.parse_keyword(Keyword::TRY) {
Some(TransactionModifier::Try)
} else if self.parse_keyword(Keyword::CATCH) {
Some(TransactionModifier::Catch)
} else {
None
};
Ok(Statement::Commit {
chain: self.parse_commit_rollback_chain()?,
end: true,
modifier,
})
}
@ -12861,6 +12876,8 @@ impl<'a> Parser<'a> {
pub fn parse_commit(&mut self) -> Result<Statement, ParserError> {
Ok(Statement::Commit {
chain: self.parse_commit_rollback_chain()?,
end: false,
modifier: None,
})
}

View file

@ -7887,6 +7887,27 @@ fn parse_start_transaction() {
ParserError::ParserError("Expected: transaction mode, found: EOF".to_string()),
res.unwrap_err()
);
// MS-SQL syntax
let dialects = all_dialects_where(|d| d.supports_start_transaction_modifier());
dialects.verified_stmt("BEGIN TRY");
dialects.verified_stmt("BEGIN CATCH");
let dialects = all_dialects_where(|d| {
d.supports_start_transaction_modifier() && d.supports_end_transaction_modifier()
});
dialects
.parse_sql_statements(
r#"
BEGIN TRY;
SELECT 1/0;
END TRY;
BEGIN CATCH;
EXECUTE foo;
END CATCH;
"#,
)
.unwrap();
}
#[test]
@ -8102,12 +8123,12 @@ fn parse_set_time_zone_alias() {
#[test]
fn parse_commit() {
match verified_stmt("COMMIT") {
Statement::Commit { chain: false } => (),
Statement::Commit { chain: false, .. } => (),
_ => unreachable!(),
}
match verified_stmt("COMMIT AND CHAIN") {
Statement::Commit { chain: true } => (),
Statement::Commit { chain: true, .. } => (),
_ => unreachable!(),
}
@ -8122,13 +8143,17 @@ fn parse_commit() {
#[test]
fn parse_end() {
one_statement_parses_to("END AND NO CHAIN", "COMMIT");
one_statement_parses_to("END WORK AND NO CHAIN", "COMMIT");
one_statement_parses_to("END TRANSACTION AND NO CHAIN", "COMMIT");
one_statement_parses_to("END WORK AND CHAIN", "COMMIT AND CHAIN");
one_statement_parses_to("END TRANSACTION AND CHAIN", "COMMIT AND CHAIN");
one_statement_parses_to("END WORK", "COMMIT");
one_statement_parses_to("END TRANSACTION", "COMMIT");
one_statement_parses_to("END AND NO CHAIN", "END");
one_statement_parses_to("END WORK AND NO CHAIN", "END");
one_statement_parses_to("END TRANSACTION AND NO CHAIN", "END");
one_statement_parses_to("END WORK AND CHAIN", "END AND CHAIN");
one_statement_parses_to("END TRANSACTION AND CHAIN", "END AND CHAIN");
one_statement_parses_to("END WORK", "END");
one_statement_parses_to("END TRANSACTION", "END");
// MS-SQL syntax
let dialects = all_dialects_where(|d| d.supports_end_transaction_modifier());
dialects.verified_stmt("END TRY");
dialects.verified_stmt("END CATCH");
}
#[test]

View file

@ -115,7 +115,11 @@ fn custom_statement_parser() -> Result<(), ParserError> {
for _ in 0..3 {
let _ = parser.next_token();
}
Some(Ok(Statement::Commit { chain: false }))
Some(Ok(Statement::Commit {
chain: false,
end: false,
modifier: None,
}))
} else {
None
}

View file

@ -523,13 +523,7 @@ fn parse_start_transaction_with_modifier() {
sqlite_and_generic().verified_stmt("BEGIN IMMEDIATE");
sqlite_and_generic().verified_stmt("BEGIN EXCLUSIVE");
let unsupported_dialects = TestedDialects::new(
all_dialects()
.dialects
.into_iter()
.filter(|x| !(x.is::<SQLiteDialect>() || x.is::<GenericDialect>()))
.collect(),
);
let unsupported_dialects = all_dialects_except(|d| d.supports_start_transaction_modifier());
let res = unsupported_dialects.parse_sql_statements("BEGIN DEFERRED");
assert_eq!(
ParserError::ParserError("Expected: end of statement, found: DEFERRED".to_string()),