Add support for aggregate expressions with filters (#585)

This commit is contained in:
Andy Grove 2022-09-08 13:08:45 -06:00 committed by GitHub
parent 0bb49cea99
commit 303f80f168
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 6 deletions

View file

@ -375,6 +375,8 @@ pub enum Expr {
MapAccess { column: Box<Expr>, keys: Vec<Expr> },
/// Scalar function call e.g. `LEFT(foo, 5)`
Function(Function),
/// Aggregate function with filter
AggregateExpressionWithFilter { expr: Box<Expr>, filter: Box<Expr> },
/// `CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END`
///
/// Note we only recognize a complete single expression as `<condition>`,
@ -571,6 +573,9 @@ impl fmt::Display for Expr {
write!(f, " '{}'", &value::escape_single_quote_string(value))
}
Expr::Function(fun) => write!(f, "{}", fun),
Expr::AggregateExpressionWithFilter { expr, filter } => {
write!(f, "{} FILTER (WHERE {})", expr, filter)
}
Expr::Case {
operand,
conditions,

View file

@ -36,4 +36,8 @@ impl Dialect for HiveDialect {
|| ch == '{'
|| ch == '}'
}
fn supports_filter_during_aggregation(&self) -> bool {
true
}
}

View file

@ -67,6 +67,10 @@ pub trait Dialect: Debug + Any {
fn is_identifier_start(&self, ch: char) -> bool;
/// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool;
/// Does the dialect support `FILTER (WHERE expr)` for aggregate queries?
fn supports_filter_during_aggregation(&self) -> bool {
false
}
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior

View file

@ -42,6 +42,10 @@ impl Dialect for PostgreSqlDialect {
None
}
}
fn supports_filter_during_aggregation(&self) -> bool {
true
}
}
pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {

View file

@ -4542,12 +4542,31 @@ impl<'a> Parser<'a> {
/// Parse a comma-delimited list of projections after SELECT
pub fn parse_select_item(&mut self) -> Result<SelectItem, ParserError> {
match self.parse_wildcard_expr()? {
WildcardExpr::Expr(expr) => self
.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)
.map(|alias| match alias {
Some(alias) => SelectItem::ExprWithAlias { expr, alias },
None => SelectItem::UnnamedExpr(expr),
}),
WildcardExpr::Expr(expr) => {
let expr: Expr = if self.dialect.supports_filter_during_aggregation()
&& self.parse_keyword(Keyword::FILTER)
{
let i = self.index - 1;
if self.consume_token(&Token::LParen) && self.parse_keyword(Keyword::WHERE) {
let filter = self.parse_expr()?;
self.expect_token(&Token::RParen)?;
Expr::AggregateExpressionWithFilter {
expr: Box::new(expr),
filter: Box::new(filter),
}
} else {
self.index = i;
expr
}
} else {
expr
};
self.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)
.map(|alias| match alias {
Some(alias) => SelectItem::ExprWithAlias { expr, alias },
None => SelectItem::UnnamedExpr(expr),
})
}
WildcardExpr::QualifiedWildcard(prefix) => Ok(SelectItem::QualifiedWildcard(prefix)),
WildcardExpr::Wildcard => Ok(SelectItem::Wildcard),
}

View file

@ -276,6 +276,31 @@ fn parse_create_function() {
);
}
#[test]
fn filtering_during_aggregation() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL), \
array_agg(name) FILTER (WHERE name LIKE 'a%') \
FROM region";
println!("{}", hive().verified_stmt(rename));
}
#[test]
fn filtering_during_aggregation_aliased() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
FROM region";
println!("{}", hive().verified_stmt(rename));
}
#[test]
fn filter_as_alias() {
let sql = "SELECT name filter FROM region";
let expected = "SELECT name AS filter FROM region";
println!("{}", hive().one_statement_parses_to(sql, expected));
}
fn hive() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(HiveDialect {})],