diff --git a/src/ast/mod.rs b/src/ast/mod.rs index a378b58b..9df0b5de 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -5683,6 +5683,46 @@ impl fmt::Display for FunctionBehavior { } } +/// These attributes describe the behavior of the function when called with a null argument. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum FunctionCalledOnNull { + CalledOnNullInput, + ReturnsNullOnNullInput, + Strict, +} + +impl fmt::Display for FunctionCalledOnNull { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FunctionCalledOnNull::CalledOnNullInput => write!(f, "CALLED ON NULL INPUT"), + FunctionCalledOnNull::ReturnsNullOnNullInput => write!(f, "RETURNS NULL ON NULL INPUT"), + FunctionCalledOnNull::Strict => write!(f, "STRICT"), + } + } +} + +/// If it is safe for PostgreSQL to call the function from multiple threads at once +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum FunctionParallel { + Unsafe, + Restricted, + Safe, +} + +impl fmt::Display for FunctionParallel { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FunctionParallel::Unsafe => write!(f, "PARALLEL UNSAFE"), + FunctionParallel::Restricted => write!(f, "PARALLEL RESTRICTED"), + FunctionParallel::Safe => write!(f, "PARALLEL SAFE"), + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -5703,7 +5743,7 @@ impl fmt::Display for FunctionDefinition { /// Postgres specific feature. /// -/// See [Postgresdocs](https://www.postgresql.org/docs/15/sql-createfunction.html) +/// See [Postgres docs](https://www.postgresql.org/docs/15/sql-createfunction.html) /// for more details #[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -5713,6 +5753,10 @@ pub struct CreateFunctionBody { pub language: Option, /// IMMUTABLE | STABLE | VOLATILE pub behavior: Option, + /// CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT + pub called_on_null: Option, + /// PARALLEL { UNSAFE | RESTRICTED | SAFE } + pub parallel: Option, /// AS 'definition' /// /// Note that Hive's `AS class_name` is also parsed here. @@ -5731,6 +5775,12 @@ impl fmt::Display for CreateFunctionBody { if let Some(behavior) = &self.behavior { write!(f, " {behavior}")?; } + if let Some(called_on_null) = &self.called_on_null { + write!(f, " {called_on_null}")?; + } + if let Some(parallel) = &self.parallel { + write!(f, " {parallel}")?; + } if let Some(definition) = &self.as_ { write!(f, " AS {definition}")?; } diff --git a/src/keywords.rs b/src/keywords.rs index c94a6227..fa7d133e 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -353,6 +353,7 @@ define_keywords!( INITIALLY, INNER, INOUT, + INPUT, INPUTFORMAT, INSENSITIVE, INSERT, @@ -498,6 +499,7 @@ define_keywords!( OVERLAY, OVERWRITE, OWNED, + PARALLEL, PARAMETER, PARQUET, PARTITION, @@ -570,6 +572,7 @@ define_keywords!( RESPECT, RESTART, RESTRICT, + RESTRICTED, RESULT, RESULTSET, RETAIN, @@ -589,6 +592,7 @@ define_keywords!( ROW_NUMBER, RULE, RUN, + SAFE, SAFE_CAST, SAVEPOINT, SCHEMA, @@ -704,6 +708,7 @@ define_keywords!( UNLOGGED, UNNEST, UNPIVOT, + UNSAFE, UNSIGNED, UNTIL, UPDATE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a3d7a7cf..235c1f1d 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -3437,6 +3437,46 @@ impl<'a> Parser<'a> { } else if self.parse_keyword(Keyword::VOLATILE) { ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?; body.behavior = Some(FunctionBehavior::Volatile); + } else if self.parse_keywords(&[ + Keyword::CALLED, + Keyword::ON, + Keyword::NULL, + Keyword::INPUT, + ]) { + ensure_not_set( + &body.called_on_null, + "CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT", + )?; + body.called_on_null = Some(FunctionCalledOnNull::CalledOnNullInput); + } else if self.parse_keywords(&[ + Keyword::RETURNS, + Keyword::NULL, + Keyword::ON, + Keyword::NULL, + Keyword::INPUT, + ]) { + ensure_not_set( + &body.called_on_null, + "CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT", + )?; + body.called_on_null = Some(FunctionCalledOnNull::ReturnsNullOnNullInput); + } else if self.parse_keyword(Keyword::STRICT) { + ensure_not_set( + &body.called_on_null, + "CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT", + )?; + body.called_on_null = Some(FunctionCalledOnNull::Strict); + } else if self.parse_keyword(Keyword::PARALLEL) { + ensure_not_set(&body.parallel, "PARALLEL { UNSAFE | RESTRICTED | SAFE }")?; + if self.parse_keyword(Keyword::UNSAFE) { + body.parallel = Some(FunctionParallel::Unsafe); + } else if self.parse_keyword(Keyword::RESTRICTED) { + body.parallel = Some(FunctionParallel::Restricted); + } else if self.parse_keyword(Keyword::SAFE) { + body.parallel = Some(FunctionParallel::Safe); + } else { + return self.expected("one of UNSAFE | RESTRICTED | SAFE", self.peek_token()); + } } else if self.parse_keyword(Keyword::RETURN) { ensure_not_set(&body.return_, "RETURN")?; body.return_ = Some(self.parse_expr()?); diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 4a92cd45..8515956f 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -3280,7 +3280,7 @@ fn parse_similar_to() { #[test] fn parse_create_function() { - let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'"; + let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'"; assert_eq!( pg_and_generic().verified_stmt(sql), Statement::CreateFunction { @@ -3295,6 +3295,8 @@ fn parse_create_function() { params: CreateFunctionBody { language: Some("SQL".into()), behavior: Some(FunctionBehavior::Immutable), + called_on_null: Some(FunctionCalledOnNull::Strict), + parallel: Some(FunctionParallel::Safe), as_: Some(FunctionDefinition::SingleQuotedDef( "select $1 + $2;".into() )), @@ -3303,7 +3305,7 @@ fn parse_create_function() { } ); - let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURN a + b"; + let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT PARALLEL RESTRICTED RETURN a + b"; assert_eq!( pg_and_generic().verified_stmt(sql), Statement::CreateFunction { @@ -3323,6 +3325,40 @@ fn parse_create_function() { params: CreateFunctionBody { language: Some("SQL".into()), behavior: Some(FunctionBehavior::Immutable), + called_on_null: Some(FunctionCalledOnNull::ReturnsNullOnNullInput), + parallel: Some(FunctionParallel::Restricted), + return_: Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier("a".into())), + op: BinaryOperator::Plus, + right: Box::new(Expr::Identifier("b".into())), + }), + ..Default::default() + }, + } + ); + + let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL STABLE CALLED ON NULL INPUT PARALLEL UNSAFE RETURN a + b"; + assert_eq!( + pg_and_generic().verified_stmt(sql), + Statement::CreateFunction { + or_replace: true, + temporary: false, + name: ObjectName(vec![Ident::new("add")]), + args: Some(vec![ + OperateFunctionArg::with_name("a", DataType::Integer(None)), + OperateFunctionArg { + mode: Some(ArgMode::In), + name: Some("b".into()), + data_type: DataType::Integer(None), + default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))), + } + ]), + return_type: Some(DataType::Integer(None)), + params: CreateFunctionBody { + language: Some("SQL".into()), + behavior: Some(FunctionBehavior::Stable), + called_on_null: Some(FunctionCalledOnNull::CalledOnNullInput), + parallel: Some(FunctionParallel::Unsafe), return_: Some(Expr::BinaryOp { left: Box::new(Expr::Identifier("a".into())), op: BinaryOperator::Plus, @@ -3348,6 +3384,8 @@ fn parse_create_function() { params: CreateFunctionBody { language: Some("plpgsql".into()), behavior: None, + called_on_null: None, + parallel: None, return_: None, as_: Some(FunctionDefinition::DoubleDollarDef( " BEGIN RETURN i + 1; END; ".into() @@ -3358,6 +3396,12 @@ fn parse_create_function() { ); } +#[test] +fn parse_incorrect_create_function_parallel() { + let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'"; + assert!(pg().parse_sql_statements(sql).is_err()); +} + #[test] fn parse_drop_function() { let sql = "DROP FUNCTION IF EXISTS test_func";