mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-07-07 17:04:59 +00:00
Add CREATE FUNCTION
support for SQL Server (#1808)
This commit is contained in:
parent
945f8e0534
commit
2eb1e7bdd4
10 changed files with 313 additions and 50 deletions
|
@ -2157,6 +2157,10 @@ impl fmt::Display for ClusteredBy {
|
|||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||
pub struct CreateFunction {
|
||||
/// True if this is a `CREATE OR ALTER FUNCTION` statement
|
||||
///
|
||||
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16#or-alter)
|
||||
pub or_alter: bool,
|
||||
pub or_replace: bool,
|
||||
pub temporary: bool,
|
||||
pub if_not_exists: bool,
|
||||
|
@ -2219,9 +2223,10 @@ impl fmt::Display for CreateFunction {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"CREATE {or_replace}{temp}FUNCTION {if_not_exists}{name}",
|
||||
"CREATE {or_alter}{or_replace}{temp}FUNCTION {if_not_exists}{name}",
|
||||
name = self.name,
|
||||
temp = if self.temporary { "TEMPORARY " } else { "" },
|
||||
or_alter = if self.or_alter { "OR ALTER " } else { "" },
|
||||
or_replace = if self.or_replace { "OR REPLACE " } else { "" },
|
||||
if_not_exists = if self.if_not_exists {
|
||||
"IF NOT EXISTS "
|
||||
|
@ -2272,6 +2277,9 @@ impl fmt::Display for CreateFunction {
|
|||
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
|
||||
write!(f, " AS {function_body}")?;
|
||||
}
|
||||
if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body {
|
||||
write!(f, " AS {bes}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
100
src/ast/mod.rs
100
src/ast/mod.rs
|
@ -2293,18 +2293,14 @@ pub enum ConditionalStatements {
|
|||
/// SELECT 1; SELECT 2; SELECT 3; ...
|
||||
Sequence { statements: Vec<Statement> },
|
||||
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
|
||||
BeginEnd {
|
||||
begin_token: AttachedToken,
|
||||
statements: Vec<Statement>,
|
||||
end_token: AttachedToken,
|
||||
},
|
||||
BeginEnd(BeginEndStatements),
|
||||
}
|
||||
|
||||
impl ConditionalStatements {
|
||||
pub fn statements(&self) -> &Vec<Statement> {
|
||||
match self {
|
||||
ConditionalStatements::Sequence { statements } => statements,
|
||||
ConditionalStatements::BeginEnd { statements, .. } => statements,
|
||||
ConditionalStatements::BeginEnd(bes) => &bes.statements,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2318,15 +2314,44 @@ impl fmt::Display for ConditionalStatements {
|
|||
}
|
||||
Ok(())
|
||||
}
|
||||
ConditionalStatements::BeginEnd { statements, .. } => {
|
||||
write!(f, "BEGIN ")?;
|
||||
format_statement_list(f, statements)?;
|
||||
write!(f, " END")
|
||||
}
|
||||
ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a list of statements enclosed within `BEGIN` and `END` keywords.
|
||||
/// Example:
|
||||
/// ```sql
|
||||
/// BEGIN
|
||||
/// SELECT 1;
|
||||
/// SELECT 2;
|
||||
/// END
|
||||
/// ```
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||
pub struct BeginEndStatements {
|
||||
pub begin_token: AttachedToken,
|
||||
pub statements: Vec<Statement>,
|
||||
pub end_token: AttachedToken,
|
||||
}
|
||||
|
||||
impl fmt::Display for BeginEndStatements {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let BeginEndStatements {
|
||||
begin_token: AttachedToken(begin_token),
|
||||
statements,
|
||||
end_token: AttachedToken(end_token),
|
||||
} = self;
|
||||
|
||||
write!(f, "{begin_token} ")?;
|
||||
if !statements.is_empty() {
|
||||
format_statement_list(f, statements)?;
|
||||
}
|
||||
write!(f, " {end_token}")
|
||||
}
|
||||
}
|
||||
|
||||
/// A `RAISE` statement.
|
||||
///
|
||||
/// Examples:
|
||||
|
@ -3615,6 +3640,7 @@ pub enum Statement {
|
|||
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
|
||||
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
|
||||
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
|
||||
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
|
||||
CreateFunction(CreateFunction),
|
||||
/// CREATE TRIGGER
|
||||
///
|
||||
|
@ -4061,6 +4087,12 @@ pub enum Statement {
|
|||
///
|
||||
/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/print-transact-sql>
|
||||
Print(PrintStatement),
|
||||
/// ```sql
|
||||
/// RETURN [ expression ]
|
||||
/// ```
|
||||
///
|
||||
/// See [ReturnStatement]
|
||||
Return(ReturnStatement),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
|
@ -5753,6 +5785,7 @@ impl fmt::Display for Statement {
|
|||
Ok(())
|
||||
}
|
||||
Statement::Print(s) => write!(f, "{s}"),
|
||||
Statement::Return(r) => write!(f, "{r}"),
|
||||
Statement::List(command) => write!(f, "LIST {command}"),
|
||||
Statement::Remove(command) => write!(f, "REMOVE {command}"),
|
||||
}
|
||||
|
@ -8355,6 +8388,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
|
|||
///
|
||||
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
|
||||
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
|
||||
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||
|
@ -8383,6 +8417,22 @@ pub enum CreateFunctionBody {
|
|||
///
|
||||
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
|
||||
AsAfterOptions(Expr),
|
||||
/// Function body with statements before the `RETURN` keyword.
|
||||
///
|
||||
/// Example:
|
||||
/// ```sql
|
||||
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
|
||||
/// RETURNS INT
|
||||
/// AS
|
||||
/// BEGIN
|
||||
/// DECLARE c INT;
|
||||
/// SET c = a + b;
|
||||
/// RETURN c;
|
||||
/// END
|
||||
/// ```
|
||||
///
|
||||
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
|
||||
AsBeginEnd(BeginEndStatements),
|
||||
/// Function body expression using the 'RETURN' keyword.
|
||||
///
|
||||
/// Example:
|
||||
|
@ -9231,6 +9281,34 @@ impl fmt::Display for PrintStatement {
|
|||
}
|
||||
}
|
||||
|
||||
/// Represents a `Return` statement.
|
||||
///
|
||||
/// [MsSql triggers](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql)
|
||||
/// [MsSql functions](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||
pub struct ReturnStatement {
|
||||
pub value: Option<ReturnStatementValue>,
|
||||
}
|
||||
|
||||
impl fmt::Display for ReturnStatement {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match &self.value {
|
||||
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
|
||||
None => write!(f, "RETURN"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Variants of a `RETURN` statement
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||
pub enum ReturnStatementValue {
|
||||
Expr(Expr),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -23,8 +23,8 @@ use crate::tokenizer::Span;
|
|||
use super::{
|
||||
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
|
||||
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken,
|
||||
CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef,
|
||||
ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
|
||||
BeginEndStatements, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption,
|
||||
ColumnOptionDef, ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
|
||||
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
|
||||
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
|
||||
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
|
||||
|
@ -520,6 +520,7 @@ impl Spanned for Statement {
|
|||
Statement::RenameTable { .. } => Span::empty(),
|
||||
Statement::RaisError { .. } => Span::empty(),
|
||||
Statement::Print { .. } => Span::empty(),
|
||||
Statement::Return { .. } => Span::empty(),
|
||||
Statement::List(..) | Statement::Remove(..) => Span::empty(),
|
||||
}
|
||||
}
|
||||
|
@ -778,11 +779,7 @@ impl Spanned for ConditionalStatements {
|
|||
ConditionalStatements::Sequence { statements } => {
|
||||
union_spans(statements.iter().map(|s| s.span()))
|
||||
}
|
||||
ConditionalStatements::BeginEnd {
|
||||
begin_token: AttachedToken(start),
|
||||
statements: _,
|
||||
end_token: AttachedToken(end),
|
||||
} => union_spans([start.span, end.span].into_iter()),
|
||||
ConditionalStatements::BeginEnd(bes) => bes.span(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2282,6 +2279,21 @@ impl Spanned for TableObject {
|
|||
}
|
||||
}
|
||||
|
||||
impl Spanned for BeginEndStatements {
|
||||
fn span(&self) -> Span {
|
||||
let BeginEndStatements {
|
||||
begin_token,
|
||||
statements,
|
||||
end_token,
|
||||
} = self;
|
||||
union_spans(
|
||||
core::iter::once(begin_token.0.span)
|
||||
.chain(statements.iter().map(|i| i.span()))
|
||||
.chain(core::iter::once(end_token.0.span)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
// under the License.
|
||||
|
||||
use crate::ast::helpers::attached_token::AttachedToken;
|
||||
use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement};
|
||||
use crate::ast::{
|
||||
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
|
||||
};
|
||||
use crate::dialect::Dialect;
|
||||
use crate::keywords::{self, Keyword};
|
||||
use crate::parser::{Parser, ParserError};
|
||||
|
@ -149,11 +151,11 @@ impl MsSqlDialect {
|
|||
start_token: AttachedToken(if_token),
|
||||
condition: Some(condition),
|
||||
then_token: None,
|
||||
conditional_statements: ConditionalStatements::BeginEnd {
|
||||
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
|
||||
begin_token: AttachedToken(begin_token),
|
||||
statements,
|
||||
end_token: AttachedToken(end_token),
|
||||
},
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
let stmt = parser.parse_statement()?;
|
||||
|
@ -167,8 +169,10 @@ impl MsSqlDialect {
|
|||
}
|
||||
};
|
||||
|
||||
let mut prior_statement_ended_with_semi_colon = false;
|
||||
while let Token::SemiColon = parser.peek_token_ref().token {
|
||||
parser.advance_token();
|
||||
prior_statement_ended_with_semi_colon = true;
|
||||
}
|
||||
|
||||
let mut else_block = None;
|
||||
|
@ -182,11 +186,11 @@ impl MsSqlDialect {
|
|||
start_token: AttachedToken(else_token),
|
||||
condition: None,
|
||||
then_token: None,
|
||||
conditional_statements: ConditionalStatements::BeginEnd {
|
||||
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
|
||||
begin_token: AttachedToken(begin_token),
|
||||
statements,
|
||||
end_token: AttachedToken(end_token),
|
||||
},
|
||||
}),
|
||||
});
|
||||
} else {
|
||||
let stmt = parser.parse_statement()?;
|
||||
|
@ -199,6 +203,8 @@ impl MsSqlDialect {
|
|||
},
|
||||
});
|
||||
}
|
||||
} else if prior_statement_ended_with_semi_colon {
|
||||
parser.prev_token();
|
||||
}
|
||||
|
||||
Ok(Statement::If(IfStatement {
|
||||
|
|
|
@ -577,13 +577,7 @@ impl<'a> Parser<'a> {
|
|||
Keyword::GRANT => self.parse_grant(),
|
||||
Keyword::REVOKE => self.parse_revoke(),
|
||||
Keyword::START => self.parse_start_transaction(),
|
||||
// `BEGIN` is a nonstandard but common alias for the
|
||||
// standard `START TRANSACTION` statement. It is supported
|
||||
// by at least PostgreSQL and MySQL.
|
||||
Keyword::BEGIN => self.parse_begin(),
|
||||
// `END` is a nonstandard but common alias for the
|
||||
// standard `COMMIT TRANSACTION` statement. It is supported
|
||||
// by PostgreSQL.
|
||||
Keyword::END => self.parse_end(),
|
||||
Keyword::SAVEPOINT => self.parse_savepoint(),
|
||||
Keyword::RELEASE => self.parse_release(),
|
||||
|
@ -618,6 +612,7 @@ impl<'a> Parser<'a> {
|
|||
// `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment
|
||||
Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(),
|
||||
Keyword::PRINT => self.parse_print(),
|
||||
Keyword::RETURN => self.parse_return(),
|
||||
_ => self.expected("an SQL statement", next_token),
|
||||
},
|
||||
Token::LParen => {
|
||||
|
@ -4458,7 +4453,6 @@ impl<'a> Parser<'a> {
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
values.push(self.parse_statement()?);
|
||||
self.expect_token(&Token::SemiColon)?;
|
||||
}
|
||||
|
@ -4560,7 +4554,7 @@ impl<'a> Parser<'a> {
|
|||
} else if self.parse_keyword(Keyword::EXTERNAL) {
|
||||
self.parse_create_external_table(or_replace)
|
||||
} else if self.parse_keyword(Keyword::FUNCTION) {
|
||||
self.parse_create_function(or_replace, temporary)
|
||||
self.parse_create_function(or_alter, or_replace, temporary)
|
||||
} else if self.parse_keyword(Keyword::TRIGGER) {
|
||||
self.parse_create_trigger(or_replace, false)
|
||||
} else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) {
|
||||
|
@ -4869,6 +4863,7 @@ impl<'a> Parser<'a> {
|
|||
|
||||
pub fn parse_create_function(
|
||||
&mut self,
|
||||
or_alter: bool,
|
||||
or_replace: bool,
|
||||
temporary: bool,
|
||||
) -> Result<Statement, ParserError> {
|
||||
|
@ -4880,6 +4875,8 @@ impl<'a> Parser<'a> {
|
|||
self.parse_create_macro(or_replace, temporary)
|
||||
} else if dialect_of!(self is BigQueryDialect) {
|
||||
self.parse_bigquery_create_function(or_replace, temporary)
|
||||
} else if dialect_of!(self is MsSqlDialect) {
|
||||
self.parse_mssql_create_function(or_alter, or_replace, temporary)
|
||||
} else {
|
||||
self.prev_token();
|
||||
self.expected("an object type after CREATE", self.peek_token())
|
||||
|
@ -4994,6 +4991,7 @@ impl<'a> Parser<'a> {
|
|||
}
|
||||
|
||||
Ok(Statement::CreateFunction(CreateFunction {
|
||||
or_alter: false,
|
||||
or_replace,
|
||||
temporary,
|
||||
name,
|
||||
|
@ -5027,6 +5025,7 @@ impl<'a> Parser<'a> {
|
|||
let using = self.parse_optional_create_function_using()?;
|
||||
|
||||
Ok(Statement::CreateFunction(CreateFunction {
|
||||
or_alter: false,
|
||||
or_replace,
|
||||
temporary,
|
||||
name,
|
||||
|
@ -5054,22 +5053,7 @@ impl<'a> Parser<'a> {
|
|||
temporary: bool,
|
||||
) -> Result<Statement, ParserError> {
|
||||
let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]);
|
||||
let name = self.parse_object_name(false)?;
|
||||
|
||||
let parse_function_param =
|
||||
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
|
||||
let name = parser.parse_identifier()?;
|
||||
let data_type = parser.parse_data_type()?;
|
||||
Ok(OperateFunctionArg {
|
||||
mode: None,
|
||||
name: Some(name),
|
||||
data_type,
|
||||
default_expr: None,
|
||||
})
|
||||
};
|
||||
self.expect_token(&Token::LParen)?;
|
||||
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
|
||||
self.expect_token(&Token::RParen)?;
|
||||
let (name, args) = self.parse_create_function_name_and_params()?;
|
||||
|
||||
let return_type = if self.parse_keyword(Keyword::RETURNS) {
|
||||
Some(self.parse_data_type()?)
|
||||
|
@ -5116,6 +5100,7 @@ impl<'a> Parser<'a> {
|
|||
};
|
||||
|
||||
Ok(Statement::CreateFunction(CreateFunction {
|
||||
or_alter: false,
|
||||
or_replace,
|
||||
temporary,
|
||||
if_not_exists,
|
||||
|
@ -5134,6 +5119,73 @@ impl<'a> Parser<'a> {
|
|||
}))
|
||||
}
|
||||
|
||||
/// Parse `CREATE FUNCTION` for [MsSql]
|
||||
///
|
||||
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
|
||||
fn parse_mssql_create_function(
|
||||
&mut self,
|
||||
or_alter: bool,
|
||||
or_replace: bool,
|
||||
temporary: bool,
|
||||
) -> Result<Statement, ParserError> {
|
||||
let (name, args) = self.parse_create_function_name_and_params()?;
|
||||
|
||||
self.expect_keyword(Keyword::RETURNS)?;
|
||||
let return_type = Some(self.parse_data_type()?);
|
||||
|
||||
self.expect_keyword_is(Keyword::AS)?;
|
||||
|
||||
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
|
||||
let statements = self.parse_statement_list(&[Keyword::END])?;
|
||||
let end_token = self.expect_keyword(Keyword::END)?;
|
||||
|
||||
let function_body = Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
|
||||
begin_token: AttachedToken(begin_token),
|
||||
statements,
|
||||
end_token: AttachedToken(end_token),
|
||||
}));
|
||||
|
||||
Ok(Statement::CreateFunction(CreateFunction {
|
||||
or_alter,
|
||||
or_replace,
|
||||
temporary,
|
||||
if_not_exists: false,
|
||||
name,
|
||||
args: Some(args),
|
||||
return_type,
|
||||
function_body,
|
||||
language: None,
|
||||
determinism_specifier: None,
|
||||
options: None,
|
||||
remote_connection: None,
|
||||
using: None,
|
||||
behavior: None,
|
||||
called_on_null: None,
|
||||
parallel: None,
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_create_function_name_and_params(
|
||||
&mut self,
|
||||
) -> Result<(ObjectName, Vec<OperateFunctionArg>), ParserError> {
|
||||
let name = self.parse_object_name(false)?;
|
||||
let parse_function_param =
|
||||
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
|
||||
let name = parser.parse_identifier()?;
|
||||
let data_type = parser.parse_data_type()?;
|
||||
Ok(OperateFunctionArg {
|
||||
mode: None,
|
||||
name: Some(name),
|
||||
data_type,
|
||||
default_expr: None,
|
||||
})
|
||||
};
|
||||
self.expect_token(&Token::LParen)?;
|
||||
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
|
||||
self.expect_token(&Token::RParen)?;
|
||||
Ok((name, args))
|
||||
}
|
||||
|
||||
fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> {
|
||||
let mode = if self.parse_keyword(Keyword::IN) {
|
||||
Some(ArgMode::In)
|
||||
|
@ -15161,6 +15213,16 @@ impl<'a> Parser<'a> {
|
|||
}))
|
||||
}
|
||||
|
||||
/// Parse [Statement::Return]
|
||||
fn parse_return(&mut self) -> Result<Statement, ParserError> {
|
||||
match self.maybe_parse(|p| p.parse_expr())? {
|
||||
Some(expr) => Ok(Statement::Return(ReturnStatement {
|
||||
value: Some(ReturnStatementValue::Expr(expr)),
|
||||
})),
|
||||
None => Ok(Statement::Return(ReturnStatement { value: None })),
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume the parser and return its underlying token buffer
|
||||
pub fn into_tokens(self) -> Vec<TokenWithSpan> {
|
||||
self.tokens
|
||||
|
|
|
@ -2134,6 +2134,7 @@ fn test_bigquery_create_function() {
|
|||
assert_eq!(
|
||||
stmt,
|
||||
Statement::CreateFunction(CreateFunction {
|
||||
or_alter: false,
|
||||
or_replace: true,
|
||||
temporary: true,
|
||||
if_not_exists: false,
|
||||
|
|
|
@ -15079,3 +15079,11 @@ fn parse_set_time_zone_alias() {
|
|||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_return() {
|
||||
let stmt = all_dialects().verified_stmt("RETURN");
|
||||
assert_eq!(stmt, Statement::Return(ReturnStatement { value: None }));
|
||||
|
||||
let _ = all_dialects().verified_stmt("RETURN 1");
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ use sqlparser::ast::{
|
|||
Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
|
||||
OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
|
||||
};
|
||||
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
|
||||
use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect};
|
||||
use sqlparser::parser::ParserError;
|
||||
use sqlparser::test_utils::*;
|
||||
|
||||
|
@ -423,7 +423,7 @@ fn parse_create_function() {
|
|||
}
|
||||
|
||||
// Test error in dialect that doesn't support parsing CREATE FUNCTION
|
||||
let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]);
|
||||
let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]);
|
||||
|
||||
assert_eq!(
|
||||
unsupported_dialects.parse_sql_statements(sql).unwrap_err(),
|
||||
|
|
|
@ -187,6 +187,92 @@ fn parse_mssql_create_procedure() {
|
|||
let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10 END");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_create_function() {
|
||||
let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1; END";
|
||||
assert_eq!(
|
||||
ms().verified_stmt(return_expression_function),
|
||||
sqlparser::ast::Statement::CreateFunction(CreateFunction {
|
||||
or_alter: false,
|
||||
or_replace: false,
|
||||
temporary: false,
|
||||
if_not_exists: false,
|
||||
name: ObjectName::from(vec![Ident::new("some_scalar_udf")]),
|
||||
args: Some(vec![
|
||||
OperateFunctionArg {
|
||||
mode: None,
|
||||
name: Some(Ident::new("@foo")),
|
||||
data_type: DataType::Int(None),
|
||||
default_expr: None,
|
||||
},
|
||||
OperateFunctionArg {
|
||||
mode: None,
|
||||
name: Some(Ident::new("@bar")),
|
||||
data_type: DataType::Varchar(Some(CharacterLength::IntegerLength {
|
||||
length: 256,
|
||||
unit: None
|
||||
})),
|
||||
default_expr: None,
|
||||
},
|
||||
]),
|
||||
return_type: Some(DataType::Int(None)),
|
||||
function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
|
||||
begin_token: AttachedToken::empty(),
|
||||
statements: vec![Statement::Return(ReturnStatement {
|
||||
value: Some(ReturnStatementValue::Expr(Expr::Value(
|
||||
(number("1")).with_empty_span()
|
||||
))),
|
||||
}),],
|
||||
end_token: AttachedToken::empty(),
|
||||
})),
|
||||
behavior: None,
|
||||
called_on_null: None,
|
||||
parallel: None,
|
||||
using: None,
|
||||
language: None,
|
||||
determinism_specifier: None,
|
||||
options: None,
|
||||
remote_connection: None,
|
||||
}),
|
||||
);
|
||||
|
||||
let multi_statement_function = "\
|
||||
CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
|
||||
RETURNS INT \
|
||||
AS \
|
||||
BEGIN \
|
||||
SET @foo = @foo + 1; \
|
||||
RETURN @foo; \
|
||||
END\
|
||||
";
|
||||
let _ = ms().verified_stmt(multi_statement_function);
|
||||
|
||||
let create_function_with_conditional = "\
|
||||
CREATE FUNCTION some_scalar_udf() \
|
||||
RETURNS INT \
|
||||
AS \
|
||||
BEGIN \
|
||||
IF 1 = 2 \
|
||||
BEGIN \
|
||||
RETURN 1; \
|
||||
END; \
|
||||
RETURN 0; \
|
||||
END\
|
||||
";
|
||||
let _ = ms().verified_stmt(create_function_with_conditional);
|
||||
|
||||
let create_or_alter_function = "\
|
||||
CREATE OR ALTER FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
|
||||
RETURNS INT \
|
||||
AS \
|
||||
BEGIN \
|
||||
SET @foo = @foo + 1; \
|
||||
RETURN @foo; \
|
||||
END\
|
||||
";
|
||||
let _ = ms().verified_stmt(create_or_alter_function);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mssql_apply_join() {
|
||||
let _ = ms_and_generic().verified_only_select(
|
||||
|
|
|
@ -4104,6 +4104,7 @@ fn parse_create_function() {
|
|||
assert_eq!(
|
||||
pg_and_generic().verified_stmt(sql),
|
||||
Statement::CreateFunction(CreateFunction {
|
||||
or_alter: false,
|
||||
or_replace: false,
|
||||
temporary: false,
|
||||
name: ObjectName::from(vec![Ident::new("add")]),
|
||||
|
@ -5485,6 +5486,7 @@ fn parse_trigger_related_functions() {
|
|||
assert_eq!(
|
||||
create_function,
|
||||
Statement::CreateFunction(CreateFunction {
|
||||
or_alter: false,
|
||||
or_replace: false,
|
||||
temporary: false,
|
||||
if_not_exists: false,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue