mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-09-10 07:56:20 +00:00
Support OVER clause for window/analytic functions
Since this changes SQLFunction anyway, changed its `id` field to `name`, as we don't seem to use "id" to mean "name" anywhere else.
This commit is contained in:
parent
c8e7c3b343
commit
d4de248c73
4 changed files with 257 additions and 15 deletions
|
@ -139,6 +139,7 @@ keyword!(
|
|||
FIRST_VALUE,
|
||||
FLOAT,
|
||||
FLOOR,
|
||||
FOLLOWING,
|
||||
FOR,
|
||||
FOREIGN,
|
||||
FRAME_ROW,
|
||||
|
@ -246,6 +247,7 @@ keyword!(
|
|||
POSITION_REGEX,
|
||||
POWER,
|
||||
PRECEDES,
|
||||
PRECEDING,
|
||||
PRECISION,
|
||||
PREPARE,
|
||||
PRIMARY,
|
||||
|
@ -333,6 +335,7 @@ keyword!(
|
|||
TRIM_ARRAY,
|
||||
TRUE,
|
||||
UESCAPE,
|
||||
UNBOUNDED,
|
||||
UNION,
|
||||
UNIQUE,
|
||||
UNKNOWN,
|
||||
|
@ -488,6 +491,7 @@ pub const ALL_KEYWORDS: &'static [&'static str] = &[
|
|||
FIRST_VALUE,
|
||||
FLOAT,
|
||||
FLOOR,
|
||||
FOLLOWING,
|
||||
FOR,
|
||||
FOREIGN,
|
||||
FRAME_ROW,
|
||||
|
@ -595,6 +599,7 @@ pub const ALL_KEYWORDS: &'static [&'static str] = &[
|
|||
POSITION_REGEX,
|
||||
POWER,
|
||||
PRECEDES,
|
||||
PRECEDING,
|
||||
PRECISION,
|
||||
PREPARE,
|
||||
PRIMARY,
|
||||
|
@ -682,6 +687,7 @@ pub const ALL_KEYWORDS: &'static [&'static str] = &[
|
|||
TRIM_ARRAY,
|
||||
TRUE,
|
||||
UESCAPE,
|
||||
UNBOUNDED,
|
||||
UNION,
|
||||
UNIQUE,
|
||||
UNKNOWN,
|
||||
|
|
|
@ -101,7 +101,11 @@ pub enum ASTNode {
|
|||
SQLValue(Value),
|
||||
/// Scalar function call e.g. `LEFT(foo, 5)`
|
||||
/// TODO: this can be a compound SQLObjectName as well (for UDFs)
|
||||
SQLFunction { id: SQLIdent, args: Vec<ASTNode> },
|
||||
SQLFunction {
|
||||
name: SQLIdent,
|
||||
args: Vec<ASTNode>,
|
||||
over: Option<SQLWindowSpec>,
|
||||
},
|
||||
/// CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END
|
||||
SQLCase {
|
||||
// TODO: support optional operand for "simple case"
|
||||
|
@ -171,14 +175,13 @@ impl ToString for ASTNode {
|
|||
format!("{} {}", operator.to_string(), expr.as_ref().to_string())
|
||||
}
|
||||
ASTNode::SQLValue(v) => v.to_string(),
|
||||
ASTNode::SQLFunction { id, args } => format!(
|
||||
"{}({})",
|
||||
id,
|
||||
args.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
),
|
||||
ASTNode::SQLFunction { name, args, over } => {
|
||||
let mut s = format!("{}({})", name, comma_separated_string(args));
|
||||
if let Some(o) = over {
|
||||
s += &format!(" OVER ({})", o.to_string())
|
||||
}
|
||||
s
|
||||
}
|
||||
ASTNode::SQLCase {
|
||||
conditions,
|
||||
results,
|
||||
|
@ -203,6 +206,116 @@ impl ToString for ASTNode {
|
|||
}
|
||||
}
|
||||
|
||||
/// A window specification (i.e. `OVER (PARTITION BY .. ORDER BY .. etc.)`)
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SQLWindowSpec {
|
||||
pub partition_by: Vec<ASTNode>,
|
||||
pub order_by: Vec<SQLOrderByExpr>,
|
||||
pub window_frame: Option<SQLWindowFrame>,
|
||||
}
|
||||
|
||||
impl ToString for SQLWindowSpec {
|
||||
fn to_string(&self) -> String {
|
||||
let mut clauses = vec![];
|
||||
if !self.partition_by.is_empty() {
|
||||
clauses.push(format!(
|
||||
"PARTITION BY {}",
|
||||
comma_separated_string(&self.partition_by)
|
||||
))
|
||||
};
|
||||
if !self.order_by.is_empty() {
|
||||
clauses.push(format!(
|
||||
"ORDER BY {}",
|
||||
comma_separated_string(&self.order_by)
|
||||
))
|
||||
};
|
||||
if let Some(window_frame) = &self.window_frame {
|
||||
if let Some(end_bound) = &window_frame.end_bound {
|
||||
clauses.push(format!(
|
||||
"{} BETWEEN {} AND {}",
|
||||
window_frame.units.to_string(),
|
||||
window_frame.start_bound.to_string(),
|
||||
end_bound.to_string()
|
||||
));
|
||||
} else {
|
||||
clauses.push(format!(
|
||||
"{} {}",
|
||||
window_frame.units.to_string(),
|
||||
window_frame.start_bound.to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
clauses.join(" ")
|
||||
}
|
||||
}
|
||||
|
||||
/// Specifies the data processed by a window function, e.g.
|
||||
/// `RANGE UNBOUNDED PRECEDING` or `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SQLWindowFrame {
|
||||
pub units: SQLWindowFrameUnits,
|
||||
pub start_bound: SQLWindowFrameBound,
|
||||
/// The right bound of the `BETWEEN .. AND` clause.
|
||||
pub end_bound: Option<SQLWindowFrameBound>,
|
||||
// TBD: EXCLUDE
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SQLWindowFrameUnits {
|
||||
Rows,
|
||||
Range,
|
||||
Groups,
|
||||
}
|
||||
|
||||
impl ToString for SQLWindowFrameUnits {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
SQLWindowFrameUnits::Rows => "ROWS".to_string(),
|
||||
SQLWindowFrameUnits::Range => "RANGE".to_string(),
|
||||
SQLWindowFrameUnits::Groups => "GROUPS".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SQLWindowFrameUnits {
|
||||
type Err = ParserError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"ROWS" => Ok(SQLWindowFrameUnits::Rows),
|
||||
"RANGE" => Ok(SQLWindowFrameUnits::Range),
|
||||
"GROUPS" => Ok(SQLWindowFrameUnits::Groups),
|
||||
_ => Err(ParserError::ParserError(format!(
|
||||
"Expected ROWS, RANGE, or GROUPS, found: {}",
|
||||
s
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SQLWindowFrameBound {
|
||||
/// "CURRENT ROW"
|
||||
CurrentRow,
|
||||
/// "<N> PRECEDING" or "UNBOUNDED PRECEDING"
|
||||
Preceding(Option<u64>),
|
||||
/// "<N> FOLLOWING" or "UNBOUNDED FOLLOWING". This can only appear in
|
||||
/// SQLWindowFrame::end_bound.
|
||||
Following(Option<u64>),
|
||||
}
|
||||
|
||||
impl ToString for SQLWindowFrameBound {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
SQLWindowFrameBound::CurrentRow => "CURRENT ROW".to_string(),
|
||||
SQLWindowFrameBound::Preceding(None) => "UNBOUNDED PRECEDING".to_string(),
|
||||
SQLWindowFrameBound::Following(None) => "UNBOUNDED FOLLOWING".to_string(),
|
||||
SQLWindowFrameBound::Preceding(Some(n)) => format!("{} PRECEDING", n),
|
||||
SQLWindowFrameBound::Following(Some(n)) => format!("{} FOLLOWING", n),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A top-level statement (SELECT, INSERT, CREATE, etc.)
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SQLStatement {
|
||||
|
|
|
@ -237,7 +237,7 @@ impl Parser {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn parse_function(&mut self, id: SQLIdent) -> Result<ASTNode, ParserError> {
|
||||
pub fn parse_function(&mut self, name: SQLIdent) -> Result<ASTNode, ParserError> {
|
||||
self.expect_token(&Token::LParen)?;
|
||||
let args = if self.consume_token(&Token::RParen) {
|
||||
vec![]
|
||||
|
@ -246,7 +246,98 @@ impl Parser {
|
|||
self.expect_token(&Token::RParen)?;
|
||||
args
|
||||
};
|
||||
Ok(ASTNode::SQLFunction { id, args })
|
||||
let over = if self.parse_keyword("OVER") {
|
||||
// TBD: support window names (`OVER mywin`) in place of inline specification
|
||||
self.expect_token(&Token::LParen)?;
|
||||
let partition_by = if self.parse_keywords(vec!["PARTITION", "BY"]) {
|
||||
// a list of possibly-qualified column names
|
||||
self.parse_expr_list()?
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let order_by = if self.parse_keywords(vec!["ORDER", "BY"]) {
|
||||
self.parse_order_by_expr_list()?
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let window_frame = self.parse_window_frame()?;
|
||||
|
||||
Some(SQLWindowSpec {
|
||||
partition_by,
|
||||
order_by,
|
||||
window_frame,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ASTNode::SQLFunction { name, args, over })
|
||||
}
|
||||
|
||||
pub fn parse_window_frame(&mut self) -> Result<Option<SQLWindowFrame>, ParserError> {
|
||||
let window_frame = match self.peek_token() {
|
||||
Some(Token::SQLWord(w)) => {
|
||||
let units = w.keyword.parse::<SQLWindowFrameUnits>()?;
|
||||
self.next_token();
|
||||
if self.parse_keyword("BETWEEN") {
|
||||
let start_bound = self.parse_window_frame_bound()?;
|
||||
self.expect_keyword("AND")?;
|
||||
let end_bound = Some(self.parse_window_frame_bound()?);
|
||||
Some(SQLWindowFrame {
|
||||
units,
|
||||
start_bound,
|
||||
end_bound,
|
||||
})
|
||||
} else {
|
||||
let start_bound = self.parse_window_frame_bound()?;
|
||||
let end_bound = None;
|
||||
Some(SQLWindowFrame {
|
||||
units,
|
||||
start_bound,
|
||||
end_bound,
|
||||
})
|
||||
}
|
||||
}
|
||||
Some(Token::RParen) => None,
|
||||
unexpected => {
|
||||
return parser_err!(format!(
|
||||
"Expected 'ROWS', 'RANGE', 'GROUPS', or ')', got {:?}",
|
||||
unexpected
|
||||
));
|
||||
}
|
||||
};
|
||||
self.expect_token(&Token::RParen)?;
|
||||
Ok(window_frame)
|
||||
}
|
||||
|
||||
/// "CURRENT ROW" | ( (<positive number> | "UNBOUNDED") ("PRECEDING" | FOLLOWING) )
|
||||
pub fn parse_window_frame_bound(&mut self) -> Result<SQLWindowFrameBound, ParserError> {
|
||||
if self.parse_keywords(vec!["CURRENT", "ROW"]) {
|
||||
Ok(SQLWindowFrameBound::CurrentRow)
|
||||
} else {
|
||||
let rows = if self.parse_keyword("UNBOUNDED") {
|
||||
None
|
||||
} else {
|
||||
let rows = self.parse_literal_int()?;
|
||||
if rows < 0 {
|
||||
parser_err!(format!(
|
||||
"The number of rows must be non-negative, got {}",
|
||||
rows
|
||||
))?;
|
||||
}
|
||||
Some(rows as u64)
|
||||
};
|
||||
if self.parse_keyword("PRECEDING") {
|
||||
Ok(SQLWindowFrameBound::Preceding(rows))
|
||||
} else if self.parse_keyword("FOLLOWING") {
|
||||
Ok(SQLWindowFrameBound::Following(rows))
|
||||
} else {
|
||||
parser_err!(format!(
|
||||
"Expected PRECEDING or FOLLOWING, found {:?}",
|
||||
self.peek_token()
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_case_expression(&mut self) -> Result<ASTNode, ParserError> {
|
||||
|
|
|
@ -119,8 +119,9 @@ fn parse_select_count_wildcard() {
|
|||
let select = verified_only_select(sql);
|
||||
assert_eq!(
|
||||
&ASTNode::SQLFunction {
|
||||
id: "COUNT".to_string(),
|
||||
name: "COUNT".to_string(),
|
||||
args: vec![ASTNode::SQLWildcard],
|
||||
over: None,
|
||||
},
|
||||
expr_from_projection(only(&select.projection))
|
||||
);
|
||||
|
@ -532,13 +533,43 @@ fn parse_scalar_function_in_projection() {
|
|||
let select = verified_only_select(sql);
|
||||
assert_eq!(
|
||||
&ASTNode::SQLFunction {
|
||||
id: String::from("sqrt"),
|
||||
name: String::from("sqrt"),
|
||||
args: vec![ASTNode::SQLIdentifier(String::from("id"))],
|
||||
over: None,
|
||||
},
|
||||
expr_from_projection(only(&select.projection))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_window_functions() {
|
||||
let sql = "SELECT row_number() OVER (ORDER BY dt DESC), \
|
||||
sum(foo) OVER (PARTITION BY a, b ORDER BY c, d \
|
||||
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), \
|
||||
avg(bar) OVER (ORDER BY a \
|
||||
RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), \
|
||||
max(baz) OVER (ORDER BY a \
|
||||
ROWS UNBOUNDED PRECEDING) \
|
||||
FROM foo";
|
||||
let select = verified_only_select(sql);
|
||||
assert_eq!(4, select.projection.len());
|
||||
assert_eq!(
|
||||
&ASTNode::SQLFunction {
|
||||
name: "row_number".to_string(),
|
||||
args: vec![],
|
||||
over: Some(SQLWindowSpec {
|
||||
partition_by: vec![],
|
||||
order_by: vec![SQLOrderByExpr {
|
||||
expr: ASTNode::SQLIdentifier("dt".to_string()),
|
||||
asc: Some(false)
|
||||
}],
|
||||
window_frame: None,
|
||||
})
|
||||
},
|
||||
expr_from_projection(&select.projection[0])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_aggregate_with_group_by() {
|
||||
let sql = "SELECT a, COUNT(1), MIN(b), MAX(b) FROM foo GROUP BY a";
|
||||
|
@ -605,8 +636,9 @@ fn parse_delimited_identifiers() {
|
|||
);
|
||||
assert_eq!(
|
||||
&ASTNode::SQLFunction {
|
||||
id: r#""myfun""#.to_string(),
|
||||
args: vec![]
|
||||
name: r#""myfun""#.to_string(),
|
||||
args: vec![],
|
||||
over: None,
|
||||
},
|
||||
expr_from_projection(&select.projection[1]),
|
||||
);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue