Support PARALLEL ... and for ..ON NULL INPUT ... to CREATE FUNCTION` (#1202)

This commit is contained in:
Daniel Imfeld 2024-04-06 07:03:00 -10:00 committed by GitHub
parent 14b33ac493
commit 2bf93a470c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 142 additions and 3 deletions

View file

@ -5683,6 +5683,46 @@ impl fmt::Display for FunctionBehavior {
}
}
/// These attributes describe the behavior of the function when called with a null argument.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionCalledOnNull {
CalledOnNullInput,
ReturnsNullOnNullInput,
Strict,
}
impl fmt::Display for FunctionCalledOnNull {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionCalledOnNull::CalledOnNullInput => write!(f, "CALLED ON NULL INPUT"),
FunctionCalledOnNull::ReturnsNullOnNullInput => write!(f, "RETURNS NULL ON NULL INPUT"),
FunctionCalledOnNull::Strict => write!(f, "STRICT"),
}
}
}
/// If it is safe for PostgreSQL to call the function from multiple threads at once
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionParallel {
Unsafe,
Restricted,
Safe,
}
impl fmt::Display for FunctionParallel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionParallel::Unsafe => write!(f, "PARALLEL UNSAFE"),
FunctionParallel::Restricted => write!(f, "PARALLEL RESTRICTED"),
FunctionParallel::Safe => write!(f, "PARALLEL SAFE"),
}
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@ -5703,7 +5743,7 @@ impl fmt::Display for FunctionDefinition {
/// Postgres specific feature.
///
/// See [Postgresdocs](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// See [Postgres docs](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// for more details
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@ -5713,6 +5753,10 @@ pub struct CreateFunctionBody {
pub language: Option<Ident>,
/// IMMUTABLE | STABLE | VOLATILE
pub behavior: Option<FunctionBehavior>,
/// CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT
pub called_on_null: Option<FunctionCalledOnNull>,
/// PARALLEL { UNSAFE | RESTRICTED | SAFE }
pub parallel: Option<FunctionParallel>,
/// AS 'definition'
///
/// Note that Hive's `AS class_name` is also parsed here.
@ -5731,6 +5775,12 @@ impl fmt::Display for CreateFunctionBody {
if let Some(behavior) = &self.behavior {
write!(f, " {behavior}")?;
}
if let Some(called_on_null) = &self.called_on_null {
write!(f, " {called_on_null}")?;
}
if let Some(parallel) = &self.parallel {
write!(f, " {parallel}")?;
}
if let Some(definition) = &self.as_ {
write!(f, " AS {definition}")?;
}

View file

@ -353,6 +353,7 @@ define_keywords!(
INITIALLY,
INNER,
INOUT,
INPUT,
INPUTFORMAT,
INSENSITIVE,
INSERT,
@ -498,6 +499,7 @@ define_keywords!(
OVERLAY,
OVERWRITE,
OWNED,
PARALLEL,
PARAMETER,
PARQUET,
PARTITION,
@ -570,6 +572,7 @@ define_keywords!(
RESPECT,
RESTART,
RESTRICT,
RESTRICTED,
RESULT,
RESULTSET,
RETAIN,
@ -589,6 +592,7 @@ define_keywords!(
ROW_NUMBER,
RULE,
RUN,
SAFE,
SAFE_CAST,
SAVEPOINT,
SCHEMA,
@ -704,6 +708,7 @@ define_keywords!(
UNLOGGED,
UNNEST,
UNPIVOT,
UNSAFE,
UNSIGNED,
UNTIL,
UPDATE,

View file

@ -3437,6 +3437,46 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::VOLATILE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Volatile);
} else if self.parse_keywords(&[
Keyword::CALLED,
Keyword::ON,
Keyword::NULL,
Keyword::INPUT,
]) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::CalledOnNullInput);
} else if self.parse_keywords(&[
Keyword::RETURNS,
Keyword::NULL,
Keyword::ON,
Keyword::NULL,
Keyword::INPUT,
]) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::ReturnsNullOnNullInput);
} else if self.parse_keyword(Keyword::STRICT) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::Strict);
} else if self.parse_keyword(Keyword::PARALLEL) {
ensure_not_set(&body.parallel, "PARALLEL { UNSAFE | RESTRICTED | SAFE }")?;
if self.parse_keyword(Keyword::UNSAFE) {
body.parallel = Some(FunctionParallel::Unsafe);
} else if self.parse_keyword(Keyword::RESTRICTED) {
body.parallel = Some(FunctionParallel::Restricted);
} else if self.parse_keyword(Keyword::SAFE) {
body.parallel = Some(FunctionParallel::Safe);
} else {
return self.expected("one of UNSAFE | RESTRICTED | SAFE", self.peek_token());
}
} else if self.parse_keyword(Keyword::RETURN) {
ensure_not_set(&body.return_, "RETURN")?;
body.return_ = Some(self.parse_expr()?);

View file

@ -3280,7 +3280,7 @@ fn parse_similar_to() {
#[test]
fn parse_create_function() {
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'";
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction {
@ -3295,6 +3295,8 @@ fn parse_create_function() {
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable),
called_on_null: Some(FunctionCalledOnNull::Strict),
parallel: Some(FunctionParallel::Safe),
as_: Some(FunctionDefinition::SingleQuotedDef(
"select $1 + $2;".into()
)),
@ -3303,7 +3305,7 @@ fn parse_create_function() {
}
);
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURN a + b";
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT PARALLEL RESTRICTED RETURN a + b";
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction {
@ -3323,6 +3325,40 @@ fn parse_create_function() {
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable),
called_on_null: Some(FunctionCalledOnNull::ReturnsNullOnNullInput),
parallel: Some(FunctionParallel::Restricted),
return_: Some(Expr::BinaryOp {
left: Box::new(Expr::Identifier("a".into())),
op: BinaryOperator::Plus,
right: Box::new(Expr::Identifier("b".into())),
}),
..Default::default()
},
}
);
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL STABLE CALLED ON NULL INPUT PARALLEL UNSAFE RETURN a + b";
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction {
or_replace: true,
temporary: false,
name: ObjectName(vec![Ident::new("add")]),
args: Some(vec![
OperateFunctionArg::with_name("a", DataType::Integer(None)),
OperateFunctionArg {
mode: Some(ArgMode::In),
name: Some("b".into()),
data_type: DataType::Integer(None),
default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))),
}
]),
return_type: Some(DataType::Integer(None)),
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Stable),
called_on_null: Some(FunctionCalledOnNull::CalledOnNullInput),
parallel: Some(FunctionParallel::Unsafe),
return_: Some(Expr::BinaryOp {
left: Box::new(Expr::Identifier("a".into())),
op: BinaryOperator::Plus,
@ -3348,6 +3384,8 @@ fn parse_create_function() {
params: CreateFunctionBody {
language: Some("plpgsql".into()),
behavior: None,
called_on_null: None,
parallel: None,
return_: None,
as_: Some(FunctionDefinition::DoubleDollarDef(
" BEGIN RETURN i + 1; END; ".into()
@ -3358,6 +3396,12 @@ fn parse_create_function() {
);
}
#[test]
fn parse_incorrect_create_function_parallel() {
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'";
assert!(pg().parse_sql_statements(sql).is_err());
}
#[test]
fn parse_drop_function() {
let sql = "DROP FUNCTION IF EXISTS test_func";