mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-07-19 06:15:00 +00:00
Parse ARRAY_AGG for Bigquery and Snowflake (#662)
This commit is contained in:
parent
0428ac742b
commit
87b4a168cb
8 changed files with 156 additions and 4 deletions
|
@ -416,6 +416,8 @@ pub enum Expr {
|
|||
ArraySubquery(Box<Query>),
|
||||
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
|
||||
ListAgg(ListAgg),
|
||||
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
|
||||
ArrayAgg(ArrayAgg),
|
||||
/// The `GROUPING SETS` expr.
|
||||
GroupingSets(Vec<Vec<Expr>>),
|
||||
/// The `CUBE` expr.
|
||||
|
@ -655,6 +657,7 @@ impl fmt::Display for Expr {
|
|||
Expr::Subquery(s) => write!(f, "({})", s),
|
||||
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
|
||||
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
|
||||
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
|
||||
Expr::GroupingSets(sets) => {
|
||||
write!(f, "GROUPING SETS (")?;
|
||||
let mut sep = "";
|
||||
|
@ -3036,6 +3039,45 @@ impl fmt::Display for ListAggOnOverflow {
|
|||
}
|
||||
}
|
||||
|
||||
/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
|
||||
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
|
||||
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub struct ArrayAgg {
|
||||
pub distinct: bool,
|
||||
pub expr: Box<Expr>,
|
||||
pub order_by: Option<Box<OrderByExpr>>,
|
||||
pub limit: Option<Box<Expr>>,
|
||||
pub within_group: bool, // order by is used inside a within group or not
|
||||
}
|
||||
|
||||
impl fmt::Display for ArrayAgg {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"ARRAY_AGG({}{}",
|
||||
if self.distinct { "DISTINCT " } else { "" },
|
||||
self.expr
|
||||
)?;
|
||||
if !self.within_group {
|
||||
if let Some(order_by) = &self.order_by {
|
||||
write!(f, " ORDER BY {}", order_by)?;
|
||||
}
|
||||
if let Some(limit) = &self.limit {
|
||||
write!(f, " LIMIT {}", limit)?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")?;
|
||||
if self.within_group {
|
||||
if let Some(order_by) = &self.order_by {
|
||||
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum ObjectType {
|
||||
|
|
|
@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any {
|
|||
fn supports_filter_during_aggregation(&self) -> bool {
|
||||
false
|
||||
}
|
||||
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
|
||||
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
|
||||
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
|
||||
fn supports_within_after_array_aggregation(&self) -> bool {
|
||||
false
|
||||
}
|
||||
/// Dialect-specific prefix parser override
|
||||
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
|
||||
// return None to fall back to the default behavior
|
||||
|
|
|
@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
|
|||
|| ch == '$'
|
||||
|| ch == '_'
|
||||
}
|
||||
|
||||
fn supports_within_after_array_aggregation(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -473,6 +473,7 @@ impl<'a> Parser<'a> {
|
|||
self.expect_token(&Token::LParen)?;
|
||||
self.parse_array_subquery()
|
||||
}
|
||||
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
|
||||
Keyword::NOT => self.parse_not(),
|
||||
// Here `w` is a word, check if it's a part of a multi-part
|
||||
// identifier, a function call, or a simple identifier:
|
||||
|
@ -1071,6 +1072,54 @@ impl<'a> Parser<'a> {
|
|||
}))
|
||||
}
|
||||
|
||||
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
|
||||
self.expect_token(&Token::LParen)?;
|
||||
let distinct = self.parse_keyword(Keyword::DISTINCT);
|
||||
let expr = Box::new(self.parse_expr()?);
|
||||
// ANSI SQL and BigQuery define ORDER BY inside function.
|
||||
if !self.dialect.supports_within_after_array_aggregation() {
|
||||
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
|
||||
let order_by_expr = self.parse_order_by_expr()?;
|
||||
Some(Box::new(order_by_expr))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let limit = if self.parse_keyword(Keyword::LIMIT) {
|
||||
self.parse_limit()?.map(Box::new)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.expect_token(&Token::RParen)?;
|
||||
return Ok(Expr::ArrayAgg(ArrayAgg {
|
||||
distinct,
|
||||
expr,
|
||||
order_by,
|
||||
limit,
|
||||
within_group: false,
|
||||
}));
|
||||
}
|
||||
// Snowflake defines ORDERY BY in within group instead of inside the function like
|
||||
// ANSI SQL.
|
||||
self.expect_token(&Token::RParen)?;
|
||||
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
|
||||
self.expect_token(&Token::LParen)?;
|
||||
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
|
||||
let order_by_expr = self.parse_order_by_expr()?;
|
||||
self.expect_token(&Token::RParen)?;
|
||||
Some(Box::new(order_by_expr))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Expr::ArrayAgg(ArrayAgg {
|
||||
distinct,
|
||||
expr,
|
||||
order_by: within_group,
|
||||
limit: None,
|
||||
within_group: true,
|
||||
}))
|
||||
}
|
||||
|
||||
// This function parses date/time fields for the EXTRACT function-like
|
||||
// operator, interval qualifiers, and the ceil/floor operations.
|
||||
// EXTRACT supports a wider set of date/time fields than interval qualifiers,
|
||||
|
|
|
@ -224,6 +224,17 @@ fn parse_similar_to() {
|
|||
chk(true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_array_agg_func() {
|
||||
for sql in [
|
||||
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
|
||||
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
|
||||
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
|
||||
] {
|
||||
bigquery().verified_stmt(sql);
|
||||
}
|
||||
}
|
||||
|
||||
fn bigquery() -> TestedDialects {
|
||||
TestedDialects {
|
||||
dialects: vec![Box::new(BigQueryDialect {})],
|
||||
|
|
|
@ -1777,6 +1777,27 @@ fn parse_listagg() {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_array_agg_func() {
|
||||
let supported_dialects = TestedDialects {
|
||||
dialects: vec![
|
||||
Box::new(GenericDialect {}),
|
||||
Box::new(PostgreSqlDialect {}),
|
||||
Box::new(MsSqlDialect {}),
|
||||
Box::new(AnsiDialect {}),
|
||||
Box::new(HiveDialect {}),
|
||||
],
|
||||
};
|
||||
|
||||
for sql in [
|
||||
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
|
||||
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
|
||||
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
|
||||
] {
|
||||
supported_dialects.verified_stmt(sql);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_create_table() {
|
||||
let sql = "CREATE TABLE uk_cities (\
|
||||
|
|
|
@ -281,8 +281,8 @@ fn parse_create_function() {
|
|||
#[test]
|
||||
fn filtering_during_aggregation() {
|
||||
let rename = "SELECT \
|
||||
array_agg(name) FILTER (WHERE name IS NOT NULL), \
|
||||
array_agg(name) FILTER (WHERE name LIKE 'a%') \
|
||||
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
|
||||
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
|
||||
FROM region";
|
||||
println!("{}", hive().verified_stmt(rename));
|
||||
}
|
||||
|
@ -290,8 +290,8 @@ fn filtering_during_aggregation() {
|
|||
#[test]
|
||||
fn filtering_during_aggregation_aliased() {
|
||||
let rename = "SELECT \
|
||||
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
|
||||
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
|
||||
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
|
||||
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
|
||||
FROM region";
|
||||
println!("{}", hive().verified_stmt(rename));
|
||||
}
|
||||
|
|
|
@ -334,6 +334,25 @@ fn parse_similar_to() {
|
|||
chk(true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_array_agg_func() {
|
||||
for sql in [
|
||||
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
|
||||
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
|
||||
] {
|
||||
snowflake().verified_stmt(sql);
|
||||
}
|
||||
|
||||
let sql = "select array_agg(x order by x) as a from T";
|
||||
let result = snowflake().parse_sql_statements(sql);
|
||||
assert_eq!(
|
||||
result,
|
||||
Err(ParserError::ParserError(String::from(
|
||||
"Expected ), found: order"
|
||||
)))
|
||||
)
|
||||
}
|
||||
|
||||
fn snowflake() -> TestedDialects {
|
||||
TestedDialects {
|
||||
dialects: vec![Box::new(SnowflakeDialect {})],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue