Support PARALLEL ... and for ..ON NULL INPUT ... to CREATE FUNCTION` (#1202)

This commit is contained in:
Daniel Imfeld 2024-04-06 07:03:00 -10:00 committed by GitHub
parent 14b33ac493
commit 2bf93a470c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 142 additions and 3 deletions

View file

@ -5683,6 +5683,46 @@ impl fmt::Display for FunctionBehavior {
} }
} }
/// These attributes describe the behavior of the function when called with a null argument.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionCalledOnNull {
CalledOnNullInput,
ReturnsNullOnNullInput,
Strict,
}
impl fmt::Display for FunctionCalledOnNull {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionCalledOnNull::CalledOnNullInput => write!(f, "CALLED ON NULL INPUT"),
FunctionCalledOnNull::ReturnsNullOnNullInput => write!(f, "RETURNS NULL ON NULL INPUT"),
FunctionCalledOnNull::Strict => write!(f, "STRICT"),
}
}
}
/// If it is safe for PostgreSQL to call the function from multiple threads at once
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionParallel {
Unsafe,
Restricted,
Safe,
}
impl fmt::Display for FunctionParallel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionParallel::Unsafe => write!(f, "PARALLEL UNSAFE"),
FunctionParallel::Restricted => write!(f, "PARALLEL RESTRICTED"),
FunctionParallel::Safe => write!(f, "PARALLEL SAFE"),
}
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@ -5703,7 +5743,7 @@ impl fmt::Display for FunctionDefinition {
/// Postgres specific feature. /// Postgres specific feature.
/// ///
/// See [Postgresdocs](https://www.postgresql.org/docs/15/sql-createfunction.html) /// See [Postgres docs](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// for more details /// for more details
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@ -5713,6 +5753,10 @@ pub struct CreateFunctionBody {
pub language: Option<Ident>, pub language: Option<Ident>,
/// IMMUTABLE | STABLE | VOLATILE /// IMMUTABLE | STABLE | VOLATILE
pub behavior: Option<FunctionBehavior>, pub behavior: Option<FunctionBehavior>,
/// CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT
pub called_on_null: Option<FunctionCalledOnNull>,
/// PARALLEL { UNSAFE | RESTRICTED | SAFE }
pub parallel: Option<FunctionParallel>,
/// AS 'definition' /// AS 'definition'
/// ///
/// Note that Hive's `AS class_name` is also parsed here. /// Note that Hive's `AS class_name` is also parsed here.
@ -5731,6 +5775,12 @@ impl fmt::Display for CreateFunctionBody {
if let Some(behavior) = &self.behavior { if let Some(behavior) = &self.behavior {
write!(f, " {behavior}")?; write!(f, " {behavior}")?;
} }
if let Some(called_on_null) = &self.called_on_null {
write!(f, " {called_on_null}")?;
}
if let Some(parallel) = &self.parallel {
write!(f, " {parallel}")?;
}
if let Some(definition) = &self.as_ { if let Some(definition) = &self.as_ {
write!(f, " AS {definition}")?; write!(f, " AS {definition}")?;
} }

View file

@ -353,6 +353,7 @@ define_keywords!(
INITIALLY, INITIALLY,
INNER, INNER,
INOUT, INOUT,
INPUT,
INPUTFORMAT, INPUTFORMAT,
INSENSITIVE, INSENSITIVE,
INSERT, INSERT,
@ -498,6 +499,7 @@ define_keywords!(
OVERLAY, OVERLAY,
OVERWRITE, OVERWRITE,
OWNED, OWNED,
PARALLEL,
PARAMETER, PARAMETER,
PARQUET, PARQUET,
PARTITION, PARTITION,
@ -570,6 +572,7 @@ define_keywords!(
RESPECT, RESPECT,
RESTART, RESTART,
RESTRICT, RESTRICT,
RESTRICTED,
RESULT, RESULT,
RESULTSET, RESULTSET,
RETAIN, RETAIN,
@ -589,6 +592,7 @@ define_keywords!(
ROW_NUMBER, ROW_NUMBER,
RULE, RULE,
RUN, RUN,
SAFE,
SAFE_CAST, SAFE_CAST,
SAVEPOINT, SAVEPOINT,
SCHEMA, SCHEMA,
@ -704,6 +708,7 @@ define_keywords!(
UNLOGGED, UNLOGGED,
UNNEST, UNNEST,
UNPIVOT, UNPIVOT,
UNSAFE,
UNSIGNED, UNSIGNED,
UNTIL, UNTIL,
UPDATE, UPDATE,

View file

@ -3437,6 +3437,46 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::VOLATILE) { } else if self.parse_keyword(Keyword::VOLATILE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?; ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Volatile); body.behavior = Some(FunctionBehavior::Volatile);
} else if self.parse_keywords(&[
Keyword::CALLED,
Keyword::ON,
Keyword::NULL,
Keyword::INPUT,
]) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::CalledOnNullInput);
} else if self.parse_keywords(&[
Keyword::RETURNS,
Keyword::NULL,
Keyword::ON,
Keyword::NULL,
Keyword::INPUT,
]) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::ReturnsNullOnNullInput);
} else if self.parse_keyword(Keyword::STRICT) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::Strict);
} else if self.parse_keyword(Keyword::PARALLEL) {
ensure_not_set(&body.parallel, "PARALLEL { UNSAFE | RESTRICTED | SAFE }")?;
if self.parse_keyword(Keyword::UNSAFE) {
body.parallel = Some(FunctionParallel::Unsafe);
} else if self.parse_keyword(Keyword::RESTRICTED) {
body.parallel = Some(FunctionParallel::Restricted);
} else if self.parse_keyword(Keyword::SAFE) {
body.parallel = Some(FunctionParallel::Safe);
} else {
return self.expected("one of UNSAFE | RESTRICTED | SAFE", self.peek_token());
}
} else if self.parse_keyword(Keyword::RETURN) { } else if self.parse_keyword(Keyword::RETURN) {
ensure_not_set(&body.return_, "RETURN")?; ensure_not_set(&body.return_, "RETURN")?;
body.return_ = Some(self.parse_expr()?); body.return_ = Some(self.parse_expr()?);

View file

@ -3280,7 +3280,7 @@ fn parse_similar_to() {
#[test] #[test]
fn parse_create_function() { fn parse_create_function() {
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'"; let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";
assert_eq!( assert_eq!(
pg_and_generic().verified_stmt(sql), pg_and_generic().verified_stmt(sql),
Statement::CreateFunction { Statement::CreateFunction {
@ -3295,6 +3295,8 @@ fn parse_create_function() {
params: CreateFunctionBody { params: CreateFunctionBody {
language: Some("SQL".into()), language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable), behavior: Some(FunctionBehavior::Immutable),
called_on_null: Some(FunctionCalledOnNull::Strict),
parallel: Some(FunctionParallel::Safe),
as_: Some(FunctionDefinition::SingleQuotedDef( as_: Some(FunctionDefinition::SingleQuotedDef(
"select $1 + $2;".into() "select $1 + $2;".into()
)), )),
@ -3303,7 +3305,7 @@ fn parse_create_function() {
} }
); );
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURN a + b"; let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT PARALLEL RESTRICTED RETURN a + b";
assert_eq!( assert_eq!(
pg_and_generic().verified_stmt(sql), pg_and_generic().verified_stmt(sql),
Statement::CreateFunction { Statement::CreateFunction {
@ -3323,6 +3325,40 @@ fn parse_create_function() {
params: CreateFunctionBody { params: CreateFunctionBody {
language: Some("SQL".into()), language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable), behavior: Some(FunctionBehavior::Immutable),
called_on_null: Some(FunctionCalledOnNull::ReturnsNullOnNullInput),
parallel: Some(FunctionParallel::Restricted),
return_: Some(Expr::BinaryOp {
left: Box::new(Expr::Identifier("a".into())),
op: BinaryOperator::Plus,
right: Box::new(Expr::Identifier("b".into())),
}),
..Default::default()
},
}
);
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL STABLE CALLED ON NULL INPUT PARALLEL UNSAFE RETURN a + b";
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction {
or_replace: true,
temporary: false,
name: ObjectName(vec![Ident::new("add")]),
args: Some(vec![
OperateFunctionArg::with_name("a", DataType::Integer(None)),
OperateFunctionArg {
mode: Some(ArgMode::In),
name: Some("b".into()),
data_type: DataType::Integer(None),
default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))),
}
]),
return_type: Some(DataType::Integer(None)),
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Stable),
called_on_null: Some(FunctionCalledOnNull::CalledOnNullInput),
parallel: Some(FunctionParallel::Unsafe),
return_: Some(Expr::BinaryOp { return_: Some(Expr::BinaryOp {
left: Box::new(Expr::Identifier("a".into())), left: Box::new(Expr::Identifier("a".into())),
op: BinaryOperator::Plus, op: BinaryOperator::Plus,
@ -3348,6 +3384,8 @@ fn parse_create_function() {
params: CreateFunctionBody { params: CreateFunctionBody {
language: Some("plpgsql".into()), language: Some("plpgsql".into()),
behavior: None, behavior: None,
called_on_null: None,
parallel: None,
return_: None, return_: None,
as_: Some(FunctionDefinition::DoubleDollarDef( as_: Some(FunctionDefinition::DoubleDollarDef(
" BEGIN RETURN i + 1; END; ".into() " BEGIN RETURN i + 1; END; ".into()
@ -3358,6 +3396,12 @@ fn parse_create_function() {
); );
} }
#[test]
fn parse_incorrect_create_function_parallel() {
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'";
assert!(pg().parse_sql_statements(sql).is_err());
}
#[test] #[test]
fn parse_drop_function() { fn parse_drop_function() {
let sql = "DROP FUNCTION IF EXISTS test_func"; let sql = "DROP FUNCTION IF EXISTS test_func";