mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-07-19 06:15:00 +00:00
Parse ARRAY_AGG for Bigquery and Snowflake (#662)
This commit is contained in:
parent
0428ac742b
commit
87b4a168cb
8 changed files with 156 additions and 4 deletions
|
@ -416,6 +416,8 @@ pub enum Expr {
|
||||||
ArraySubquery(Box<Query>),
|
ArraySubquery(Box<Query>),
|
||||||
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
|
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
|
||||||
ListAgg(ListAgg),
|
ListAgg(ListAgg),
|
||||||
|
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
|
||||||
|
ArrayAgg(ArrayAgg),
|
||||||
/// The `GROUPING SETS` expr.
|
/// The `GROUPING SETS` expr.
|
||||||
GroupingSets(Vec<Vec<Expr>>),
|
GroupingSets(Vec<Vec<Expr>>),
|
||||||
/// The `CUBE` expr.
|
/// The `CUBE` expr.
|
||||||
|
@ -655,6 +657,7 @@ impl fmt::Display for Expr {
|
||||||
Expr::Subquery(s) => write!(f, "({})", s),
|
Expr::Subquery(s) => write!(f, "({})", s),
|
||||||
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
|
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
|
||||||
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
|
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
|
||||||
|
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
|
||||||
Expr::GroupingSets(sets) => {
|
Expr::GroupingSets(sets) => {
|
||||||
write!(f, "GROUPING SETS (")?;
|
write!(f, "GROUPING SETS (")?;
|
||||||
let mut sep = "";
|
let mut sep = "";
|
||||||
|
@ -3036,6 +3039,45 @@ impl fmt::Display for ListAggOnOverflow {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
|
||||||
|
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
|
||||||
|
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
pub struct ArrayAgg {
|
||||||
|
pub distinct: bool,
|
||||||
|
pub expr: Box<Expr>,
|
||||||
|
pub order_by: Option<Box<OrderByExpr>>,
|
||||||
|
pub limit: Option<Box<Expr>>,
|
||||||
|
pub within_group: bool, // order by is used inside a within group or not
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ArrayAgg {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"ARRAY_AGG({}{}",
|
||||||
|
if self.distinct { "DISTINCT " } else { "" },
|
||||||
|
self.expr
|
||||||
|
)?;
|
||||||
|
if !self.within_group {
|
||||||
|
if let Some(order_by) = &self.order_by {
|
||||||
|
write!(f, " ORDER BY {}", order_by)?;
|
||||||
|
}
|
||||||
|
if let Some(limit) = &self.limit {
|
||||||
|
write!(f, " LIMIT {}", limit)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, ")")?;
|
||||||
|
if self.within_group {
|
||||||
|
if let Some(order_by) = &self.order_by {
|
||||||
|
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
pub enum ObjectType {
|
pub enum ObjectType {
|
||||||
|
|
|
@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any {
|
||||||
fn supports_filter_during_aggregation(&self) -> bool {
|
fn supports_filter_during_aggregation(&self) -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
|
||||||
|
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
|
||||||
|
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
|
||||||
|
fn supports_within_after_array_aggregation(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
/// Dialect-specific prefix parser override
|
/// Dialect-specific prefix parser override
|
||||||
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
|
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
|
||||||
// return None to fall back to the default behavior
|
// return None to fall back to the default behavior
|
||||||
|
|
|
@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
|
||||||
|| ch == '$'
|
|| ch == '$'
|
||||||
|| ch == '_'
|
|| ch == '_'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_within_after_array_aggregation(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -473,6 +473,7 @@ impl<'a> Parser<'a> {
|
||||||
self.expect_token(&Token::LParen)?;
|
self.expect_token(&Token::LParen)?;
|
||||||
self.parse_array_subquery()
|
self.parse_array_subquery()
|
||||||
}
|
}
|
||||||
|
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
|
||||||
Keyword::NOT => self.parse_not(),
|
Keyword::NOT => self.parse_not(),
|
||||||
// Here `w` is a word, check if it's a part of a multi-part
|
// Here `w` is a word, check if it's a part of a multi-part
|
||||||
// identifier, a function call, or a simple identifier:
|
// identifier, a function call, or a simple identifier:
|
||||||
|
@ -1071,6 +1072,54 @@ impl<'a> Parser<'a> {
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
|
||||||
|
self.expect_token(&Token::LParen)?;
|
||||||
|
let distinct = self.parse_keyword(Keyword::DISTINCT);
|
||||||
|
let expr = Box::new(self.parse_expr()?);
|
||||||
|
// ANSI SQL and BigQuery define ORDER BY inside function.
|
||||||
|
if !self.dialect.supports_within_after_array_aggregation() {
|
||||||
|
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
|
||||||
|
let order_by_expr = self.parse_order_by_expr()?;
|
||||||
|
Some(Box::new(order_by_expr))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let limit = if self.parse_keyword(Keyword::LIMIT) {
|
||||||
|
self.parse_limit()?.map(Box::new)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
self.expect_token(&Token::RParen)?;
|
||||||
|
return Ok(Expr::ArrayAgg(ArrayAgg {
|
||||||
|
distinct,
|
||||||
|
expr,
|
||||||
|
order_by,
|
||||||
|
limit,
|
||||||
|
within_group: false,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
// Snowflake defines ORDERY BY in within group instead of inside the function like
|
||||||
|
// ANSI SQL.
|
||||||
|
self.expect_token(&Token::RParen)?;
|
||||||
|
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
|
||||||
|
self.expect_token(&Token::LParen)?;
|
||||||
|
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
|
||||||
|
let order_by_expr = self.parse_order_by_expr()?;
|
||||||
|
self.expect_token(&Token::RParen)?;
|
||||||
|
Some(Box::new(order_by_expr))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Expr::ArrayAgg(ArrayAgg {
|
||||||
|
distinct,
|
||||||
|
expr,
|
||||||
|
order_by: within_group,
|
||||||
|
limit: None,
|
||||||
|
within_group: true,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
// This function parses date/time fields for the EXTRACT function-like
|
// This function parses date/time fields for the EXTRACT function-like
|
||||||
// operator, interval qualifiers, and the ceil/floor operations.
|
// operator, interval qualifiers, and the ceil/floor operations.
|
||||||
// EXTRACT supports a wider set of date/time fields than interval qualifiers,
|
// EXTRACT supports a wider set of date/time fields than interval qualifiers,
|
||||||
|
|
|
@ -224,6 +224,17 @@ fn parse_similar_to() {
|
||||||
chk(true);
|
chk(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_array_agg_func() {
|
||||||
|
for sql in [
|
||||||
|
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
|
||||||
|
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
|
||||||
|
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
|
||||||
|
] {
|
||||||
|
bigquery().verified_stmt(sql);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn bigquery() -> TestedDialects {
|
fn bigquery() -> TestedDialects {
|
||||||
TestedDialects {
|
TestedDialects {
|
||||||
dialects: vec![Box::new(BigQueryDialect {})],
|
dialects: vec![Box::new(BigQueryDialect {})],
|
||||||
|
|
|
@ -1777,6 +1777,27 @@ fn parse_listagg() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_array_agg_func() {
|
||||||
|
let supported_dialects = TestedDialects {
|
||||||
|
dialects: vec![
|
||||||
|
Box::new(GenericDialect {}),
|
||||||
|
Box::new(PostgreSqlDialect {}),
|
||||||
|
Box::new(MsSqlDialect {}),
|
||||||
|
Box::new(AnsiDialect {}),
|
||||||
|
Box::new(HiveDialect {}),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
for sql in [
|
||||||
|
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
|
||||||
|
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
|
||||||
|
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
|
||||||
|
] {
|
||||||
|
supported_dialects.verified_stmt(sql);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_create_table() {
|
fn parse_create_table() {
|
||||||
let sql = "CREATE TABLE uk_cities (\
|
let sql = "CREATE TABLE uk_cities (\
|
||||||
|
|
|
@ -281,8 +281,8 @@ fn parse_create_function() {
|
||||||
#[test]
|
#[test]
|
||||||
fn filtering_during_aggregation() {
|
fn filtering_during_aggregation() {
|
||||||
let rename = "SELECT \
|
let rename = "SELECT \
|
||||||
array_agg(name) FILTER (WHERE name IS NOT NULL), \
|
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
|
||||||
array_agg(name) FILTER (WHERE name LIKE 'a%') \
|
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
|
||||||
FROM region";
|
FROM region";
|
||||||
println!("{}", hive().verified_stmt(rename));
|
println!("{}", hive().verified_stmt(rename));
|
||||||
}
|
}
|
||||||
|
@ -290,8 +290,8 @@ fn filtering_during_aggregation() {
|
||||||
#[test]
|
#[test]
|
||||||
fn filtering_during_aggregation_aliased() {
|
fn filtering_during_aggregation_aliased() {
|
||||||
let rename = "SELECT \
|
let rename = "SELECT \
|
||||||
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
|
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
|
||||||
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
|
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
|
||||||
FROM region";
|
FROM region";
|
||||||
println!("{}", hive().verified_stmt(rename));
|
println!("{}", hive().verified_stmt(rename));
|
||||||
}
|
}
|
||||||
|
|
|
@ -334,6 +334,25 @@ fn parse_similar_to() {
|
||||||
chk(true);
|
chk(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_array_agg_func() {
|
||||||
|
for sql in [
|
||||||
|
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
|
||||||
|
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
|
||||||
|
] {
|
||||||
|
snowflake().verified_stmt(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
let sql = "select array_agg(x order by x) as a from T";
|
||||||
|
let result = snowflake().parse_sql_statements(sql);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
Err(ParserError::ParserError(String::from(
|
||||||
|
"Expected ), found: order"
|
||||||
|
)))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn snowflake() -> TestedDialects {
|
fn snowflake() -> TestedDialects {
|
||||||
TestedDialects {
|
TestedDialects {
|
||||||
dialects: vec![Box::new(SnowflakeDialect {})],
|
dialects: vec![Box::new(SnowflakeDialect {})],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue