support create function definition with $$ (#755)

* support create function definition using '2700775'

* fix warn
This commit is contained in:
zidaye 2022-12-14 06:15:33 +08:00 committed by GitHub
parent d420001c37
commit 6c545195e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 10 deletions

View file

@ -3777,6 +3777,23 @@ impl fmt::Display for FunctionBehavior {
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum FunctionDefinition {
SingleQuotedDef(String),
DoubleDollarDef(String),
}
impl fmt::Display for FunctionDefinition {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionDefinition::SingleQuotedDef(s) => write!(f, "'{s}'")?,
FunctionDefinition::DoubleDollarDef(s) => write!(f, "$${s}$$")?,
}
Ok(())
}
}
/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@ -3788,7 +3805,7 @@ pub struct CreateFunctionBody {
/// AS 'definition'
///
/// Note that Hive's `AS class_name` is also parsed here.
pub as_: Option<String>,
pub as_: Option<FunctionDefinition>,
/// RETURN expression
pub return_: Option<Expr>,
/// USING ... (Hive only)
@ -3804,7 +3821,7 @@ impl fmt::Display for CreateFunctionBody {
write!(f, " {behavior}")?;
}
if let Some(definition) = &self.as_ {
write!(f, " AS '{definition}'")?;
write!(f, " AS {definition}")?;
}
if let Some(expr) = &self.return_ {
write!(f, " RETURN {expr}")?;

View file

@ -2310,7 +2310,7 @@ impl<'a> Parser<'a> {
if dialect_of!(self is HiveDialect) {
let name = self.parse_object_name()?;
self.expect_keyword(Keyword::AS)?;
let class_name = self.parse_literal_string()?;
let class_name = self.parse_function_definition()?;
let params = CreateFunctionBody {
as_: Some(class_name),
using: self.parse_optional_create_function_using()?,
@ -2400,7 +2400,7 @@ impl<'a> Parser<'a> {
}
if self.parse_keyword(Keyword::AS) {
ensure_not_set(&body.as_, "AS")?;
body.as_ = Some(self.parse_literal_string()?);
body.as_ = Some(self.parse_function_definition()?);
} else if self.parse_keyword(Keyword::LANGUAGE) {
ensure_not_set(&body.language, "LANGUAGE")?;
body.language = Some(self.parse_identifier()?);
@ -3883,6 +3883,33 @@ impl<'a> Parser<'a> {
}
}
pub fn parse_function_definition(&mut self) -> Result<FunctionDefinition, ParserError> {
let peek_token = self.peek_token();
match peek_token.token {
Token::DoubleDollarQuoting if dialect_of!(self is PostgreSqlDialect) => {
self.next_token();
let mut func_desc = String::new();
loop {
if let Some(next_token) = self.next_token_no_skip() {
match &next_token.token {
Token::DoubleDollarQuoting => break,
Token::EOF => {
return self.expected(
"literal string",
TokenWithLocation::wrap(Token::EOF),
);
}
token => func_desc.push_str(token.to_string().as_str()),
}
}
}
Ok(FunctionDefinition::DoubleDollarDef(func_desc))
}
_ => Ok(FunctionDefinition::SingleQuotedDef(
self.parse_literal_string()?,
)),
}
}
/// Parse a literal string
pub fn parse_literal_string(&mut self) -> Result<String, ParserError> {
let next_token = self.next_token();

View file

@ -145,6 +145,8 @@ pub enum Token {
PGCubeRoot,
/// `?` or `$` , a prepared statement arg placeholder
Placeholder(String),
/// `$$`, used for PostgreSQL create function definition
DoubleDollarQuoting,
/// ->, used as a operator to extract json field in PostgreSQL
Arrow,
/// ->>, used as a operator to extract json field as text in PostgreSQL
@ -215,6 +217,7 @@ impl fmt::Display for Token {
Token::LongArrow => write!(f, "->>"),
Token::HashArrow => write!(f, "#>"),
Token::HashLongArrow => write!(f, "#>>"),
Token::DoubleDollarQuoting => write!(f, "$$"),
}
}
}
@ -770,8 +773,14 @@ impl<'a> Tokenizer<'a> {
}
'$' => {
chars.next();
let s = peeking_take_while(chars, |ch| ch.is_alphanumeric() || ch == '_');
Ok(Some(Token::Placeholder(String::from("$") + &s)))
match chars.peek() {
Some('$') => self.consume_and_return(chars, Token::DoubleDollarQuoting),
_ => {
let s =
peeking_take_while(chars, |ch| ch.is_alphanumeric() || ch == '_');
Ok(Some(Token::Placeholder(String::from("$") + &s)))
}
}
}
//whitespace check (including unicode chars) should be last as it covers some of the chars above
ch if ch.is_whitespace() => {

View file

@ -16,8 +16,8 @@
//! is also tested (on the inputs it can handle).
use sqlparser::ast::{
CreateFunctionBody, CreateFunctionUsing, Expr, Function, Ident, ObjectName, SelectItem,
Statement, TableFactor, UnaryOperator, Value,
CreateFunctionBody, CreateFunctionUsing, Expr, Function, FunctionDefinition, Ident, ObjectName,
SelectItem, Statement, TableFactor, UnaryOperator, Value,
};
use sqlparser::dialect::{GenericDialect, HiveDialect};
use sqlparser::parser::ParserError;
@ -252,7 +252,9 @@ fn parse_create_function() {
assert_eq!(
params,
CreateFunctionBody {
as_: Some("org.random.class.Name".to_string()),
as_: Some(FunctionDefinition::SingleQuotedDef(
"org.random.class.Name".to_string()
)),
using: Some(CreateFunctionUsing::Jar(
"hdfs://somewhere.com:8020/very/far".to_string()
)),

View file

@ -2257,7 +2257,9 @@ fn parse_create_function() {
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable),
as_: Some("select $1 + $2;".into()),
as_: Some(FunctionDefinition::SingleQuotedDef(
"select $1 + $2;".into()
)),
..Default::default()
},
}
@ -2292,4 +2294,28 @@ fn parse_create_function() {
},
}
);
let sql = r#"CREATE OR REPLACE FUNCTION increment(i INTEGER) RETURNS INTEGER LANGUAGE plpgsql AS $$ BEGIN RETURN i + 1; END; $$"#;
assert_eq!(
pg().verified_stmt(sql),
Statement::CreateFunction {
or_replace: true,
temporary: false,
name: ObjectName(vec![Ident::new("increment")]),
args: Some(vec![CreateFunctionArg::with_name(
"i",
DataType::Integer(None)
)]),
return_type: Some(DataType::Integer(None)),
params: CreateFunctionBody {
language: Some("plpgsql".into()),
behavior: None,
return_: None,
as_: Some(FunctionDefinition::DoubleDollarDef(
" BEGIN RETURN i + 1; END; ".into()
)),
using: None
},
}
);
}