mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-08-22 06:54:07 +00:00
Databricks: support for lambda functions (#1257)
This commit is contained in:
parent
a86c58b1c9
commit
d9d69a2192
6 changed files with 210 additions and 5 deletions
|
@ -793,6 +793,59 @@ pub enum Expr {
|
||||||
OuterJoin(Box<Expr>),
|
OuterJoin(Box<Expr>),
|
||||||
/// A reference to the prior level in a CONNECT BY clause.
|
/// A reference to the prior level in a CONNECT BY clause.
|
||||||
Prior(Box<Expr>),
|
Prior(Box<Expr>),
|
||||||
|
/// A lambda function.
|
||||||
|
///
|
||||||
|
/// Syntax:
|
||||||
|
/// ```plaintext
|
||||||
|
/// param -> expr | (param1, ...) -> expr
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// See <https://docs.databricks.com/en/sql/language-manual/sql-ref-lambda-functions.html>.
|
||||||
|
Lambda(LambdaFunction),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A lambda function.
|
||||||
|
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||||
|
pub struct LambdaFunction {
|
||||||
|
/// The parameters to the lambda function.
|
||||||
|
pub params: OneOrManyWithParens<Ident>,
|
||||||
|
/// The body of the lambda function.
|
||||||
|
pub body: Box<Expr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for LambdaFunction {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "{} -> {}", self.params, self.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encapsulates the common pattern in SQL where either one unparenthesized item
|
||||||
|
/// such as an identifier or expression is permitted, or multiple of the same
|
||||||
|
/// item in a parenthesized list.
|
||||||
|
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
|
||||||
|
pub enum OneOrManyWithParens<T> {
|
||||||
|
/// A single `T`, unparenthesized.
|
||||||
|
One(T),
|
||||||
|
/// One or more `T`s, parenthesized.
|
||||||
|
Many(Vec<T>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> fmt::Display for OneOrManyWithParens<T>
|
||||||
|
where
|
||||||
|
T: fmt::Display,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
OneOrManyWithParens::One(value) => write!(f, "{value}"),
|
||||||
|
OneOrManyWithParens::Many(values) => {
|
||||||
|
write!(f, "({})", display_comma_separated(values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for CastFormat {
|
impl fmt::Display for CastFormat {
|
||||||
|
@ -1241,6 +1294,7 @@ impl fmt::Display for Expr {
|
||||||
write!(f, "{expr} (+)")
|
write!(f, "{expr} (+)")
|
||||||
}
|
}
|
||||||
Expr::Prior(expr) => write!(f, "PRIOR {expr}"),
|
Expr::Prior(expr) => write!(f, "PRIOR {expr}"),
|
||||||
|
Expr::Lambda(lambda) => write!(f, "{lambda}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,4 +29,8 @@ impl Dialect for DatabricksDialect {
|
||||||
fn supports_group_by_expr(&self) -> bool {
|
fn supports_group_by_expr(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_lambda_functions(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -209,6 +209,14 @@ pub trait Dialect: Debug + Any {
|
||||||
fn supports_dictionary_syntax(&self) -> bool {
|
fn supports_dictionary_syntax(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
/// Returns true if the dialect supports lambda functions, for example:
|
||||||
|
///
|
||||||
|
/// ```sql
|
||||||
|
/// SELECT transform(array(1, 2, 3), x -> x + 1); -- returns [2,3,4]
|
||||||
|
/// ```
|
||||||
|
fn supports_lambda_functions(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
/// Returns true if the dialect has a CONVERT function which accepts a type first
|
/// Returns true if the dialect has a CONVERT function which accepts a type first
|
||||||
/// and an expression second, e.g. `CONVERT(varchar, 1)`
|
/// and an expression second, e.g. `CONVERT(varchar, 1)`
|
||||||
fn convert_type_before_value(&self) -> bool {
|
fn convert_type_before_value(&self) -> bool {
|
||||||
|
|
|
@ -1018,7 +1018,19 @@ impl<'a> Parser<'a> {
|
||||||
Keyword::CAST => self.parse_cast_expr(CastKind::Cast),
|
Keyword::CAST => self.parse_cast_expr(CastKind::Cast),
|
||||||
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
|
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
|
||||||
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
|
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
|
||||||
Keyword::EXISTS => self.parse_exists_expr(false),
|
Keyword::EXISTS
|
||||||
|
// Support parsing Databricks has a function named `exists`.
|
||||||
|
if !dialect_of!(self is DatabricksDialect)
|
||||||
|
|| matches!(
|
||||||
|
self.peek_nth_token(1).token,
|
||||||
|
Token::Word(Word {
|
||||||
|
keyword: Keyword::SELECT | Keyword::WITH,
|
||||||
|
..
|
||||||
|
})
|
||||||
|
) =>
|
||||||
|
{
|
||||||
|
self.parse_exists_expr(false)
|
||||||
|
}
|
||||||
Keyword::EXTRACT => self.parse_extract_expr(),
|
Keyword::EXTRACT => self.parse_extract_expr(),
|
||||||
Keyword::CEIL => self.parse_ceil_floor_expr(true),
|
Keyword::CEIL => self.parse_ceil_floor_expr(true),
|
||||||
Keyword::FLOOR => self.parse_ceil_floor_expr(false),
|
Keyword::FLOOR => self.parse_ceil_floor_expr(false),
|
||||||
|
@ -1036,7 +1048,7 @@ impl<'a> Parser<'a> {
|
||||||
}
|
}
|
||||||
Keyword::ARRAY
|
Keyword::ARRAY
|
||||||
if self.peek_token() == Token::LParen
|
if self.peek_token() == Token::LParen
|
||||||
&& !dialect_of!(self is ClickHouseDialect) =>
|
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
|
||||||
{
|
{
|
||||||
self.expect_token(&Token::LParen)?;
|
self.expect_token(&Token::LParen)?;
|
||||||
let query = self.parse_boxed_query()?;
|
let query = self.parse_boxed_query()?;
|
||||||
|
@ -1124,6 +1136,13 @@ impl<'a> Parser<'a> {
|
||||||
value: self.parse_introduced_string_value()?,
|
value: self.parse_introduced_string_value()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Token::Arrow if self.dialect.supports_lambda_functions() => {
|
||||||
|
self.expect_token(&Token::Arrow)?;
|
||||||
|
return Ok(Expr::Lambda(LambdaFunction {
|
||||||
|
params: OneOrManyWithParens::One(w.to_ident()),
|
||||||
|
body: Box::new(self.parse_expr()?),
|
||||||
|
}));
|
||||||
|
}
|
||||||
_ => Ok(Expr::Identifier(w.to_ident())),
|
_ => Ok(Expr::Identifier(w.to_ident())),
|
||||||
},
|
},
|
||||||
}, // End of Token::Word
|
}, // End of Token::Word
|
||||||
|
@ -1182,6 +1201,8 @@ impl<'a> Parser<'a> {
|
||||||
if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) {
|
if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) {
|
||||||
self.prev_token();
|
self.prev_token();
|
||||||
Expr::Subquery(self.parse_boxed_query()?)
|
Expr::Subquery(self.parse_boxed_query()?)
|
||||||
|
} else if let Some(lambda) = self.try_parse_lambda() {
|
||||||
|
return Ok(lambda);
|
||||||
} else {
|
} else {
|
||||||
let exprs = self.parse_comma_separated(Parser::parse_expr)?;
|
let exprs = self.parse_comma_separated(Parser::parse_expr)?;
|
||||||
match exprs.len() {
|
match exprs.len() {
|
||||||
|
@ -1231,6 +1252,22 @@ impl<'a> Parser<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn try_parse_lambda(&mut self) -> Option<Expr> {
|
||||||
|
if !self.dialect.supports_lambda_functions() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
self.maybe_parse(|p| {
|
||||||
|
let params = p.parse_comma_separated(|p| p.parse_identifier(false))?;
|
||||||
|
p.expect_token(&Token::RParen)?;
|
||||||
|
p.expect_token(&Token::Arrow)?;
|
||||||
|
let expr = p.parse_expr()?;
|
||||||
|
Ok(Expr::Lambda(LambdaFunction {
|
||||||
|
params: OneOrManyWithParens::Many(params),
|
||||||
|
body: Box::new(expr),
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn parse_function(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
|
pub fn parse_function(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
|
||||||
self.expect_token(&Token::LParen)?;
|
self.expect_token(&Token::LParen)?;
|
||||||
|
|
||||||
|
|
|
@ -1380,7 +1380,11 @@ fn pg_and_generic() -> TestedDialects {
|
||||||
fn parse_json_ops_without_colon() {
|
fn parse_json_ops_without_colon() {
|
||||||
use self::BinaryOperator::*;
|
use self::BinaryOperator::*;
|
||||||
let binary_ops = [
|
let binary_ops = [
|
||||||
("->", Arrow, all_dialects()),
|
(
|
||||||
|
"->",
|
||||||
|
Arrow,
|
||||||
|
all_dialects_except(|d| d.supports_lambda_functions()),
|
||||||
|
),
|
||||||
("->>", LongArrow, all_dialects()),
|
("->>", LongArrow, all_dialects()),
|
||||||
("#>", HashArrow, pg_and_generic()),
|
("#>", HashArrow, pg_and_generic()),
|
||||||
("#>>", HashLongArrow, pg_and_generic()),
|
("#>>", HashLongArrow, pg_and_generic()),
|
||||||
|
@ -6174,7 +6178,8 @@ fn parse_exists_subquery() {
|
||||||
verified_stmt("SELECT * FROM t WHERE EXISTS (WITH u AS (SELECT 1) SELECT * FROM u)");
|
verified_stmt("SELECT * FROM t WHERE EXISTS (WITH u AS (SELECT 1) SELECT * FROM u)");
|
||||||
verified_stmt("SELECT EXISTS (SELECT 1)");
|
verified_stmt("SELECT EXISTS (SELECT 1)");
|
||||||
|
|
||||||
let res = parse_sql_statements("SELECT EXISTS (");
|
let res = all_dialects_except(|d| d.is::<DatabricksDialect>())
|
||||||
|
.parse_sql_statements("SELECT EXISTS (");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ParserError::ParserError(
|
ParserError::ParserError(
|
||||||
"Expected SELECT, VALUES, or a subquery in the query body, found: EOF".to_string()
|
"Expected SELECT, VALUES, or a subquery in the query body, found: EOF".to_string()
|
||||||
|
@ -6182,7 +6187,8 @@ fn parse_exists_subquery() {
|
||||||
res.unwrap_err(),
|
res.unwrap_err(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = parse_sql_statements("SELECT EXISTS (NULL)");
|
let res = all_dialects_except(|d| d.is::<DatabricksDialect>())
|
||||||
|
.parse_sql_statements("SELECT EXISTS (NULL)");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ParserError::ParserError(
|
ParserError::ParserError(
|
||||||
"Expected SELECT, VALUES, or a subquery in the query body, found: NULL".to_string()
|
"Expected SELECT, VALUES, or a subquery in the query body, found: NULL".to_string()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use sqlparser::ast::*;
|
use sqlparser::ast::*;
|
||||||
use sqlparser::dialect::DatabricksDialect;
|
use sqlparser::dialect::DatabricksDialect;
|
||||||
|
use sqlparser::parser::ParserError;
|
||||||
use test_utils::*;
|
use test_utils::*;
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
|
@ -28,3 +29,98 @@ fn test_databricks_identifiers() {
|
||||||
SelectItem::UnnamedExpr(Expr::Value(Value::DoubleQuotedString("Ä".to_owned())))
|
SelectItem::UnnamedExpr(Expr::Value(Value::DoubleQuotedString("Ä".to_owned())))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_databricks_exists() {
|
||||||
|
// exists is a function in databricks
|
||||||
|
assert_eq!(
|
||||||
|
databricks().verified_expr("exists(array(1, 2, 3), x -> x IS NULL)"),
|
||||||
|
call(
|
||||||
|
"exists",
|
||||||
|
[
|
||||||
|
call(
|
||||||
|
"array",
|
||||||
|
[
|
||||||
|
Expr::Value(number("1")),
|
||||||
|
Expr::Value(number("2")),
|
||||||
|
Expr::Value(number("3"))
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Expr::Lambda(LambdaFunction {
|
||||||
|
params: OneOrManyWithParens::One(Ident::new("x")),
|
||||||
|
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x")))))
|
||||||
|
})
|
||||||
|
]
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let res = databricks().parse_sql_statements("SELECT EXISTS (");
|
||||||
|
assert_eq!(
|
||||||
|
// TODO: improve this error message...
|
||||||
|
ParserError::ParserError("Expected an expression:, found: EOF".to_string()),
|
||||||
|
res.unwrap_err(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_databricks_lambdas() {
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let sql = concat!(
|
||||||
|
"SELECT array_sort(array('Hello', 'World'), ",
|
||||||
|
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
|
||||||
|
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
|
||||||
|
"ELSE 1 END)",
|
||||||
|
);
|
||||||
|
pretty_assertions::assert_eq!(
|
||||||
|
SelectItem::UnnamedExpr(call(
|
||||||
|
"array_sort",
|
||||||
|
[
|
||||||
|
call(
|
||||||
|
"array",
|
||||||
|
[
|
||||||
|
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
|
||||||
|
Expr::Value(Value::SingleQuotedString("World".to_owned()))
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Expr::Lambda(LambdaFunction {
|
||||||
|
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
|
||||||
|
body: Box::new(Expr::Case {
|
||||||
|
operand: None,
|
||||||
|
conditions: vec![
|
||||||
|
Expr::BinaryOp {
|
||||||
|
left: Box::new(Expr::Identifier(Ident::new("p1"))),
|
||||||
|
op: BinaryOperator::Eq,
|
||||||
|
right: Box::new(Expr::Identifier(Ident::new("p2")))
|
||||||
|
},
|
||||||
|
Expr::BinaryOp {
|
||||||
|
left: Box::new(call(
|
||||||
|
"reverse",
|
||||||
|
[Expr::Identifier(Ident::new("p1"))]
|
||||||
|
)),
|
||||||
|
op: BinaryOperator::Lt,
|
||||||
|
right: Box::new(call(
|
||||||
|
"reverse",
|
||||||
|
[Expr::Identifier(Ident::new("p2"))]
|
||||||
|
))
|
||||||
|
}
|
||||||
|
],
|
||||||
|
results: vec![
|
||||||
|
Expr::Value(number("0")),
|
||||||
|
Expr::UnaryOp {
|
||||||
|
op: UnaryOperator::Minus,
|
||||||
|
expr: Box::new(Expr::Value(number("1")))
|
||||||
|
}
|
||||||
|
],
|
||||||
|
else_result: Some(Box::new(Expr::Value(number("1"))))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
]
|
||||||
|
)),
|
||||||
|
databricks().verified_only_select(sql).projection[0]
|
||||||
|
);
|
||||||
|
|
||||||
|
databricks().verified_expr(
|
||||||
|
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
|
||||||
|
);
|
||||||
|
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue