diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 8076d82c..8222b34d 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -13,11 +13,16 @@ //! SQL Abstract Syntax Tree (AST) types #[cfg(not(feature = "std"))] use alloc::{ + borrow::Cow, boxed::Box, format, string::{String, ToString}, vec::Vec, }; + +#[cfg(feature = "std")] +use std::borrow::Cow; + use core::fmt::{self, Display}; #[cfg(feature = "serde")] @@ -1406,6 +1411,35 @@ impl fmt::Display for NullTreatment { } } +/// Specifies Ignore / Respect NULL within window functions. +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum NullTreatmentType { + /// The declaration is part of the function's arguments. + /// + /// ```sql + /// FIRST_VALUE(x IGNORE NULLS) OVER () + /// ``` + FunctionArg(NullTreatment), + /// The declaration occurs after the function call. + /// + /// ```sql + /// FIRST_VALUE(x) IGNORE NULLS OVER () + /// ``` + AfterFunction(NullTreatment), +} + +impl Display for NullTreatmentType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let null_treatment = match self { + NullTreatmentType::FunctionArg(n) => n, + NullTreatmentType::AfterFunction(n) => n, + }; + write!(f, "{null_treatment}") + } +} + /// Specifies [WindowFrame]'s `start_bound` and `end_bound` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -4787,15 +4821,18 @@ pub struct Function { pub args: Vec, /// e.g. `x > 5` in `COUNT(x) FILTER (WHERE x > 5)` pub filter: Option>, - // Snowflake/MSSQL supports different options for null treatment in rank functions - pub null_treatment: Option, + /// Specifies Ignore / Respect NULL within window functions. + /// + /// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/navigation_functions#first_value) + /// [Snowflake](https://docs.snowflake.com/en/sql-reference/functions/first_value) + pub null_treatment: Option, pub over: Option, - // aggregate functions may specify eg `COUNT(DISTINCT x)` + /// aggregate functions may specify eg `COUNT(DISTINCT x)` pub distinct: bool, - // Some functions must be called without trailing parentheses, for example Postgres - // do it for current_catalog, current_schema, etc. This flags is used for formatting. + /// Some functions must be called without trailing parentheses, for example Postgres + /// do it for current_catalog, current_schema, etc. This flags is used for formatting. pub special: bool, - // Required ordering for the function (if empty, there is no requirement). + /// Required ordering for the function (if empty, there is no requirement). pub order_by: Vec, } @@ -4830,19 +4867,25 @@ impl fmt::Display for Function { }; write!( f, - "{}({}{}{order_by}{})", + "{}({}{}{order_by}{}{})", self.name, if self.distinct { "DISTINCT " } else { "" }, display_comma_separated(&self.args), display_comma_separated(&self.order_by), + match self.null_treatment { + Some(NullTreatmentType::FunctionArg(null_treatment)) => { + Cow::from(format!(" {null_treatment}")) + } + _ => Cow::from(""), + } )?; if let Some(filter_cond) = &self.filter { write!(f, " FILTER (WHERE {filter_cond})")?; } - if let Some(o) = &self.null_treatment { - write!(f, " {o}")?; + if let Some(NullTreatmentType::AfterFunction(null_treatment)) = &self.null_treatment { + write!(f, " {null_treatment}")?; } if let Some(o) = &self.over { diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index 7475b6a5..b945587c 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -30,6 +30,11 @@ impl Dialect for BigQueryDialect { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_' } + /// See [doc](https://cloud.google.com/bigquery/docs/reference/standard-sql/navigation_functions#first_value) + fn supports_window_function_null_treatment_arg(&self) -> bool { + true + } + // See https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#escape_sequences fn supports_string_literal_backslash_escape(&self) -> bool { true diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index fd5e1dcf..b5a5ae21 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -51,6 +51,10 @@ impl Dialect for GenericDialect { true } + fn supports_window_function_null_treatment_arg(&self) -> bool { + true + } + fn supports_dictionary_syntax(&self) -> bool { true } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 8089d66b..d3257aba 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -185,6 +185,20 @@ pub trait Dialect: Debug + Any { fn supports_named_fn_args_with_eq_operator(&self) -> bool { false } + /// Returns true if the dialects supports specifying null treatment + /// as part of a window function's parameter list. As opposed + /// to after the parameter list. + /// i.e The following syntax returns true + /// ```sql + /// FIRST_VALUE(a IGNORE NULLS) OVER () + /// ``` + /// while the following syntax returns false + /// ```sql + /// FIRST_VALUE(a) IGNORE NULLS OVER () + /// ``` + fn supports_window_function_null_treatment_arg(&self) -> bool { + false + } /// Returns true if the dialect supports defining structs or objects using a /// syntax like `{'x': 1, 'y': 2, 'z': 3}`. fn supports_dictionary_syntax(&self) -> bool { diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 18c1bf73..0c5c3d12 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -208,6 +208,13 @@ impl From for MatchedTrailingBracket { } } +/// Output of the [`Parser::parse_window_function_args`] function. +struct ParseWindowFunctionArgsOutput { + args: Vec, + order_by: Vec, + null_treatment: Option, +} + /// Options that control how the [`Parser`] parses SQL text #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParserOptions { @@ -1229,7 +1236,11 @@ impl<'a> Parser<'a> { pub fn parse_function(&mut self, name: ObjectName) -> Result { self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?.is_some(); - let (args, order_by) = self.parse_optional_args_with_orderby()?; + let ParseWindowFunctionArgsOutput { + args, + order_by, + null_treatment, + } = self.parse_window_function_args()?; let filter = if self.dialect.supports_filter_during_aggregation() && self.parse_keyword(Keyword::FILTER) && self.consume_token(&Token::LParen) @@ -1241,19 +1252,15 @@ impl<'a> Parser<'a> { } else { None }; - let null_treatment = match self.parse_one_of_keywords(&[Keyword::RESPECT, Keyword::IGNORE]) - { - Some(keyword) => { - self.expect_keyword(Keyword::NULLS)?; - match keyword { - Keyword::RESPECT => Some(NullTreatment::RespectNulls), - Keyword::IGNORE => Some(NullTreatment::IgnoreNulls), - _ => None, - } - } - None => None, - }; + // Syntax for null treatment shows up either in the args list + // or after the function call, but not both. + let mut null_treatment = null_treatment.map(NullTreatmentType::FunctionArg); + if null_treatment.is_none() { + null_treatment = self + .parse_null_treatment()? + .map(NullTreatmentType::AfterFunction); + } let over = if self.parse_keyword(Keyword::OVER) { if self.consume_token(&Token::LParen) { let window_spec = self.parse_window_spec()?; @@ -1276,17 +1283,37 @@ impl<'a> Parser<'a> { })) } + /// Optionally parses a null treatment clause. + fn parse_null_treatment(&mut self) -> Result, ParserError> { + match self.parse_one_of_keywords(&[Keyword::RESPECT, Keyword::IGNORE]) { + Some(keyword) => { + self.expect_keyword(Keyword::NULLS)?; + + Ok(match keyword { + Keyword::RESPECT => Some(NullTreatment::RespectNulls), + Keyword::IGNORE => Some(NullTreatment::IgnoreNulls), + _ => None, + }) + } + None => Ok(None), + } + } + pub fn parse_time_functions(&mut self, name: ObjectName) -> Result { - let (args, order_by, special) = if self.consume_token(&Token::LParen) { - let (args, order_by) = self.parse_optional_args_with_orderby()?; - (args, order_by, false) + let (args, order_by, null_treatment, special) = if self.consume_token(&Token::LParen) { + let ParseWindowFunctionArgsOutput { + args, + order_by, + null_treatment, + } = self.parse_window_function_args()?; + (args, order_by, null_treatment, false) } else { - (vec![], vec![], true) + (vec![], vec![], None, true) }; Ok(Expr::Function(Function { name, args, - null_treatment: None, + null_treatment: null_treatment.map(NullTreatmentType::FunctionArg), filter: None, over: None, distinct: false, @@ -9326,11 +9353,21 @@ impl<'a> Parser<'a> { } } - pub fn parse_optional_args_with_orderby( - &mut self, - ) -> Result<(Vec, Vec), ParserError> { + /// Parses a potentially empty list of arguments to a window function + /// (including the closing parenthesis). + /// + /// Examples: + /// ```sql + /// FIRST_VALUE(x ORDER BY 1,2,3); + /// FIRST_VALUE(x IGNORE NULL); + /// ``` + fn parse_window_function_args(&mut self) -> Result { if self.consume_token(&Token::RParen) { - Ok((vec![], vec![])) + Ok(ParseWindowFunctionArgsOutput { + args: vec![], + order_by: vec![], + null_treatment: None, + }) } else { // Snowflake permits a subquery to be passed as an argument without // an enclosing set of parens if it's the only argument. @@ -9342,22 +9379,34 @@ impl<'a> Parser<'a> { self.prev_token(); let subquery = self.parse_boxed_query()?; self.expect_token(&Token::RParen)?; - return Ok(( - vec![FunctionArg::Unnamed(FunctionArgExpr::from(Expr::Subquery( + return Ok(ParseWindowFunctionArgsOutput { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::from(Expr::Subquery( subquery, )))], - vec![], - )); + order_by: vec![], + null_treatment: None, + }); } let args = self.parse_comma_separated(Parser::parse_function_args)?; let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { self.parse_comma_separated(Parser::parse_order_by_expr)? } else { - vec![] + Default::default() }; + + let null_treatment = if self.dialect.supports_window_function_null_treatment_arg() { + self.parse_null_treatment()? + } else { + None + }; + self.expect_token(&Token::RParen)?; - Ok((args, order_by)) + Ok(ParseWindowFunctionArgsOutput { + args, + order_by, + null_treatment, + }) } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 852e6947..8040899c 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -2644,6 +2644,58 @@ fn parse_window_rank_function() { } } +#[test] +fn parse_window_function_null_treatment_arg() { + let dialects = all_dialects_where(|d| d.supports_window_function_null_treatment_arg()); + let sql = "SELECT \ + FIRST_VALUE(a IGNORE NULLS) OVER (), \ + FIRST_VALUE(b RESPECT NULLS) OVER () \ + FROM mytable"; + let Select { projection, .. } = dialects.verified_only_select(sql); + for (i, (expected_expr, expected_null_treatment)) in [ + ("a", NullTreatment::IgnoreNulls), + ("b", NullTreatment::RespectNulls), + ] + .into_iter() + .enumerate() + { + let SelectItem::UnnamedExpr(Expr::Function(actual)) = &projection[i] else { + unreachable!() + }; + assert_eq!(ObjectName(vec![Ident::new("FIRST_VALUE")]), actual.name); + assert!(actual.order_by.is_empty()); + assert_eq!(1, actual.args.len()); + let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(actual_expr))) = + &actual.args[0] + else { + unreachable!() + }; + assert_eq!(&Ident::new(expected_expr), actual_expr); + let Some(NullTreatmentType::FunctionArg(actual_null_treatment)) = actual.null_treatment + else { + unreachable!() + }; + assert_eq!(expected_null_treatment, actual_null_treatment); + } + + let sql = "SELECT FIRST_VALUE(a ORDER BY b IGNORE NULLS) OVER () FROM t1"; + dialects.verified_stmt(sql); + + let sql = "SELECT LAG(1 IGNORE NULLS) IGNORE NULLS OVER () FROM t1"; + assert_eq!( + dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected end of statement, found: NULLS".to_string()) + ); + + let sql = "SELECT LAG(1 IGNORE NULLS) IGNORE NULLS OVER () FROM t1"; + assert_eq!( + all_dialects_where(|d| !d.supports_window_function_null_treatment_arg()) + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError("Expected ), found: IGNORE".to_string()) + ); +} + #[test] fn parse_create_table() { let sql = "CREATE TABLE uk_cities (\