Parse ARRAY_AGG for Bigquery and Snowflake (#662)

This commit is contained in:
SuperBo 2022-11-12 03:25:07 +07:00 committed by GitHub
parent 0428ac742b
commit 87b4a168cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 156 additions and 4 deletions

View file

@ -416,6 +416,8 @@ pub enum Expr {
ArraySubquery(Box<Query>),
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
ListAgg(ListAgg),
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
ArrayAgg(ArrayAgg),
/// The `GROUPING SETS` expr.
GroupingSets(Vec<Vec<Expr>>),
/// The `CUBE` expr.
@ -655,6 +657,7 @@ impl fmt::Display for Expr {
Expr::Subquery(s) => write!(f, "({})", s),
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
Expr::GroupingSets(sets) => {
write!(f, "GROUPING SETS (")?;
let mut sep = "";
@ -3036,6 +3039,45 @@ impl fmt::Display for ListAggOnOverflow {
}
}
/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ArrayAgg {
pub distinct: bool,
pub expr: Box<Expr>,
pub order_by: Option<Box<OrderByExpr>>,
pub limit: Option<Box<Expr>>,
pub within_group: bool, // order by is used inside a within group or not
}
impl fmt::Display for ArrayAgg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ARRAY_AGG({}{}",
if self.distinct { "DISTINCT " } else { "" },
self.expr
)?;
if !self.within_group {
if let Some(order_by) = &self.order_by {
write!(f, " ORDER BY {}", order_by)?;
}
if let Some(limit) = &self.limit {
write!(f, " LIMIT {}", limit)?;
}
}
write!(f, ")")?;
if self.within_group {
if let Some(order_by) = &self.order_by {
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ObjectType {

View file

@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any {
fn supports_filter_during_aggregation(&self) -> bool {
false
}
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
fn supports_within_after_array_aggregation(&self) -> bool {
false
}
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior

View file

@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
|| ch == '$'
|| ch == '_'
}
fn supports_within_after_array_aggregation(&self) -> bool {
true
}
}

View file

@ -473,6 +473,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::LParen)?;
self.parse_array_subquery()
}
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
Keyword::NOT => self.parse_not(),
// Here `w` is a word, check if it's a part of a multi-part
// identifier, a function call, or a simple identifier:
@ -1071,6 +1072,54 @@ impl<'a> Parser<'a> {
}))
}
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let distinct = self.parse_keyword(Keyword::DISTINCT);
let expr = Box::new(self.parse_expr()?);
// ANSI SQL and BigQuery define ORDER BY inside function.
if !self.dialect.supports_within_after_array_aggregation() {
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
let order_by_expr = self.parse_order_by_expr()?;
Some(Box::new(order_by_expr))
} else {
None
};
let limit = if self.parse_keyword(Keyword::LIMIT) {
self.parse_limit()?.map(Box::new)
} else {
None
};
self.expect_token(&Token::RParen)?;
return Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by,
limit,
within_group: false,
}));
}
// Snowflake defines ORDERY BY in within group instead of inside the function like
// ANSI SQL.
self.expect_token(&Token::RParen)?;
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
self.expect_token(&Token::LParen)?;
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
let order_by_expr = self.parse_order_by_expr()?;
self.expect_token(&Token::RParen)?;
Some(Box::new(order_by_expr))
} else {
None
};
Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by: within_group,
limit: None,
within_group: true,
}))
}
// This function parses date/time fields for the EXTRACT function-like
// operator, interval qualifiers, and the ceil/floor operations.
// EXTRACT supports a wider set of date/time fields than interval qualifiers,

View file

@ -224,6 +224,17 @@ fn parse_similar_to() {
chk(true);
}
#[test]
fn parse_array_agg_func() {
for sql in [
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
] {
bigquery().verified_stmt(sql);
}
}
fn bigquery() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(BigQueryDialect {})],

View file

@ -1777,6 +1777,27 @@ fn parse_listagg() {
);
}
#[test]
fn parse_array_agg_func() {
let supported_dialects = TestedDialects {
dialects: vec![
Box::new(GenericDialect {}),
Box::new(PostgreSqlDialect {}),
Box::new(MsSqlDialect {}),
Box::new(AnsiDialect {}),
Box::new(HiveDialect {}),
],
};
for sql in [
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
] {
supported_dialects.verified_stmt(sql);
}
}
#[test]
fn parse_create_table() {
let sql = "CREATE TABLE uk_cities (\

View file

@ -281,8 +281,8 @@ fn parse_create_function() {
#[test]
fn filtering_during_aggregation() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL), \
array_agg(name) FILTER (WHERE name LIKE 'a%') \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
FROM region";
println!("{}", hive().verified_stmt(rename));
}
@ -290,8 +290,8 @@ fn filtering_during_aggregation() {
#[test]
fn filtering_during_aggregation_aliased() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
FROM region";
println!("{}", hive().verified_stmt(rename));
}

View file

@ -334,6 +334,25 @@ fn parse_similar_to() {
chk(true);
}
#[test]
fn test_array_agg_func() {
for sql in [
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
] {
snowflake().verified_stmt(sql);
}
let sql = "select array_agg(x order by x) as a from T";
let result = snowflake().parse_sql_statements(sql);
assert_eq!(
result,
Err(ParserError::ParserError(String::from(
"Expected ), found: order"
)))
)
}
fn snowflake() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(SnowflakeDialect {})],