Databricks: support for lambda functions (#1257)

This commit is contained in:
Joey Hain 2024-05-06 10:37:48 -07:00 committed by GitHub
parent a86c58b1c9
commit d9d69a2192
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 210 additions and 5 deletions

View file

@ -793,6 +793,59 @@ pub enum Expr {
OuterJoin(Box<Expr>),
/// A reference to the prior level in a CONNECT BY clause.
Prior(Box<Expr>),
/// A lambda function.
///
/// Syntax:
/// ```plaintext
/// param -> expr | (param1, ...) -> expr
/// ```
///
/// See <https://docs.databricks.com/en/sql/language-manual/sql-ref-lambda-functions.html>.
Lambda(LambdaFunction),
}
/// A lambda function.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LambdaFunction {
/// The parameters to the lambda function.
pub params: OneOrManyWithParens<Ident>,
/// The body of the lambda function.
pub body: Box<Expr>,
}
impl fmt::Display for LambdaFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} -> {}", self.params, self.body)
}
}
/// Encapsulates the common pattern in SQL where either one unparenthesized item
/// such as an identifier or expression is permitted, or multiple of the same
/// item in a parenthesized list.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum OneOrManyWithParens<T> {
/// A single `T`, unparenthesized.
One(T),
/// One or more `T`s, parenthesized.
Many(Vec<T>),
}
impl<T> fmt::Display for OneOrManyWithParens<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OneOrManyWithParens::One(value) => write!(f, "{value}"),
OneOrManyWithParens::Many(values) => {
write!(f, "({})", display_comma_separated(values))
}
}
}
}
impl fmt::Display for CastFormat {
@ -1241,6 +1294,7 @@ impl fmt::Display for Expr {
write!(f, "{expr} (+)")
}
Expr::Prior(expr) => write!(f, "PRIOR {expr}"),
Expr::Lambda(lambda) => write!(f, "{lambda}"),
}
}
}

View file

@ -29,4 +29,8 @@ impl Dialect for DatabricksDialect {
fn supports_group_by_expr(&self) -> bool {
true
}
fn supports_lambda_functions(&self) -> bool {
true
}
}

View file

@ -209,6 +209,14 @@ pub trait Dialect: Debug + Any {
fn supports_dictionary_syntax(&self) -> bool {
false
}
/// Returns true if the dialect supports lambda functions, for example:
///
/// ```sql
/// SELECT transform(array(1, 2, 3), x -> x + 1); -- returns [2,3,4]
/// ```
fn supports_lambda_functions(&self) -> bool {
false
}
/// Returns true if the dialect has a CONVERT function which accepts a type first
/// and an expression second, e.g. `CONVERT(varchar, 1)`
fn convert_type_before_value(&self) -> bool {

View file

@ -1018,7 +1018,19 @@ impl<'a> Parser<'a> {
Keyword::CAST => self.parse_cast_expr(CastKind::Cast),
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
Keyword::EXISTS => self.parse_exists_expr(false),
Keyword::EXISTS
// Support parsing Databricks has a function named `exists`.
if !dialect_of!(self is DatabricksDialect)
|| matches!(
self.peek_nth_token(1).token,
Token::Word(Word {
keyword: Keyword::SELECT | Keyword::WITH,
..
})
) =>
{
self.parse_exists_expr(false)
}
Keyword::EXTRACT => self.parse_extract_expr(),
Keyword::CEIL => self.parse_ceil_floor_expr(true),
Keyword::FLOOR => self.parse_ceil_floor_expr(false),
@ -1036,7 +1048,7 @@ impl<'a> Parser<'a> {
}
Keyword::ARRAY
if self.peek_token() == Token::LParen
&& !dialect_of!(self is ClickHouseDialect) =>
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
{
self.expect_token(&Token::LParen)?;
let query = self.parse_boxed_query()?;
@ -1124,6 +1136,13 @@ impl<'a> Parser<'a> {
value: self.parse_introduced_string_value()?,
})
}
Token::Arrow if self.dialect.supports_lambda_functions() => {
self.expect_token(&Token::Arrow)?;
return Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(w.to_ident()),
body: Box::new(self.parse_expr()?),
}));
}
_ => Ok(Expr::Identifier(w.to_ident())),
},
}, // End of Token::Word
@ -1182,6 +1201,8 @@ impl<'a> Parser<'a> {
if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) {
self.prev_token();
Expr::Subquery(self.parse_boxed_query()?)
} else if let Some(lambda) = self.try_parse_lambda() {
return Ok(lambda);
} else {
let exprs = self.parse_comma_separated(Parser::parse_expr)?;
match exprs.len() {
@ -1231,6 +1252,22 @@ impl<'a> Parser<'a> {
}
}
fn try_parse_lambda(&mut self) -> Option<Expr> {
if !self.dialect.supports_lambda_functions() {
return None;
}
self.maybe_parse(|p| {
let params = p.parse_comma_separated(|p| p.parse_identifier(false))?;
p.expect_token(&Token::RParen)?;
p.expect_token(&Token::Arrow)?;
let expr = p.parse_expr()?;
Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(params),
body: Box::new(expr),
}))
})
}
pub fn parse_function(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;