Improve support for cursors for SQL Server (#1831)

Co-authored-by: Ifeanyi Ubah <ify1992@yahoo.com>
This commit is contained in:
Andrew Harper 2025-05-01 23:25:30 -04:00 committed by GitHub
parent 483394cd1a
commit a464f8e8d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 289 additions and 20 deletions

View file

@ -2228,7 +2228,33 @@ impl fmt::Display for IfStatement {
}
}
/// A block within a [Statement::Case] or [Statement::If]-like statement
/// A `WHILE` statement.
///
/// Example:
/// ```sql
/// WHILE @@FETCH_STATUS = 0
/// BEGIN
/// FETCH NEXT FROM c1 INTO @var1, @var2;
/// END
/// ```
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/while-transact-sql)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct WhileStatement {
pub while_block: ConditionalStatementBlock,
}
impl fmt::Display for WhileStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let WhileStatement { while_block } = self;
write!(f, "{while_block}")?;
Ok(())
}
}
/// A block within a [Statement::Case] or [Statement::If] or [Statement::While]-like statement
///
/// Example 1:
/// ```sql
@ -2244,6 +2270,14 @@ impl fmt::Display for IfStatement {
/// ```sql
/// ELSE SELECT 1; SELECT 2;
/// ```
///
/// Example 4:
/// ```sql
/// WHILE @@FETCH_STATUS = 0
/// BEGIN
/// FETCH NEXT FROM c1 INTO @var1, @var2;
/// END
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@ -2983,6 +3017,8 @@ pub enum Statement {
Case(CaseStatement),
/// An `IF` statement.
If(IfStatement),
/// A `WHILE` statement.
While(WhileStatement),
/// A `RAISE` statement.
Raise(RaiseStatement),
/// ```sql
@ -3034,6 +3070,11 @@ pub enum Statement {
partition: Option<Box<Expr>>,
},
/// ```sql
/// OPEN cursor_name
/// ```
/// Opens a cursor.
Open(OpenStatement),
/// ```sql
/// CLOSE
/// ```
/// Closes the portal underlying an open cursor.
@ -3413,6 +3454,7 @@ pub enum Statement {
/// Cursor name
name: Ident,
direction: FetchDirection,
position: FetchPosition,
/// Optional, It's possible to fetch rows form cursor to the table
into: Option<ObjectName>,
},
@ -4235,11 +4277,10 @@ impl fmt::Display for Statement {
Statement::Fetch {
name,
direction,
position,
into,
} => {
write!(f, "FETCH {direction} ")?;
write!(f, "IN {name}")?;
write!(f, "FETCH {direction} {position} {name}")?;
if let Some(into) = into {
write!(f, " INTO {into}")?;
@ -4329,6 +4370,9 @@ impl fmt::Display for Statement {
Statement::If(stmt) => {
write!(f, "{stmt}")
}
Statement::While(stmt) => {
write!(f, "{stmt}")
}
Statement::Raise(stmt) => {
write!(f, "{stmt}")
}
@ -4498,6 +4542,7 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::Delete(delete) => write!(f, "{delete}"),
Statement::Open(open) => write!(f, "{open}"),
Statement::Close { cursor } => {
write!(f, "CLOSE {cursor}")?;
@ -6187,6 +6232,28 @@ impl fmt::Display for FetchDirection {
}
}
/// The "position" for a FETCH statement.
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/fetch-transact-sql)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FetchPosition {
From,
In,
}
impl fmt::Display for FetchPosition {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FetchPosition::From => f.write_str("FROM")?,
FetchPosition::In => f.write_str("IN")?,
};
Ok(())
}
}
/// A privilege on a database object (table, sequence, etc.).
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@ -9354,6 +9421,21 @@ pub enum ReturnStatementValue {
Expr(Expr),
}
/// Represents an `OPEN` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct OpenStatement {
/// Cursor name
pub cursor_name: Ident,
}
impl fmt::Display for OpenStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "OPEN {}", self.cursor_name)
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -31,13 +31,13 @@ use super::{
FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate,
InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView,
LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart,
Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition,
PivotValueSource, ProjectionSelect, Query, RaiseStatement, RaiseStatementValue,
ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
WildcardAdditionalOptions, With, WithFill,
Offset, OnConflict, OnConflictAction, OnInsert, OpenStatement, OrderBy, OrderByExpr,
OrderByKind, Partition, PivotValueSource, ProjectionSelect, Query, RaiseStatement,
RaiseStatementValue, ReferentialAction, RenameSelectItem, ReplaceSelectElement,
ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript,
SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject,
TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
WhileStatement, WildcardAdditionalOptions, With, WithFill,
};
/// Given an iterator of spans, return the [Span::union] of all spans.
@ -339,6 +339,7 @@ impl Spanned for Statement {
} => source.span(),
Statement::Case(stmt) => stmt.span(),
Statement::If(stmt) => stmt.span(),
Statement::While(stmt) => stmt.span(),
Statement::Raise(stmt) => stmt.span(),
Statement::Call(function) => function.span(),
Statement::Copy {
@ -365,6 +366,7 @@ impl Spanned for Statement {
from_query: _,
partition: _,
} => Span::empty(),
Statement::Open(open) => open.span(),
Statement::Close { cursor } => match cursor {
CloseCursor::All => Span::empty(),
CloseCursor::Specific { name } => name.span,
@ -776,6 +778,14 @@ impl Spanned for IfStatement {
}
}
impl Spanned for WhileStatement {
fn span(&self) -> Span {
let WhileStatement { while_block } = self;
while_block.span()
}
}
impl Spanned for ConditionalStatements {
fn span(&self) -> Span {
match self {
@ -2297,6 +2307,13 @@ impl Spanned for BeginEndStatements {
}
}
impl Spanned for OpenStatement {
fn span(&self) -> Span {
let OpenStatement { cursor_name } = self;
cursor_name.span
}
}
#[cfg(test)]
pub mod tests {
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};

View file

@ -985,6 +985,7 @@ define_keywords!(
WHEN,
WHENEVER,
WHERE,
WHILE,
WIDTH_BUCKET,
WINDOW,
WITH,
@ -1068,6 +1069,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
Keyword::SAMPLE,
Keyword::TABLESAMPLE,
Keyword::FROM,
Keyword::OPEN,
];
/// Can't be used as a column alias, so that `SELECT <expr> alias`

View file

@ -536,6 +536,10 @@ impl<'a> Parser<'a> {
self.prev_token();
self.parse_if_stmt()
}
Keyword::WHILE => {
self.prev_token();
self.parse_while()
}
Keyword::RAISE => {
self.prev_token();
self.parse_raise_stmt()
@ -570,6 +574,10 @@ impl<'a> Parser<'a> {
Keyword::ALTER => self.parse_alter(),
Keyword::CALL => self.parse_call(),
Keyword::COPY => self.parse_copy(),
Keyword::OPEN => {
self.prev_token();
self.parse_open()
}
Keyword::CLOSE => self.parse_close(),
Keyword::SET => self.parse_set(),
Keyword::SHOW => self.parse_show(),
@ -700,8 +708,18 @@ impl<'a> Parser<'a> {
}))
}
/// Parse a `WHILE` statement.
///
/// See [Statement::While]
fn parse_while(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword_is(Keyword::WHILE)?;
let while_block = self.parse_conditional_statement_block(&[Keyword::END])?;
Ok(Statement::While(WhileStatement { while_block }))
}
/// Parses an expression and associated list of statements
/// belonging to a conditional statement like `IF` or `WHEN`.
/// belonging to a conditional statement like `IF` or `WHEN` or `WHILE`.
///
/// Example:
/// ```sql
@ -716,6 +734,10 @@ impl<'a> Parser<'a> {
let condition = match &start_token.token {
Token::Word(w) if w.keyword == Keyword::ELSE => None,
Token::Word(w) if w.keyword == Keyword::WHILE => {
let expr = self.parse_expr()?;
Some(expr)
}
_ => {
let expr = self.parse_expr()?;
then_token = Some(AttachedToken(self.expect_keyword(Keyword::THEN)?));
@ -723,13 +745,25 @@ impl<'a> Parser<'a> {
}
};
let statements = self.parse_statement_list(terminal_keywords)?;
let conditional_statements = if self.peek_keyword(Keyword::BEGIN) {
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
let statements = self.parse_statement_list(terminal_keywords)?;
let end_token = self.expect_keyword(Keyword::END)?;
ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
})
} else {
let statements = self.parse_statement_list(terminal_keywords)?;
ConditionalStatements::Sequence { statements }
};
Ok(ConditionalStatementBlock {
start_token: AttachedToken(start_token),
condition,
then_token,
conditional_statements: ConditionalStatements::Sequence { statements },
conditional_statements,
})
}
@ -4467,11 +4501,16 @@ impl<'a> Parser<'a> {
) -> Result<Vec<Statement>, ParserError> {
let mut values = vec![];
loop {
if let Token::Word(w) = &self.peek_nth_token_ref(0).token {
if w.quote_style.is_none() && terminal_keywords.contains(&w.keyword) {
break;
match &self.peek_nth_token_ref(0).token {
Token::EOF => break,
Token::Word(w) => {
if w.quote_style.is_none() && terminal_keywords.contains(&w.keyword) {
break;
}
}
_ => {}
}
values.push(self.parse_statement()?);
self.expect_token(&Token::SemiColon)?;
}
@ -6644,7 +6683,15 @@ impl<'a> Parser<'a> {
}
};
self.expect_one_of_keywords(&[Keyword::FROM, Keyword::IN])?;
let position = if self.peek_keyword(Keyword::FROM) {
self.expect_keyword(Keyword::FROM)?;
FetchPosition::From
} else if self.peek_keyword(Keyword::IN) {
self.expect_keyword(Keyword::IN)?;
FetchPosition::In
} else {
return parser_err!("Expected FROM or IN", self.peek_token().span.start);
};
let name = self.parse_identifier()?;
@ -6657,6 +6704,7 @@ impl<'a> Parser<'a> {
Ok(Statement::Fetch {
name,
direction,
position,
into,
})
}
@ -8770,6 +8818,14 @@ impl<'a> Parser<'a> {
})
}
/// Parse [Statement::Open]
fn parse_open(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword(Keyword::OPEN)?;
Ok(Statement::Open(OpenStatement {
cursor_name: self.parse_identifier()?,
}))
}
pub fn parse_close(&mut self) -> Result<Statement, ParserError> {
let cursor = if self.parse_keyword(Keyword::ALL) {
CloseCursor::All

View file

@ -151,6 +151,8 @@ impl TestedDialects {
///
/// 2. re-serializing the result of parsing `sql` produces the same
/// `canonical` sql string
///
/// For multiple statements, use [`statements_parse_to`].
pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
let mut statements = self.parse_sql_statements(sql).expect(sql);
assert_eq!(statements.len(), 1);
@ -166,6 +168,24 @@ impl TestedDialects {
only_statement
}
/// The same as [`one_statement_parses_to`] but it works for a multiple statements
pub fn statements_parse_to(&self, sql: &str, canonical: &str) -> Vec<Statement> {
let statements = self.parse_sql_statements(sql).expect(sql);
if !canonical.is_empty() && sql != canonical {
assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
} else {
assert_eq!(
sql,
statements
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join("; ")
);
}
statements
}
/// Ensures that `sql` parses as an [`Expr`], and that
/// re-serializing the parse result produces canonical
pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {

View file

@ -15187,3 +15187,15 @@ fn parse_return() {
let _ = all_dialects().verified_stmt("RETURN 1");
}
#[test]
fn test_open() {
let open_cursor = "OPEN Employee_Cursor";
let stmt = all_dialects().verified_stmt(open_cursor);
assert_eq!(
stmt,
Statement::Open(OpenStatement {
cursor_name: Ident::new("Employee_Cursor"),
})
);
}

View file

@ -23,7 +23,8 @@
mod test_utils;
use helpers::attached_token::AttachedToken;
use sqlparser::tokenizer::{Location, Span};
use sqlparser::keywords::Keyword;
use sqlparser::tokenizer::{Location, Span, Token, TokenWithSpan, Word};
use test_utils::*;
use sqlparser::ast::DataType::{Int, Text, Varbinary};
@ -223,7 +224,7 @@ fn parse_create_function() {
value: Some(ReturnStatementValue::Expr(Expr::Value(
(number("1")).with_empty_span()
))),
}),],
})],
end_token: AttachedToken::empty(),
})),
behavior: None,
@ -1397,6 +1398,85 @@ fn parse_mssql_declare() {
let _ = ms().verified_stmt(declare_cursor_for_select);
}
#[test]
fn test_mssql_cursor() {
let full_cursor_usage = "\
DECLARE Employee_Cursor CURSOR FOR \
SELECT LastName, FirstName \
FROM AdventureWorks2022.HumanResources.vEmployee \
WHERE LastName LIKE 'B%'; \
\
OPEN Employee_Cursor; \
\
FETCH NEXT FROM Employee_Cursor; \
\
WHILE @@FETCH_STATUS = 0 \
BEGIN \
FETCH NEXT FROM Employee_Cursor; \
END; \
\
CLOSE Employee_Cursor; \
DEALLOCATE Employee_Cursor\
";
let _ = ms().statements_parse_to(full_cursor_usage, "");
}
#[test]
fn test_mssql_while_statement() {
let while_single_statement = "WHILE 1 = 0 PRINT 'Hello World';";
let stmt = ms().verified_stmt(while_single_statement);
assert_eq!(
stmt,
Statement::While(sqlparser::ast::WhileStatement {
while_block: ConditionalStatementBlock {
start_token: AttachedToken(TokenWithSpan {
token: Token::Word(Word {
value: "WHILE".to_string(),
quote_style: None,
keyword: Keyword::WHILE
}),
span: Span::empty()
}),
condition: Some(Expr::BinaryOp {
left: Box::new(Expr::Value(
(Value::Number("1".parse().unwrap(), false)).with_empty_span()
)),
op: BinaryOperator::Eq,
right: Box::new(Expr::Value(
(Value::Number("0".parse().unwrap(), false)).with_empty_span()
)),
}),
then_token: None,
conditional_statements: ConditionalStatements::Sequence {
statements: vec![Statement::Print(PrintStatement {
message: Box::new(Expr::Value(
(Value::SingleQuotedString("Hello World".to_string()))
.with_empty_span()
)),
})],
}
}
})
);
let while_begin_end = "\
WHILE @@FETCH_STATUS = 0 \
BEGIN \
FETCH NEXT FROM Employee_Cursor; \
END\
";
let _ = ms().verified_stmt(while_begin_end);
let while_begin_end_multiple_statements = "\
WHILE @@FETCH_STATUS = 0 \
BEGIN \
FETCH NEXT FROM Employee_Cursor; \
PRINT 'Hello World'; \
END\
";
let _ = ms().verified_stmt(while_begin_end_multiple_statements);
}
#[test]
fn test_parse_raiserror() {
let sql = r#"RAISERROR('This is a test', 16, 1)"#;