Add CREATE FUNCTION support for SQL Server (#1808)

This commit is contained in:
Andrew Harper 2025-04-23 12:10:57 -04:00 committed by GitHub
parent 945f8e0534
commit 2eb1e7bdd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 313 additions and 50 deletions

View file

@ -2157,6 +2157,10 @@ impl fmt::Display for ClusteredBy {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CreateFunction { pub struct CreateFunction {
/// True if this is a `CREATE OR ALTER FUNCTION` statement
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16#or-alter)
pub or_alter: bool,
pub or_replace: bool, pub or_replace: bool,
pub temporary: bool, pub temporary: bool,
pub if_not_exists: bool, pub if_not_exists: bool,
@ -2219,9 +2223,10 @@ impl fmt::Display for CreateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
f, f,
"CREATE {or_replace}{temp}FUNCTION {if_not_exists}{name}", "CREATE {or_alter}{or_replace}{temp}FUNCTION {if_not_exists}{name}",
name = self.name, name = self.name,
temp = if self.temporary { "TEMPORARY " } else { "" }, temp = if self.temporary { "TEMPORARY " } else { "" },
or_alter = if self.or_alter { "OR ALTER " } else { "" },
or_replace = if self.or_replace { "OR REPLACE " } else { "" }, or_replace = if self.or_replace { "OR REPLACE " } else { "" },
if_not_exists = if self.if_not_exists { if_not_exists = if self.if_not_exists {
"IF NOT EXISTS " "IF NOT EXISTS "
@ -2272,6 +2277,9 @@ impl fmt::Display for CreateFunction {
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body { if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
write!(f, " AS {function_body}")?; write!(f, " AS {function_body}")?;
} }
if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body {
write!(f, " AS {bes}")?;
}
Ok(()) Ok(())
} }
} }

View file

@ -2293,18 +2293,14 @@ pub enum ConditionalStatements {
/// SELECT 1; SELECT 2; SELECT 3; ... /// SELECT 1; SELECT 2; SELECT 3; ...
Sequence { statements: Vec<Statement> }, Sequence { statements: Vec<Statement> },
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END /// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
BeginEnd { BeginEnd(BeginEndStatements),
begin_token: AttachedToken,
statements: Vec<Statement>,
end_token: AttachedToken,
},
} }
impl ConditionalStatements { impl ConditionalStatements {
pub fn statements(&self) -> &Vec<Statement> { pub fn statements(&self) -> &Vec<Statement> {
match self { match self {
ConditionalStatements::Sequence { statements } => statements, ConditionalStatements::Sequence { statements } => statements,
ConditionalStatements::BeginEnd { statements, .. } => statements, ConditionalStatements::BeginEnd(bes) => &bes.statements,
} }
} }
} }
@ -2318,15 +2314,44 @@ impl fmt::Display for ConditionalStatements {
} }
Ok(()) Ok(())
} }
ConditionalStatements::BeginEnd { statements, .. } => { ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes),
write!(f, "BEGIN ")?;
format_statement_list(f, statements)?;
write!(f, " END")
}
} }
} }
} }
/// Represents a list of statements enclosed within `BEGIN` and `END` keywords.
/// Example:
/// ```sql
/// BEGIN
/// SELECT 1;
/// SELECT 2;
/// END
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct BeginEndStatements {
pub begin_token: AttachedToken,
pub statements: Vec<Statement>,
pub end_token: AttachedToken,
}
impl fmt::Display for BeginEndStatements {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
} = self;
write!(f, "{begin_token} ")?;
if !statements.is_empty() {
format_statement_list(f, statements)?;
}
write!(f, " {end_token}")
}
}
/// A `RAISE` statement. /// A `RAISE` statement.
/// ///
/// Examples: /// Examples:
@ -3615,6 +3640,7 @@ pub enum Statement {
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction) /// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html) /// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement) /// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
CreateFunction(CreateFunction), CreateFunction(CreateFunction),
/// CREATE TRIGGER /// CREATE TRIGGER
/// ///
@ -4061,6 +4087,12 @@ pub enum Statement {
/// ///
/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/print-transact-sql> /// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/print-transact-sql>
Print(PrintStatement), Print(PrintStatement),
/// ```sql
/// RETURN [ expression ]
/// ```
///
/// See [ReturnStatement]
Return(ReturnStatement),
} }
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@ -5753,6 +5785,7 @@ impl fmt::Display for Statement {
Ok(()) Ok(())
} }
Statement::Print(s) => write!(f, "{s}"), Statement::Print(s) => write!(f, "{s}"),
Statement::Return(r) => write!(f, "{r}"),
Statement::List(command) => write!(f, "LIST {command}"), Statement::List(command) => write!(f, "LIST {command}"),
Statement::Remove(command) => write!(f, "REMOVE {command}"), Statement::Remove(command) => write!(f, "REMOVE {command}"),
} }
@ -8355,6 +8388,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
/// ///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11 /// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html /// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@ -8383,6 +8417,22 @@ pub enum CreateFunctionBody {
/// ///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11 /// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
AsAfterOptions(Expr), AsAfterOptions(Expr),
/// Function body with statements before the `RETURN` keyword.
///
/// Example:
/// ```sql
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
/// RETURNS INT
/// AS
/// BEGIN
/// DECLARE c INT;
/// SET c = a + b;
/// RETURN c;
/// END
/// ```
///
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
AsBeginEnd(BeginEndStatements),
/// Function body expression using the 'RETURN' keyword. /// Function body expression using the 'RETURN' keyword.
/// ///
/// Example: /// Example:
@ -9231,6 +9281,34 @@ impl fmt::Display for PrintStatement {
} }
} }
/// Represents a `Return` statement.
///
/// [MsSql triggers](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql)
/// [MsSql functions](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ReturnStatement {
pub value: Option<ReturnStatementValue>,
}
impl fmt::Display for ReturnStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.value {
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
None => write!(f, "RETURN"),
}
}
}
/// Variants of a `RETURN` statement
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ReturnStatementValue {
Expr(Expr),
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -23,8 +23,8 @@ use crate::tokenizer::Span;
use super::{ use super::{
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation, dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken, AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken,
CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, BeginEndStatements, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption,
ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy, ColumnOptionDef, ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte, ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable, Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
@ -520,6 +520,7 @@ impl Spanned for Statement {
Statement::RenameTable { .. } => Span::empty(), Statement::RenameTable { .. } => Span::empty(),
Statement::RaisError { .. } => Span::empty(), Statement::RaisError { .. } => Span::empty(),
Statement::Print { .. } => Span::empty(), Statement::Print { .. } => Span::empty(),
Statement::Return { .. } => Span::empty(),
Statement::List(..) | Statement::Remove(..) => Span::empty(), Statement::List(..) | Statement::Remove(..) => Span::empty(),
} }
} }
@ -778,11 +779,7 @@ impl Spanned for ConditionalStatements {
ConditionalStatements::Sequence { statements } => { ConditionalStatements::Sequence { statements } => {
union_spans(statements.iter().map(|s| s.span())) union_spans(statements.iter().map(|s| s.span()))
} }
ConditionalStatements::BeginEnd { ConditionalStatements::BeginEnd(bes) => bes.span(),
begin_token: AttachedToken(start),
statements: _,
end_token: AttachedToken(end),
} => union_spans([start.span, end.span].into_iter()),
} }
} }
} }
@ -2282,6 +2279,21 @@ impl Spanned for TableObject {
} }
} }
impl Spanned for BeginEndStatements {
fn span(&self) -> Span {
let BeginEndStatements {
begin_token,
statements,
end_token,
} = self;
union_spans(
core::iter::once(begin_token.0.span)
.chain(statements.iter().map(|i| i.span()))
.chain(core::iter::once(end_token.0.span)),
)
}
}
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect}; use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};

View file

@ -16,7 +16,9 @@
// under the License. // under the License.
use crate::ast::helpers::attached_token::AttachedToken; use crate::ast::helpers::attached_token::AttachedToken;
use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement}; use crate::ast::{
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
};
use crate::dialect::Dialect; use crate::dialect::Dialect;
use crate::keywords::{self, Keyword}; use crate::keywords::{self, Keyword};
use crate::parser::{Parser, ParserError}; use crate::parser::{Parser, ParserError};
@ -149,11 +151,11 @@ impl MsSqlDialect {
start_token: AttachedToken(if_token), start_token: AttachedToken(if_token),
condition: Some(condition), condition: Some(condition),
then_token: None, then_token: None,
conditional_statements: ConditionalStatements::BeginEnd { conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token), begin_token: AttachedToken(begin_token),
statements, statements,
end_token: AttachedToken(end_token), end_token: AttachedToken(end_token),
}, }),
} }
} else { } else {
let stmt = parser.parse_statement()?; let stmt = parser.parse_statement()?;
@ -167,8 +169,10 @@ impl MsSqlDialect {
} }
}; };
let mut prior_statement_ended_with_semi_colon = false;
while let Token::SemiColon = parser.peek_token_ref().token { while let Token::SemiColon = parser.peek_token_ref().token {
parser.advance_token(); parser.advance_token();
prior_statement_ended_with_semi_colon = true;
} }
let mut else_block = None; let mut else_block = None;
@ -182,11 +186,11 @@ impl MsSqlDialect {
start_token: AttachedToken(else_token), start_token: AttachedToken(else_token),
condition: None, condition: None,
then_token: None, then_token: None,
conditional_statements: ConditionalStatements::BeginEnd { conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token), begin_token: AttachedToken(begin_token),
statements, statements,
end_token: AttachedToken(end_token), end_token: AttachedToken(end_token),
}, }),
}); });
} else { } else {
let stmt = parser.parse_statement()?; let stmt = parser.parse_statement()?;
@ -199,6 +203,8 @@ impl MsSqlDialect {
}, },
}); });
} }
} else if prior_statement_ended_with_semi_colon {
parser.prev_token();
} }
Ok(Statement::If(IfStatement { Ok(Statement::If(IfStatement {

View file

@ -577,13 +577,7 @@ impl<'a> Parser<'a> {
Keyword::GRANT => self.parse_grant(), Keyword::GRANT => self.parse_grant(),
Keyword::REVOKE => self.parse_revoke(), Keyword::REVOKE => self.parse_revoke(),
Keyword::START => self.parse_start_transaction(), Keyword::START => self.parse_start_transaction(),
// `BEGIN` is a nonstandard but common alias for the
// standard `START TRANSACTION` statement. It is supported
// by at least PostgreSQL and MySQL.
Keyword::BEGIN => self.parse_begin(), Keyword::BEGIN => self.parse_begin(),
// `END` is a nonstandard but common alias for the
// standard `COMMIT TRANSACTION` statement. It is supported
// by PostgreSQL.
Keyword::END => self.parse_end(), Keyword::END => self.parse_end(),
Keyword::SAVEPOINT => self.parse_savepoint(), Keyword::SAVEPOINT => self.parse_savepoint(),
Keyword::RELEASE => self.parse_release(), Keyword::RELEASE => self.parse_release(),
@ -618,6 +612,7 @@ impl<'a> Parser<'a> {
// `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment // `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment
Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(), Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(),
Keyword::PRINT => self.parse_print(), Keyword::PRINT => self.parse_print(),
Keyword::RETURN => self.parse_return(),
_ => self.expected("an SQL statement", next_token), _ => self.expected("an SQL statement", next_token),
}, },
Token::LParen => { Token::LParen => {
@ -4458,7 +4453,6 @@ impl<'a> Parser<'a> {
break; break;
} }
} }
values.push(self.parse_statement()?); values.push(self.parse_statement()?);
self.expect_token(&Token::SemiColon)?; self.expect_token(&Token::SemiColon)?;
} }
@ -4560,7 +4554,7 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::EXTERNAL) { } else if self.parse_keyword(Keyword::EXTERNAL) {
self.parse_create_external_table(or_replace) self.parse_create_external_table(or_replace)
} else if self.parse_keyword(Keyword::FUNCTION) { } else if self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(or_replace, temporary) self.parse_create_function(or_alter, or_replace, temporary)
} else if self.parse_keyword(Keyword::TRIGGER) { } else if self.parse_keyword(Keyword::TRIGGER) {
self.parse_create_trigger(or_replace, false) self.parse_create_trigger(or_replace, false)
} else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) { } else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) {
@ -4869,6 +4863,7 @@ impl<'a> Parser<'a> {
pub fn parse_create_function( pub fn parse_create_function(
&mut self, &mut self,
or_alter: bool,
or_replace: bool, or_replace: bool,
temporary: bool, temporary: bool,
) -> Result<Statement, ParserError> { ) -> Result<Statement, ParserError> {
@ -4880,6 +4875,8 @@ impl<'a> Parser<'a> {
self.parse_create_macro(or_replace, temporary) self.parse_create_macro(or_replace, temporary)
} else if dialect_of!(self is BigQueryDialect) { } else if dialect_of!(self is BigQueryDialect) {
self.parse_bigquery_create_function(or_replace, temporary) self.parse_bigquery_create_function(or_replace, temporary)
} else if dialect_of!(self is MsSqlDialect) {
self.parse_mssql_create_function(or_alter, or_replace, temporary)
} else { } else {
self.prev_token(); self.prev_token();
self.expected("an object type after CREATE", self.peek_token()) self.expected("an object type after CREATE", self.peek_token())
@ -4994,6 +4991,7 @@ impl<'a> Parser<'a> {
} }
Ok(Statement::CreateFunction(CreateFunction { Ok(Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace, or_replace,
temporary, temporary,
name, name,
@ -5027,6 +5025,7 @@ impl<'a> Parser<'a> {
let using = self.parse_optional_create_function_using()?; let using = self.parse_optional_create_function_using()?;
Ok(Statement::CreateFunction(CreateFunction { Ok(Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace, or_replace,
temporary, temporary,
name, name,
@ -5054,22 +5053,7 @@ impl<'a> Parser<'a> {
temporary: bool, temporary: bool,
) -> Result<Statement, ParserError> { ) -> Result<Statement, ParserError> {
let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]);
let name = self.parse_object_name(false)?; let (name, args) = self.parse_create_function_name_and_params()?;
let parse_function_param =
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
let name = parser.parse_identifier()?;
let data_type = parser.parse_data_type()?;
Ok(OperateFunctionArg {
mode: None,
name: Some(name),
data_type,
default_expr: None,
})
};
self.expect_token(&Token::LParen)?;
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
self.expect_token(&Token::RParen)?;
let return_type = if self.parse_keyword(Keyword::RETURNS) { let return_type = if self.parse_keyword(Keyword::RETURNS) {
Some(self.parse_data_type()?) Some(self.parse_data_type()?)
@ -5116,6 +5100,7 @@ impl<'a> Parser<'a> {
}; };
Ok(Statement::CreateFunction(CreateFunction { Ok(Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace, or_replace,
temporary, temporary,
if_not_exists, if_not_exists,
@ -5134,6 +5119,73 @@ impl<'a> Parser<'a> {
})) }))
} }
/// Parse `CREATE FUNCTION` for [MsSql]
///
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
fn parse_mssql_create_function(
&mut self,
or_alter: bool,
or_replace: bool,
temporary: bool,
) -> Result<Statement, ParserError> {
let (name, args) = self.parse_create_function_name_and_params()?;
self.expect_keyword(Keyword::RETURNS)?;
let return_type = Some(self.parse_data_type()?);
self.expect_keyword_is(Keyword::AS)?;
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
let statements = self.parse_statement_list(&[Keyword::END])?;
let end_token = self.expect_keyword(Keyword::END)?;
let function_body = Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
}));
Ok(Statement::CreateFunction(CreateFunction {
or_alter,
or_replace,
temporary,
if_not_exists: false,
name,
args: Some(args),
return_type,
function_body,
language: None,
determinism_specifier: None,
options: None,
remote_connection: None,
using: None,
behavior: None,
called_on_null: None,
parallel: None,
}))
}
fn parse_create_function_name_and_params(
&mut self,
) -> Result<(ObjectName, Vec<OperateFunctionArg>), ParserError> {
let name = self.parse_object_name(false)?;
let parse_function_param =
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
let name = parser.parse_identifier()?;
let data_type = parser.parse_data_type()?;
Ok(OperateFunctionArg {
mode: None,
name: Some(name),
data_type,
default_expr: None,
})
};
self.expect_token(&Token::LParen)?;
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
self.expect_token(&Token::RParen)?;
Ok((name, args))
}
fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> { fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> {
let mode = if self.parse_keyword(Keyword::IN) { let mode = if self.parse_keyword(Keyword::IN) {
Some(ArgMode::In) Some(ArgMode::In)
@ -15161,6 +15213,16 @@ impl<'a> Parser<'a> {
})) }))
} }
/// Parse [Statement::Return]
fn parse_return(&mut self) -> Result<Statement, ParserError> {
match self.maybe_parse(|p| p.parse_expr())? {
Some(expr) => Ok(Statement::Return(ReturnStatement {
value: Some(ReturnStatementValue::Expr(expr)),
})),
None => Ok(Statement::Return(ReturnStatement { value: None })),
}
}
/// Consume the parser and return its underlying token buffer /// Consume the parser and return its underlying token buffer
pub fn into_tokens(self) -> Vec<TokenWithSpan> { pub fn into_tokens(self) -> Vec<TokenWithSpan> {
self.tokens self.tokens

View file

@ -2134,6 +2134,7 @@ fn test_bigquery_create_function() {
assert_eq!( assert_eq!(
stmt, stmt,
Statement::CreateFunction(CreateFunction { Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: true, or_replace: true,
temporary: true, temporary: true,
if_not_exists: false, if_not_exists: false,

View file

@ -15079,3 +15079,11 @@ fn parse_set_time_zone_alias() {
_ => unreachable!(), _ => unreachable!(),
} }
} }
#[test]
fn parse_return() {
let stmt = all_dialects().verified_stmt("RETURN");
assert_eq!(stmt, Statement::Return(ReturnStatement { value: None }));
let _ = all_dialects().verified_stmt("RETURN 1");
}

View file

@ -25,7 +25,7 @@ use sqlparser::ast::{
Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr, Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value, OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
}; };
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect}; use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect};
use sqlparser::parser::ParserError; use sqlparser::parser::ParserError;
use sqlparser::test_utils::*; use sqlparser::test_utils::*;
@ -423,7 +423,7 @@ fn parse_create_function() {
} }
// Test error in dialect that doesn't support parsing CREATE FUNCTION // Test error in dialect that doesn't support parsing CREATE FUNCTION
let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]); let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]);
assert_eq!( assert_eq!(
unsupported_dialects.parse_sql_statements(sql).unwrap_err(), unsupported_dialects.parse_sql_statements(sql).unwrap_err(),

View file

@ -187,6 +187,92 @@ fn parse_mssql_create_procedure() {
let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10 END"); let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10 END");
} }
#[test]
fn parse_create_function() {
let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1; END";
assert_eq!(
ms().verified_stmt(return_expression_function),
sqlparser::ast::Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: false,
temporary: false,
if_not_exists: false,
name: ObjectName::from(vec![Ident::new("some_scalar_udf")]),
args: Some(vec![
OperateFunctionArg {
mode: None,
name: Some(Ident::new("@foo")),
data_type: DataType::Int(None),
default_expr: None,
},
OperateFunctionArg {
mode: None,
name: Some(Ident::new("@bar")),
data_type: DataType::Varchar(Some(CharacterLength::IntegerLength {
length: 256,
unit: None
})),
default_expr: None,
},
]),
return_type: Some(DataType::Int(None)),
function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
begin_token: AttachedToken::empty(),
statements: vec![Statement::Return(ReturnStatement {
value: Some(ReturnStatementValue::Expr(Expr::Value(
(number("1")).with_empty_span()
))),
}),],
end_token: AttachedToken::empty(),
})),
behavior: None,
called_on_null: None,
parallel: None,
using: None,
language: None,
determinism_specifier: None,
options: None,
remote_connection: None,
}),
);
let multi_statement_function = "\
CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
RETURNS INT \
AS \
BEGIN \
SET @foo = @foo + 1; \
RETURN @foo; \
END\
";
let _ = ms().verified_stmt(multi_statement_function);
let create_function_with_conditional = "\
CREATE FUNCTION some_scalar_udf() \
RETURNS INT \
AS \
BEGIN \
IF 1 = 2 \
BEGIN \
RETURN 1; \
END; \
RETURN 0; \
END\
";
let _ = ms().verified_stmt(create_function_with_conditional);
let create_or_alter_function = "\
CREATE OR ALTER FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
RETURNS INT \
AS \
BEGIN \
SET @foo = @foo + 1; \
RETURN @foo; \
END\
";
let _ = ms().verified_stmt(create_or_alter_function);
}
#[test] #[test]
fn parse_mssql_apply_join() { fn parse_mssql_apply_join() {
let _ = ms_and_generic().verified_only_select( let _ = ms_and_generic().verified_only_select(

View file

@ -4104,6 +4104,7 @@ fn parse_create_function() {
assert_eq!( assert_eq!(
pg_and_generic().verified_stmt(sql), pg_and_generic().verified_stmt(sql),
Statement::CreateFunction(CreateFunction { Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: false, or_replace: false,
temporary: false, temporary: false,
name: ObjectName::from(vec![Ident::new("add")]), name: ObjectName::from(vec![Ident::new("add")]),
@ -5485,6 +5486,7 @@ fn parse_trigger_related_functions() {
assert_eq!( assert_eq!(
create_function, create_function,
Statement::CreateFunction(CreateFunction { Statement::CreateFunction(CreateFunction {
or_alter: false,
or_replace: false, or_replace: false,
temporary: false, temporary: false,
if_not_exists: false, if_not_exists: false,