mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-07-07 17:04:59 +00:00
Support postgres CREATE FUNCTION
(#722)
* support basic pg CREATE FUNCTION Signed-off-by: Runji Wang <wangrunji0408@163.com> * support function argument Signed-off-by: Runji Wang <wangrunji0408@163.com> * fix display and use verify in test Signed-off-by: Runji Wang <wangrunji0408@163.com> * support OR REPLACE Signed-off-by: Runji Wang <wangrunji0408@163.com> * fix compile error in bigdecimal Signed-off-by: Runji Wang <wangrunji0408@163.com> * unify all `CreateFunctionBody` to a structure Signed-off-by: Runji Wang <wangrunji0408@163.com> Signed-off-by: Runji Wang <wangrunji0408@163.com>
This commit is contained in:
parent
f621142f89
commit
5b53df97c4
5 changed files with 331 additions and 29 deletions
150
src/ast/mod.rs
150
src/ast/mod.rs
|
@ -1405,11 +1405,15 @@ pub enum Statement {
|
|||
/// CREATE FUNCTION
|
||||
///
|
||||
/// Hive: https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction
|
||||
/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
|
||||
CreateFunction {
|
||||
or_replace: bool,
|
||||
temporary: bool,
|
||||
name: ObjectName,
|
||||
class_name: String,
|
||||
using: Option<CreateFunctionUsing>,
|
||||
args: Option<Vec<CreateFunctionArg>>,
|
||||
return_type: Option<DataType>,
|
||||
/// Optional parameters.
|
||||
params: CreateFunctionBody,
|
||||
},
|
||||
/// `ASSERT <condition> [AS <message>]`
|
||||
Assert {
|
||||
|
@ -1866,19 +1870,26 @@ impl fmt::Display for Statement {
|
|||
Ok(())
|
||||
}
|
||||
Statement::CreateFunction {
|
||||
or_replace,
|
||||
temporary,
|
||||
name,
|
||||
class_name,
|
||||
using,
|
||||
args,
|
||||
return_type,
|
||||
params,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"CREATE {temp}FUNCTION {name} AS '{class_name}'",
|
||||
"CREATE {or_replace}{temp}FUNCTION {name}",
|
||||
temp = if *temporary { "TEMPORARY " } else { "" },
|
||||
or_replace = if *or_replace { "OR REPLACE " } else { "" },
|
||||
)?;
|
||||
if let Some(u) = using {
|
||||
write!(f, " {}", u)?;
|
||||
if let Some(args) = args {
|
||||
write!(f, "({})", display_comma_separated(args))?;
|
||||
}
|
||||
if let Some(return_type) = return_type {
|
||||
write!(f, " RETURNS {}", return_type)?;
|
||||
}
|
||||
write!(f, "{params}")?;
|
||||
Ok(())
|
||||
}
|
||||
Statement::CreateView {
|
||||
|
@ -3679,6 +3690,131 @@ impl fmt::Display for ContextModifier {
|
|||
}
|
||||
}
|
||||
|
||||
/// Function argument in CREATE FUNCTION.
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct CreateFunctionArg {
|
||||
pub mode: Option<ArgMode>,
|
||||
pub name: Option<Ident>,
|
||||
pub data_type: DataType,
|
||||
pub default_expr: Option<Expr>,
|
||||
}
|
||||
|
||||
impl CreateFunctionArg {
|
||||
/// Returns an unnamed argument.
|
||||
pub fn unnamed(data_type: DataType) -> Self {
|
||||
Self {
|
||||
mode: None,
|
||||
name: None,
|
||||
data_type,
|
||||
default_expr: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an argument with name.
|
||||
pub fn with_name(name: &str, data_type: DataType) -> Self {
|
||||
Self {
|
||||
mode: None,
|
||||
name: Some(name.into()),
|
||||
data_type,
|
||||
default_expr: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for CreateFunctionArg {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if let Some(mode) = &self.mode {
|
||||
write!(f, "{} ", mode)?;
|
||||
}
|
||||
if let Some(name) = &self.name {
|
||||
write!(f, "{} ", name)?;
|
||||
}
|
||||
write!(f, "{}", self.data_type)?;
|
||||
if let Some(default_expr) = &self.default_expr {
|
||||
write!(f, " = {}", default_expr)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// The mode of an argument in CREATE FUNCTION.
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum ArgMode {
|
||||
In,
|
||||
Out,
|
||||
InOut,
|
||||
}
|
||||
|
||||
impl fmt::Display for ArgMode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
ArgMode::In => write!(f, "IN"),
|
||||
ArgMode::Out => write!(f, "OUT"),
|
||||
ArgMode::InOut => write!(f, "INOUT"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// These attributes inform the query optimizer about the behavior of the function.
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum FunctionBehavior {
|
||||
Immutable,
|
||||
Stable,
|
||||
Volatile,
|
||||
}
|
||||
|
||||
impl fmt::Display for FunctionBehavior {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
FunctionBehavior::Immutable => write!(f, "IMMUTABLE"),
|
||||
FunctionBehavior::Stable => write!(f, "STABLE"),
|
||||
FunctionBehavior::Volatile => write!(f, "VOLATILE"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
|
||||
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct CreateFunctionBody {
|
||||
/// LANGUAGE lang_name
|
||||
pub language: Option<Ident>,
|
||||
/// IMMUTABLE | STABLE | VOLATILE
|
||||
pub behavior: Option<FunctionBehavior>,
|
||||
/// AS 'definition'
|
||||
///
|
||||
/// Note that Hive's `AS class_name` is also parsed here.
|
||||
pub as_: Option<String>,
|
||||
/// RETURN expression
|
||||
pub return_: Option<Expr>,
|
||||
/// USING ... (Hive only)
|
||||
pub using: Option<CreateFunctionUsing>,
|
||||
}
|
||||
|
||||
impl fmt::Display for CreateFunctionBody {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if let Some(language) = &self.language {
|
||||
write!(f, " LANGUAGE {language}")?;
|
||||
}
|
||||
if let Some(behavior) = &self.behavior {
|
||||
write!(f, " {behavior}")?;
|
||||
}
|
||||
if let Some(definition) = &self.as_ {
|
||||
write!(f, " AS '{definition}'")?;
|
||||
}
|
||||
if let Some(expr) = &self.return_ {
|
||||
write!(f, " RETURN {expr}")?;
|
||||
}
|
||||
if let Some(using) = &self.using {
|
||||
write!(f, " {using}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum CreateFunctionUsing {
|
||||
|
|
|
@ -284,6 +284,7 @@ define_keywords!(
|
|||
IF,
|
||||
IGNORE,
|
||||
ILIKE,
|
||||
IMMUTABLE,
|
||||
IN,
|
||||
INCREMENT,
|
||||
INDEX,
|
||||
|
@ -518,6 +519,7 @@ define_keywords!(
|
|||
SQLSTATE,
|
||||
SQLWARNING,
|
||||
SQRT,
|
||||
STABLE,
|
||||
START,
|
||||
STATIC,
|
||||
STATISTICS,
|
||||
|
@ -604,6 +606,7 @@ define_keywords!(
|
|||
VERSIONING,
|
||||
VIEW,
|
||||
VIRTUAL,
|
||||
VOLATILE,
|
||||
WEEK,
|
||||
WHEN,
|
||||
WHENEVER,
|
||||
|
|
130
src/parser.rs
130
src/parser.rs
|
@ -2026,9 +2026,11 @@ impl<'a> Parser<'a> {
|
|||
self.parse_create_view(or_replace)
|
||||
} else if self.parse_keyword(Keyword::EXTERNAL) {
|
||||
self.parse_create_external_table(or_replace)
|
||||
} else if self.parse_keyword(Keyword::FUNCTION) {
|
||||
self.parse_create_function(or_replace, temporary)
|
||||
} else if or_replace {
|
||||
self.expected(
|
||||
"[EXTERNAL] TABLE or [MATERIALIZED] VIEW after CREATE OR REPLACE",
|
||||
"[EXTERNAL] TABLE or [MATERIALIZED] VIEW or FUNCTION after CREATE OR REPLACE",
|
||||
self.peek_token(),
|
||||
)
|
||||
} else if self.parse_keyword(Keyword::INDEX) {
|
||||
|
@ -2041,8 +2043,6 @@ impl<'a> Parser<'a> {
|
|||
self.parse_create_schema()
|
||||
} else if self.parse_keyword(Keyword::DATABASE) {
|
||||
self.parse_create_database()
|
||||
} else if dialect_of!(self is HiveDialect) && self.parse_keyword(Keyword::FUNCTION) {
|
||||
self.parse_create_function(temporary)
|
||||
} else if self.parse_keyword(Keyword::ROLE) {
|
||||
self.parse_create_role()
|
||||
} else if self.parse_keyword(Keyword::SEQUENCE) {
|
||||
|
@ -2253,20 +2253,126 @@ impl<'a> Parser<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn parse_create_function(&mut self, temporary: bool) -> Result<Statement, ParserError> {
|
||||
let name = self.parse_object_name()?;
|
||||
self.expect_keyword(Keyword::AS)?;
|
||||
let class_name = self.parse_literal_string()?;
|
||||
let using = self.parse_optional_create_function_using()?;
|
||||
pub fn parse_create_function(
|
||||
&mut self,
|
||||
or_replace: bool,
|
||||
temporary: bool,
|
||||
) -> Result<Statement, ParserError> {
|
||||
if dialect_of!(self is HiveDialect) {
|
||||
let name = self.parse_object_name()?;
|
||||
self.expect_keyword(Keyword::AS)?;
|
||||
let class_name = self.parse_literal_string()?;
|
||||
let params = CreateFunctionBody {
|
||||
as_: Some(class_name),
|
||||
using: self.parse_optional_create_function_using()?,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Ok(Statement::CreateFunction {
|
||||
temporary,
|
||||
Ok(Statement::CreateFunction {
|
||||
or_replace,
|
||||
temporary,
|
||||
name,
|
||||
args: None,
|
||||
return_type: None,
|
||||
params,
|
||||
})
|
||||
} else if dialect_of!(self is PostgreSqlDialect) {
|
||||
let name = self.parse_object_name()?;
|
||||
self.expect_token(&Token::LParen)?;
|
||||
let args = self.parse_comma_separated(Parser::parse_create_function_arg)?;
|
||||
self.expect_token(&Token::RParen)?;
|
||||
|
||||
let return_type = if self.parse_keyword(Keyword::RETURNS) {
|
||||
Some(self.parse_data_type()?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let params = self.parse_create_function_body()?;
|
||||
|
||||
Ok(Statement::CreateFunction {
|
||||
or_replace,
|
||||
temporary,
|
||||
name,
|
||||
args: Some(args),
|
||||
return_type,
|
||||
params,
|
||||
})
|
||||
} else {
|
||||
self.prev_token();
|
||||
self.expected("an object type after CREATE", self.peek_token())
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_create_function_arg(&mut self) -> Result<CreateFunctionArg, ParserError> {
|
||||
let mode = if self.parse_keyword(Keyword::IN) {
|
||||
Some(ArgMode::In)
|
||||
} else if self.parse_keyword(Keyword::OUT) {
|
||||
Some(ArgMode::Out)
|
||||
} else if self.parse_keyword(Keyword::INOUT) {
|
||||
Some(ArgMode::InOut)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// parse: [ argname ] argtype
|
||||
let mut name = None;
|
||||
let mut data_type = self.parse_data_type()?;
|
||||
if let DataType::Custom(n, _) = &data_type {
|
||||
// the first token is actually a name
|
||||
name = Some(n.0[0].clone());
|
||||
data_type = self.parse_data_type()?;
|
||||
}
|
||||
|
||||
let default_expr = if self.parse_keyword(Keyword::DEFAULT) || self.consume_token(&Token::Eq)
|
||||
{
|
||||
Some(self.parse_expr()?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(CreateFunctionArg {
|
||||
mode,
|
||||
name,
|
||||
class_name,
|
||||
using,
|
||||
data_type,
|
||||
default_expr,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_create_function_body(&mut self) -> Result<CreateFunctionBody, ParserError> {
|
||||
let mut body = CreateFunctionBody::default();
|
||||
loop {
|
||||
fn ensure_not_set<T>(field: &Option<T>, name: &str) -> Result<(), ParserError> {
|
||||
if field.is_some() {
|
||||
return Err(ParserError::ParserError(format!(
|
||||
"{name} specified more than once",
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
if self.parse_keyword(Keyword::AS) {
|
||||
ensure_not_set(&body.as_, "AS")?;
|
||||
body.as_ = Some(self.parse_literal_string()?);
|
||||
} else if self.parse_keyword(Keyword::LANGUAGE) {
|
||||
ensure_not_set(&body.language, "LANGUAGE")?;
|
||||
body.language = Some(self.parse_identifier()?);
|
||||
} else if self.parse_keyword(Keyword::IMMUTABLE) {
|
||||
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
|
||||
body.behavior = Some(FunctionBehavior::Immutable);
|
||||
} else if self.parse_keyword(Keyword::STABLE) {
|
||||
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
|
||||
body.behavior = Some(FunctionBehavior::Stable);
|
||||
} else if self.parse_keyword(Keyword::VOLATILE) {
|
||||
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
|
||||
body.behavior = Some(FunctionBehavior::Volatile);
|
||||
} else if self.parse_keyword(Keyword::RETURN) {
|
||||
ensure_not_set(&body.return_, "RETURN")?;
|
||||
body.return_ = Some(self.parse_expr()?);
|
||||
} else {
|
||||
return Ok(body);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_create_external_table(
|
||||
&mut self,
|
||||
or_replace: bool,
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
//! is also tested (on the inputs it can handle).
|
||||
|
||||
use sqlparser::ast::{
|
||||
CreateFunctionUsing, Expr, Function, Ident, ObjectName, SelectItem, Statement, TableFactor,
|
||||
UnaryOperator, Value,
|
||||
CreateFunctionBody, CreateFunctionUsing, Expr, Function, Ident, ObjectName, SelectItem,
|
||||
Statement, TableFactor, UnaryOperator, Value,
|
||||
};
|
||||
use sqlparser::dialect::{GenericDialect, HiveDialect};
|
||||
use sqlparser::parser::ParserError;
|
||||
|
@ -244,17 +244,20 @@ fn parse_create_function() {
|
|||
Statement::CreateFunction {
|
||||
temporary,
|
||||
name,
|
||||
class_name,
|
||||
using,
|
||||
params,
|
||||
..
|
||||
} => {
|
||||
assert!(temporary);
|
||||
assert_eq!("mydb.myfunc", name.to_string());
|
||||
assert_eq!("org.random.class.Name", class_name);
|
||||
assert_eq!(name.to_string(), "mydb.myfunc");
|
||||
assert_eq!(
|
||||
using,
|
||||
Some(CreateFunctionUsing::Jar(
|
||||
"hdfs://somewhere.com:8020/very/far".to_string()
|
||||
))
|
||||
params,
|
||||
CreateFunctionBody {
|
||||
as_: Some("org.random.class.Name".to_string()),
|
||||
using: Some(CreateFunctionUsing::Jar(
|
||||
"hdfs://somewhere.com:8020/very/far".to_string()
|
||||
)),
|
||||
..Default::default()
|
||||
}
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
|
|
|
@ -2234,3 +2234,57 @@ fn parse_similar_to() {
|
|||
chk(false);
|
||||
chk(true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_create_function() {
|
||||
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'";
|
||||
assert_eq!(
|
||||
pg().verified_stmt(sql),
|
||||
Statement::CreateFunction {
|
||||
or_replace: false,
|
||||
temporary: false,
|
||||
name: ObjectName(vec![Ident::new("add")]),
|
||||
args: Some(vec![
|
||||
CreateFunctionArg::unnamed(DataType::Integer(None)),
|
||||
CreateFunctionArg::unnamed(DataType::Integer(None)),
|
||||
]),
|
||||
return_type: Some(DataType::Integer(None)),
|
||||
params: CreateFunctionBody {
|
||||
language: Some("SQL".into()),
|
||||
behavior: Some(FunctionBehavior::Immutable),
|
||||
as_: Some("select $1 + $2;".into()),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURN a + b";
|
||||
assert_eq!(
|
||||
pg().verified_stmt(sql),
|
||||
Statement::CreateFunction {
|
||||
or_replace: true,
|
||||
temporary: false,
|
||||
name: ObjectName(vec![Ident::new("add")]),
|
||||
args: Some(vec![
|
||||
CreateFunctionArg::with_name("a", DataType::Integer(None)),
|
||||
CreateFunctionArg {
|
||||
mode: Some(ArgMode::In),
|
||||
name: Some("b".into()),
|
||||
data_type: DataType::Integer(None),
|
||||
default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))),
|
||||
}
|
||||
]),
|
||||
return_type: Some(DataType::Integer(None)),
|
||||
params: CreateFunctionBody {
|
||||
language: Some("SQL".into()),
|
||||
behavior: Some(FunctionBehavior::Immutable),
|
||||
return_: Some(Expr::BinaryOp {
|
||||
left: Box::new(Expr::Identifier("a".into())),
|
||||
op: BinaryOperator::Plus,
|
||||
right: Box::new(Expr::Identifier("b".into())),
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue