mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-07-07 17:04:59 +00:00
Add support for aggregate expressions with filters (#585)
This commit is contained in:
parent
0bb49cea99
commit
303f80f168
6 changed files with 67 additions and 6 deletions
|
@ -375,6 +375,8 @@ pub enum Expr {
|
|||
MapAccess { column: Box<Expr>, keys: Vec<Expr> },
|
||||
/// Scalar function call e.g. `LEFT(foo, 5)`
|
||||
Function(Function),
|
||||
/// Aggregate function with filter
|
||||
AggregateExpressionWithFilter { expr: Box<Expr>, filter: Box<Expr> },
|
||||
/// `CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END`
|
||||
///
|
||||
/// Note we only recognize a complete single expression as `<condition>`,
|
||||
|
@ -571,6 +573,9 @@ impl fmt::Display for Expr {
|
|||
write!(f, " '{}'", &value::escape_single_quote_string(value))
|
||||
}
|
||||
Expr::Function(fun) => write!(f, "{}", fun),
|
||||
Expr::AggregateExpressionWithFilter { expr, filter } => {
|
||||
write!(f, "{} FILTER (WHERE {})", expr, filter)
|
||||
}
|
||||
Expr::Case {
|
||||
operand,
|
||||
conditions,
|
||||
|
|
|
@ -36,4 +36,8 @@ impl Dialect for HiveDialect {
|
|||
|| ch == '{'
|
||||
|| ch == '}'
|
||||
}
|
||||
|
||||
fn supports_filter_during_aggregation(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,6 +67,10 @@ pub trait Dialect: Debug + Any {
|
|||
fn is_identifier_start(&self, ch: char) -> bool;
|
||||
/// Determine if a character is a valid unquoted identifier character
|
||||
fn is_identifier_part(&self, ch: char) -> bool;
|
||||
/// Does the dialect support `FILTER (WHERE expr)` for aggregate queries?
|
||||
fn supports_filter_during_aggregation(&self) -> bool {
|
||||
false
|
||||
}
|
||||
/// Dialect-specific prefix parser override
|
||||
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
|
||||
// return None to fall back to the default behavior
|
||||
|
|
|
@ -42,6 +42,10 @@ impl Dialect for PostgreSqlDialect {
|
|||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_filter_during_aggregation(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
|
||||
|
|
|
@ -4542,12 +4542,31 @@ impl<'a> Parser<'a> {
|
|||
/// Parse a comma-delimited list of projections after SELECT
|
||||
pub fn parse_select_item(&mut self) -> Result<SelectItem, ParserError> {
|
||||
match self.parse_wildcard_expr()? {
|
||||
WildcardExpr::Expr(expr) => self
|
||||
.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)
|
||||
.map(|alias| match alias {
|
||||
Some(alias) => SelectItem::ExprWithAlias { expr, alias },
|
||||
None => SelectItem::UnnamedExpr(expr),
|
||||
}),
|
||||
WildcardExpr::Expr(expr) => {
|
||||
let expr: Expr = if self.dialect.supports_filter_during_aggregation()
|
||||
&& self.parse_keyword(Keyword::FILTER)
|
||||
{
|
||||
let i = self.index - 1;
|
||||
if self.consume_token(&Token::LParen) && self.parse_keyword(Keyword::WHERE) {
|
||||
let filter = self.parse_expr()?;
|
||||
self.expect_token(&Token::RParen)?;
|
||||
Expr::AggregateExpressionWithFilter {
|
||||
expr: Box::new(expr),
|
||||
filter: Box::new(filter),
|
||||
}
|
||||
} else {
|
||||
self.index = i;
|
||||
expr
|
||||
}
|
||||
} else {
|
||||
expr
|
||||
};
|
||||
self.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)
|
||||
.map(|alias| match alias {
|
||||
Some(alias) => SelectItem::ExprWithAlias { expr, alias },
|
||||
None => SelectItem::UnnamedExpr(expr),
|
||||
})
|
||||
}
|
||||
WildcardExpr::QualifiedWildcard(prefix) => Ok(SelectItem::QualifiedWildcard(prefix)),
|
||||
WildcardExpr::Wildcard => Ok(SelectItem::Wildcard),
|
||||
}
|
||||
|
|
|
@ -276,6 +276,31 @@ fn parse_create_function() {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filtering_during_aggregation() {
|
||||
let rename = "SELECT \
|
||||
array_agg(name) FILTER (WHERE name IS NOT NULL), \
|
||||
array_agg(name) FILTER (WHERE name LIKE 'a%') \
|
||||
FROM region";
|
||||
println!("{}", hive().verified_stmt(rename));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filtering_during_aggregation_aliased() {
|
||||
let rename = "SELECT \
|
||||
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
|
||||
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
|
||||
FROM region";
|
||||
println!("{}", hive().verified_stmt(rename));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_as_alias() {
|
||||
let sql = "SELECT name filter FROM region";
|
||||
let expected = "SELECT name AS filter FROM region";
|
||||
println!("{}", hive().one_statement_parses_to(sql, expected));
|
||||
}
|
||||
|
||||
fn hive() -> TestedDialects {
|
||||
TestedDialects {
|
||||
dialects: vec![Box::new(HiveDialect {})],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue