Add CREATE FUNCTION support for SQL Server (#1808)

This commit is contained in:
Andrew Harper 2025-04-23 12:10:57 -04:00 committed by GitHub
parent 945f8e0534
commit 2eb1e7bdd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 313 additions and 50 deletions

View file

@ -2157,6 +2157,10 @@ impl fmt::Display for ClusteredBy {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CreateFunction {
/// True if this is a `CREATE OR ALTER FUNCTION` statement
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16#or-alter)
pub or_alter: bool,
pub or_replace: bool,
pub temporary: bool,
pub if_not_exists: bool,
@ -2219,9 +2223,10 @@ impl fmt::Display for CreateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"CREATE {or_replace}{temp}FUNCTION {if_not_exists}{name}",
"CREATE {or_alter}{or_replace}{temp}FUNCTION {if_not_exists}{name}",
name = self.name,
temp = if self.temporary { "TEMPORARY " } else { "" },
or_alter = if self.or_alter { "OR ALTER " } else { "" },
or_replace = if self.or_replace { "OR REPLACE " } else { "" },
if_not_exists = if self.if_not_exists {
"IF NOT EXISTS "
@ -2272,6 +2277,9 @@ impl fmt::Display for CreateFunction {
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
write!(f, " AS {function_body}")?;
}
if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body {
write!(f, " AS {bes}")?;
}
Ok(())
}
}

View file

@ -2293,18 +2293,14 @@ pub enum ConditionalStatements {
/// SELECT 1; SELECT 2; SELECT 3; ...
Sequence { statements: Vec<Statement> },
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
BeginEnd {
begin_token: AttachedToken,
statements: Vec<Statement>,
end_token: AttachedToken,
},
BeginEnd(BeginEndStatements),
}
impl ConditionalStatements {
pub fn statements(&self) -> &Vec<Statement> {
match self {
ConditionalStatements::Sequence { statements } => statements,
ConditionalStatements::BeginEnd { statements, .. } => statements,
ConditionalStatements::BeginEnd(bes) => &bes.statements,
}
}
}
@ -2318,15 +2314,44 @@ impl fmt::Display for ConditionalStatements {
}
Ok(())
}
ConditionalStatements::BeginEnd { statements, .. } => {
write!(f, "BEGIN ")?;
format_statement_list(f, statements)?;
write!(f, " END")
}
ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes),
}
}
}
/// Represents a list of statements enclosed within `BEGIN` and `END` keywords.
/// Example:
/// ```sql
/// BEGIN
/// SELECT 1;
/// SELECT 2;
/// END
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct BeginEndStatements {
pub begin_token: AttachedToken,
pub statements: Vec<Statement>,
pub end_token: AttachedToken,
}
impl fmt::Display for BeginEndStatements {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
} = self;
write!(f, "{begin_token} ")?;
if !statements.is_empty() {
format_statement_list(f, statements)?;
}
write!(f, " {end_token}")
}
}
/// A `RAISE` statement.
///
/// Examples:
@ -3615,6 +3640,7 @@ pub enum Statement {
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
CreateFunction(CreateFunction),
/// CREATE TRIGGER
///
@ -4061,6 +4087,12 @@ pub enum Statement {
///
/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/print-transact-sql>
Print(PrintStatement),
/// ```sql
/// RETURN [ expression ]
/// ```
///
/// See [ReturnStatement]
Return(ReturnStatement),
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@ -5753,6 +5785,7 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::Print(s) => write!(f, "{s}"),
Statement::Return(r) => write!(f, "{r}"),
Statement::List(command) => write!(f, "LIST {command}"),
Statement::Remove(command) => write!(f, "REMOVE {command}"),
}
@ -8355,6 +8388,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@ -8383,6 +8417,22 @@ pub enum CreateFunctionBody {
///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
AsAfterOptions(Expr),
/// Function body with statements before the `RETURN` keyword.
///
/// Example:
/// ```sql
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
/// RETURNS INT
/// AS
/// BEGIN
/// DECLARE c INT;
/// SET c = a + b;
/// RETURN c;
/// END
/// ```
///
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
AsBeginEnd(BeginEndStatements),
/// Function body expression using the 'RETURN' keyword.
///
/// Example:
@ -9231,6 +9281,34 @@ impl fmt::Display for PrintStatement {
}
}
/// Represents a `Return` statement.
///
/// [MsSql triggers](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql)
/// [MsSql functions](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ReturnStatement {
pub value: Option<ReturnStatementValue>,
}
impl fmt::Display for ReturnStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.value {
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
None => write!(f, "RETURN"),
}
}
}
/// Variants of a `RETURN` statement
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ReturnStatementValue {
Expr(Expr),
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -23,8 +23,8 @@ use crate::tokenizer::Span;
use super::{
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken,
CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef,
ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
BeginEndStatements, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption,
ColumnOptionDef, ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
@ -520,6 +520,7 @@ impl Spanned for Statement {
Statement::RenameTable { .. } => Span::empty(),
Statement::RaisError { .. } => Span::empty(),
Statement::Print { .. } => Span::empty(),
Statement::Return { .. } => Span::empty(),
Statement::List(..) | Statement::Remove(..) => Span::empty(),
}
}
@ -778,11 +779,7 @@ impl Spanned for ConditionalStatements {
ConditionalStatements::Sequence { statements } => {
union_spans(statements.iter().map(|s| s.span()))
}
ConditionalStatements::BeginEnd {
begin_token: AttachedToken(start),
statements: _,
end_token: AttachedToken(end),
} => union_spans([start.span, end.span].into_iter()),
ConditionalStatements::BeginEnd(bes) => bes.span(),
}
}
}
@ -2282,6 +2279,21 @@ impl Spanned for TableObject {
}
}
impl Spanned for BeginEndStatements {
fn span(&self) -> Span {
let BeginEndStatements {
begin_token,
statements,
end_token,
} = self;
union_spans(
core::iter::once(begin_token.0.span)
.chain(statements.iter().map(|i| i.span()))
.chain(core::iter::once(end_token.0.span)),
)
}
}
#[cfg(test)]
pub mod tests {
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};

View file

@ -16,7 +16,9 @@
// under the License.
use crate::ast::helpers::attached_token::AttachedToken;
use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement};
use crate::ast::{
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
};
use crate::dialect::Dialect;
use crate::keywords::{self, Keyword};
use crate::parser::{Parser, ParserError};
@ -149,11 +151,11 @@ impl MsSqlDialect {
start_token: AttachedToken(if_token),
condition: Some(condition),
then_token: None,
conditional_statements: ConditionalStatements::BeginEnd {
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
},
}),
}
} else {
let stmt = parser.parse_statement()?;
@ -167,8 +169,10 @@ impl MsSqlDialect {
}
};
let mut prior_statement_ended_with_semi_colon = false;
while let Token::SemiColon = parser.peek_token_ref().token {
parser.advance_token();
prior_statement_ended_with_semi_colon = true;
}
let mut else_block = None;
@ -182,11 +186,11 @@ impl MsSqlDialect {
start_token: AttachedToken(else_token),
condition: None,
then_token: None,
conditional_statements: ConditionalStatements::BeginEnd {
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
},
}),
});
} else {
let stmt = parser.parse_statement()?;
@ -199,6 +203,8 @@ impl MsSqlDialect {
},
});
}
} else if prior_statement_ended_with_semi_colon {
parser.prev_token();
}
Ok(Statement::If(IfStatement {

View file

@ -577,13 +577,7 @@ impl<'a> Parser<'a> {
Keyword::GRANT => self.parse_grant(),
Keyword::REVOKE => self.parse_revoke(),
Keyword::START => self.parse_start_transaction(),
// `BEGIN` is a nonstandard but common alias for the
// standard `START TRANSACTION` statement. It is supported
// by at least PostgreSQL and MySQL.
Keyword::BEGIN => self.parse_begin(),
// `END` is a nonstandard but common alias for the
// standard `COMMIT TRANSACTION` statement. It is supported
// by PostgreSQL.
Keyword::END => self.parse_end(),
Keyword::SAVEPOINT => self.parse_savepoint(),
Keyword::RELEASE => self.parse_release(),
@ -618,6 +612,7 @@ impl<'a> Parser<'a> {
// `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment
Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(),
Keyword::PRINT => self.parse_print(),
Keyword::RETURN => self.parse_return(),
_ => self.expected("an SQL statement", next_token),
},
Token::LParen => {
@ -4458,7 +4453,6 @@ impl<'a> Parser<'a> {
break;
}
}
values.push(self.parse_statement()?);
self.expect_token(&Token::SemiColon)?;
}
@ -4560,7 +4554,7 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::EXTERNAL) {
self.parse_create_external_table(or_replace)
} else if self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(or_replace, temporary)
self.parse_create_function(or_alter, or_replace, temporary)
} else if self.parse_keyword(Keyword::TRIGGER) {
self.parse_create_trigger(or_replace, false)
} else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) {
@ -4869,6 +4863,7 @@ impl<'a> Parser<'a> {
pub fn parse_create_function(
&mut self,
or_alter: bool,
or_replace: bool,
temporary: bool,
) -> Result<Statement, ParserError> {
@ -4880,6 +4875,8 @@ impl<'a> Parser<'a> {
self.parse_create_macro(or_replace, temporary)
} else if dialect_of!(self is BigQueryDialect) {
self.parse_bigquery_create_function(or_replace, temporary)
} else if dialect_of!(self is MsSqlDialect) {
self.parse_mssql_create_function(or_alter, or_replace, temporary)
} else {
self.prev_token();
self.expected("an object type after CREATE", self.peek_token())
@ -4994,6 +4991,7 @@ impl<'a> Parser<'a> {
}
Ok(Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace,
temporary,
name,
@ -5027,6 +5025,7 @@ impl<'a> Parser<'a> {
let using = self.parse_optional_create_function_using()?;
Ok(Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace,
temporary,
name,
@ -5054,22 +5053,7 @@ impl<'a> Parser<'a> {
temporary: bool,
) -> Result<Statement, ParserError> {
let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]);
let name = self.parse_object_name(false)?;
let parse_function_param =
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
let name = parser.parse_identifier()?;
let data_type = parser.parse_data_type()?;
Ok(OperateFunctionArg {
mode: None,
name: Some(name),
data_type,
default_expr: None,
})
};
self.expect_token(&Token::LParen)?;
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
self.expect_token(&Token::RParen)?;
let (name, args) = self.parse_create_function_name_and_params()?;
let return_type = if self.parse_keyword(Keyword::RETURNS) {
Some(self.parse_data_type()?)
@ -5116,6 +5100,7 @@ impl<'a> Parser<'a> {
};
Ok(Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace,
temporary,
if_not_exists,
@ -5134,6 +5119,73 @@ impl<'a> Parser<'a> {
}))
}
/// Parse `CREATE FUNCTION` for [MsSql]
///
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
fn parse_mssql_create_function(
&mut self,
or_alter: bool,
or_replace: bool,
temporary: bool,
) -> Result<Statement, ParserError> {
let (name, args) = self.parse_create_function_name_and_params()?;
self.expect_keyword(Keyword::RETURNS)?;
let return_type = Some(self.parse_data_type()?);
self.expect_keyword_is(Keyword::AS)?;
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
let statements = self.parse_statement_list(&[Keyword::END])?;
let end_token = self.expect_keyword(Keyword::END)?;
let function_body = Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
}));
Ok(Statement::CreateFunction(CreateFunction {
or_alter,
or_replace,
temporary,
if_not_exists: false,
name,
args: Some(args),
return_type,
function_body,
language: None,
determinism_specifier: None,
options: None,
remote_connection: None,
using: None,
behavior: None,
called_on_null: None,
parallel: None,
}))
}
fn parse_create_function_name_and_params(
&mut self,
) -> Result<(ObjectName, Vec<OperateFunctionArg>), ParserError> {
let name = self.parse_object_name(false)?;
let parse_function_param =
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
let name = parser.parse_identifier()?;
let data_type = parser.parse_data_type()?;
Ok(OperateFunctionArg {
mode: None,
name: Some(name),
data_type,
default_expr: None,
})
};
self.expect_token(&Token::LParen)?;
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
self.expect_token(&Token::RParen)?;
Ok((name, args))
}
fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> {
let mode = if self.parse_keyword(Keyword::IN) {
Some(ArgMode::In)
@ -15161,6 +15213,16 @@ impl<'a> Parser<'a> {
}))
}
/// Parse [Statement::Return]
fn parse_return(&mut self) -> Result<Statement, ParserError> {
match self.maybe_parse(|p| p.parse_expr())? {
Some(expr) => Ok(Statement::Return(ReturnStatement {
value: Some(ReturnStatementValue::Expr(expr)),
})),
None => Ok(Statement::Return(ReturnStatement { value: None })),
}
}
/// Consume the parser and return its underlying token buffer
pub fn into_tokens(self) -> Vec<TokenWithSpan> {
self.tokens

View file

@ -2134,6 +2134,7 @@ fn test_bigquery_create_function() {
assert_eq!(
stmt,
Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: true,
temporary: true,
if_not_exists: false,

View file

@ -15079,3 +15079,11 @@ fn parse_set_time_zone_alias() {
_ => unreachable!(),
}
}
#[test]
fn parse_return() {
let stmt = all_dialects().verified_stmt("RETURN");
assert_eq!(stmt, Statement::Return(ReturnStatement { value: None }));
let _ = all_dialects().verified_stmt("RETURN 1");
}

View file

@ -25,7 +25,7 @@ use sqlparser::ast::{
Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
};
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect};
use sqlparser::parser::ParserError;
use sqlparser::test_utils::*;
@ -423,7 +423,7 @@ fn parse_create_function() {
}
// Test error in dialect that doesn't support parsing CREATE FUNCTION
let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]);
let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]);
assert_eq!(
unsupported_dialects.parse_sql_statements(sql).unwrap_err(),

View file

@ -187,6 +187,92 @@ fn parse_mssql_create_procedure() {
let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10 END");
}
#[test]
fn parse_create_function() {
let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1; END";
assert_eq!(
ms().verified_stmt(return_expression_function),
sqlparser::ast::Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: false,
temporary: false,
if_not_exists: false,
name: ObjectName::from(vec![Ident::new("some_scalar_udf")]),
args: Some(vec![
OperateFunctionArg {
mode: None,
name: Some(Ident::new("@foo")),
data_type: DataType::Int(None),
default_expr: None,
},
OperateFunctionArg {
mode: None,
name: Some(Ident::new("@bar")),
data_type: DataType::Varchar(Some(CharacterLength::IntegerLength {
length: 256,
unit: None
})),
default_expr: None,
},
]),
return_type: Some(DataType::Int(None)),
function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
begin_token: AttachedToken::empty(),
statements: vec![Statement::Return(ReturnStatement {
value: Some(ReturnStatementValue::Expr(Expr::Value(
(number("1")).with_empty_span()
))),
}),],
end_token: AttachedToken::empty(),
})),
behavior: None,
called_on_null: None,
parallel: None,
using: None,
language: None,
determinism_specifier: None,
options: None,
remote_connection: None,
}),
);
let multi_statement_function = "\
CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
RETURNS INT \
AS \
BEGIN \
SET @foo = @foo + 1; \
RETURN @foo; \
END\
";
let _ = ms().verified_stmt(multi_statement_function);
let create_function_with_conditional = "\
CREATE FUNCTION some_scalar_udf() \
RETURNS INT \
AS \
BEGIN \
IF 1 = 2 \
BEGIN \
RETURN 1; \
END; \
RETURN 0; \
END\
";
let _ = ms().verified_stmt(create_function_with_conditional);
let create_or_alter_function = "\
CREATE OR ALTER FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
RETURNS INT \
AS \
BEGIN \
SET @foo = @foo + 1; \
RETURN @foo; \
END\
";
let _ = ms().verified_stmt(create_or_alter_function);
}
#[test]
fn parse_mssql_apply_join() {
let _ = ms_and_generic().verified_only_select(

View file

@ -4104,6 +4104,7 @@ fn parse_create_function() {
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: false,
temporary: false,
name: ObjectName::from(vec![Ident::new("add")]),
@ -5485,6 +5486,7 @@ fn parse_trigger_related_functions() {
assert_eq!(
create_function,
Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: false,
temporary: false,
if_not_exists: false,