Add support for MS-SQL BEGIN/END TRY/CATCH (#1649)

This commit is contained in:
Yoav Cohen 2025-01-08 19:31:24 +01:00 committed by GitHub
parent 397bceb241
commit 687ce2d5f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 112 additions and 25 deletions

View file

@ -548,7 +548,11 @@ mod tests {
#[test] #[test]
pub fn test_from_invalid_statement() { pub fn test_from_invalid_statement() {
let stmt = Statement::Commit { chain: false }; let stmt = Statement::Commit {
chain: false,
end: false,
modifier: None,
};
assert_eq!( assert_eq!(
CreateTableBuilder::try_from(stmt).unwrap_err(), CreateTableBuilder::try_from(stmt).unwrap_err(),

View file

@ -2958,7 +2958,6 @@ pub enum Statement {
modes: Vec<TransactionMode>, modes: Vec<TransactionMode>,
begin: bool, begin: bool,
transaction: Option<BeginTransactionKind>, transaction: Option<BeginTransactionKind>,
/// Only for SQLite
modifier: Option<TransactionModifier>, modifier: Option<TransactionModifier>,
}, },
/// ```sql /// ```sql
@ -2985,7 +2984,17 @@ pub enum Statement {
/// ```sql /// ```sql
/// COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] /// COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]
/// ``` /// ```
Commit { chain: bool }, /// If `end` is false
///
/// ```sql
/// END [ TRY | CATCH ]
/// ```
/// If `end` is true
Commit {
chain: bool,
end: bool,
modifier: Option<TransactionModifier>,
},
/// ```sql /// ```sql
/// ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ] /// ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]
/// ``` /// ```
@ -4614,8 +4623,23 @@ impl fmt::Display for Statement {
} }
Ok(()) Ok(())
} }
Statement::Commit { chain } => { Statement::Commit {
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },) chain,
end: end_syntax,
modifier,
} => {
if *end_syntax {
write!(f, "END")?;
if let Some(modifier) = *modifier {
write!(f, " {}", modifier)?;
}
if *chain {
write!(f, " AND CHAIN")?;
}
} else {
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" })?;
}
Ok(())
} }
Statement::Rollback { chain, savepoint } => { Statement::Rollback { chain, savepoint } => {
write!(f, "ROLLBACK")?; write!(f, "ROLLBACK")?;
@ -6388,9 +6412,10 @@ impl fmt::Display for TransactionIsolationLevel {
} }
} }
/// SQLite specific syntax /// Modifier for the transaction in the `BEGIN` syntax
/// ///
/// <https://sqlite.org/lang_transaction.html> /// SQLite: <https://sqlite.org/lang_transaction.html>
/// MS-SQL: <https://learn.microsoft.com/en-us/sql/t-sql/language-elements/try-catch-transact-sql>
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@ -6398,6 +6423,8 @@ pub enum TransactionModifier {
Deferred, Deferred,
Immediate, Immediate,
Exclusive, Exclusive,
Try,
Catch,
} }
impl fmt::Display for TransactionModifier { impl fmt::Display for TransactionModifier {
@ -6407,6 +6434,8 @@ impl fmt::Display for TransactionModifier {
Deferred => "DEFERRED", Deferred => "DEFERRED",
Immediate => "IMMEDIATE", Immediate => "IMMEDIATE",
Exclusive => "EXCLUSIVE", Exclusive => "EXCLUSIVE",
Try => "TRY",
Catch => "CATCH",
}) })
} }
} }

View file

@ -260,11 +260,16 @@ pub trait Dialect: Debug + Any {
false false
} }
/// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE} [TRANSACTION]` statements /// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE | TRY | CATCH} [TRANSACTION]` statements
fn supports_start_transaction_modifier(&self) -> bool { fn supports_start_transaction_modifier(&self) -> bool {
false false
} }
/// Returns true if the dialect supports `END {TRY | CATCH}` statements
fn supports_end_transaction_modifier(&self) -> bool {
false
}
/// Returns true if the dialect supports named arguments of the form `FUN(a = '1', b = '2')`. /// Returns true if the dialect supports named arguments of the form `FUN(a = '1', b = '2')`.
fn supports_named_fn_args_with_eq_operator(&self) -> bool { fn supports_named_fn_args_with_eq_operator(&self) -> bool {
false false

View file

@ -78,4 +78,11 @@ impl Dialect for MsSqlDialect {
fn supports_named_fn_args_with_rarrow_operator(&self) -> bool { fn supports_named_fn_args_with_rarrow_operator(&self) -> bool {
false false
} }
fn supports_start_transaction_modifier(&self) -> bool {
true
}
fn supports_end_transaction_modifier(&self) -> bool {
true
}
} }

View file

@ -151,6 +151,7 @@ define_keywords!(
CASE, CASE,
CAST, CAST,
CATALOG, CATALOG,
CATCH,
CEIL, CEIL,
CEILING, CEILING,
CENTURY, CENTURY,
@ -812,6 +813,7 @@ define_keywords!(
TRIM_ARRAY, TRIM_ARRAY,
TRUE, TRUE,
TRUNCATE, TRUNCATE,
TRY,
TRY_CAST, TRY_CAST,
TRY_CONVERT, TRY_CONVERT,
TUPLE, TUPLE,

View file

@ -12800,6 +12800,10 @@ impl<'a> Parser<'a> {
Some(TransactionModifier::Immediate) Some(TransactionModifier::Immediate)
} else if self.parse_keyword(Keyword::EXCLUSIVE) { } else if self.parse_keyword(Keyword::EXCLUSIVE) {
Some(TransactionModifier::Exclusive) Some(TransactionModifier::Exclusive)
} else if self.parse_keyword(Keyword::TRY) {
Some(TransactionModifier::Try)
} else if self.parse_keyword(Keyword::CATCH) {
Some(TransactionModifier::Catch)
} else { } else {
None None
}; };
@ -12817,8 +12821,19 @@ impl<'a> Parser<'a> {
} }
pub fn parse_end(&mut self) -> Result<Statement, ParserError> { pub fn parse_end(&mut self) -> Result<Statement, ParserError> {
let modifier = if !self.dialect.supports_end_transaction_modifier() {
None
} else if self.parse_keyword(Keyword::TRY) {
Some(TransactionModifier::Try)
} else if self.parse_keyword(Keyword::CATCH) {
Some(TransactionModifier::Catch)
} else {
None
};
Ok(Statement::Commit { Ok(Statement::Commit {
chain: self.parse_commit_rollback_chain()?, chain: self.parse_commit_rollback_chain()?,
end: true,
modifier,
}) })
} }
@ -12861,6 +12876,8 @@ impl<'a> Parser<'a> {
pub fn parse_commit(&mut self) -> Result<Statement, ParserError> { pub fn parse_commit(&mut self) -> Result<Statement, ParserError> {
Ok(Statement::Commit { Ok(Statement::Commit {
chain: self.parse_commit_rollback_chain()?, chain: self.parse_commit_rollback_chain()?,
end: false,
modifier: None,
}) })
} }

View file

@ -7887,6 +7887,27 @@ fn parse_start_transaction() {
ParserError::ParserError("Expected: transaction mode, found: EOF".to_string()), ParserError::ParserError("Expected: transaction mode, found: EOF".to_string()),
res.unwrap_err() res.unwrap_err()
); );
// MS-SQL syntax
let dialects = all_dialects_where(|d| d.supports_start_transaction_modifier());
dialects.verified_stmt("BEGIN TRY");
dialects.verified_stmt("BEGIN CATCH");
let dialects = all_dialects_where(|d| {
d.supports_start_transaction_modifier() && d.supports_end_transaction_modifier()
});
dialects
.parse_sql_statements(
r#"
BEGIN TRY;
SELECT 1/0;
END TRY;
BEGIN CATCH;
EXECUTE foo;
END CATCH;
"#,
)
.unwrap();
} }
#[test] #[test]
@ -8102,12 +8123,12 @@ fn parse_set_time_zone_alias() {
#[test] #[test]
fn parse_commit() { fn parse_commit() {
match verified_stmt("COMMIT") { match verified_stmt("COMMIT") {
Statement::Commit { chain: false } => (), Statement::Commit { chain: false, .. } => (),
_ => unreachable!(), _ => unreachable!(),
} }
match verified_stmt("COMMIT AND CHAIN") { match verified_stmt("COMMIT AND CHAIN") {
Statement::Commit { chain: true } => (), Statement::Commit { chain: true, .. } => (),
_ => unreachable!(), _ => unreachable!(),
} }
@ -8122,13 +8143,17 @@ fn parse_commit() {
#[test] #[test]
fn parse_end() { fn parse_end() {
one_statement_parses_to("END AND NO CHAIN", "COMMIT"); one_statement_parses_to("END AND NO CHAIN", "END");
one_statement_parses_to("END WORK AND NO CHAIN", "COMMIT"); one_statement_parses_to("END WORK AND NO CHAIN", "END");
one_statement_parses_to("END TRANSACTION AND NO CHAIN", "COMMIT"); one_statement_parses_to("END TRANSACTION AND NO CHAIN", "END");
one_statement_parses_to("END WORK AND CHAIN", "COMMIT AND CHAIN"); one_statement_parses_to("END WORK AND CHAIN", "END AND CHAIN");
one_statement_parses_to("END TRANSACTION AND CHAIN", "COMMIT AND CHAIN"); one_statement_parses_to("END TRANSACTION AND CHAIN", "END AND CHAIN");
one_statement_parses_to("END WORK", "COMMIT"); one_statement_parses_to("END WORK", "END");
one_statement_parses_to("END TRANSACTION", "COMMIT"); one_statement_parses_to("END TRANSACTION", "END");
// MS-SQL syntax
let dialects = all_dialects_where(|d| d.supports_end_transaction_modifier());
dialects.verified_stmt("END TRY");
dialects.verified_stmt("END CATCH");
} }
#[test] #[test]

View file

@ -115,7 +115,11 @@ fn custom_statement_parser() -> Result<(), ParserError> {
for _ in 0..3 { for _ in 0..3 {
let _ = parser.next_token(); let _ = parser.next_token();
} }
Some(Ok(Statement::Commit { chain: false })) Some(Ok(Statement::Commit {
chain: false,
end: false,
modifier: None,
}))
} else { } else {
None None
} }

View file

@ -523,13 +523,7 @@ fn parse_start_transaction_with_modifier() {
sqlite_and_generic().verified_stmt("BEGIN IMMEDIATE"); sqlite_and_generic().verified_stmt("BEGIN IMMEDIATE");
sqlite_and_generic().verified_stmt("BEGIN EXCLUSIVE"); sqlite_and_generic().verified_stmt("BEGIN EXCLUSIVE");
let unsupported_dialects = TestedDialects::new( let unsupported_dialects = all_dialects_except(|d| d.supports_start_transaction_modifier());
all_dialects()
.dialects
.into_iter()
.filter(|x| !(x.is::<SQLiteDialect>() || x.is::<GenericDialect>()))
.collect(),
);
let res = unsupported_dialects.parse_sql_statements("BEGIN DEFERRED"); let res = unsupported_dialects.parse_sql_statements("BEGIN DEFERRED");
assert_eq!( assert_eq!(
ParserError::ParserError("Expected: end of statement, found: DEFERRED".to_string()), ParserError::ParserError("Expected: end of statement, found: DEFERRED".to_string()),