From aa714e3447ebb7e470d746607a9c6441717f1b46 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 6 Sep 2024 15:16:09 +0100 Subject: [PATCH] Fix `INTERVAL` parsing to support expressions and units via dialect (#1398) --- src/dialect/ansi.rs | 4 + src/dialect/bigquery.rs | 4 + src/dialect/clickhouse.rs | 4 + src/dialect/databricks.rs | 4 + src/dialect/hive.rs | 4 + src/dialect/mod.rs | 21 +++ src/dialect/mysql.rs | 4 + src/parser/mod.rs | 145 +++++++++----------- tests/sqlparser_bigquery.rs | 21 ++- tests/sqlparser_common.rs | 254 ++++++++++++++++++++++++++++++------ 10 files changed, 331 insertions(+), 134 deletions(-) diff --git a/src/dialect/ansi.rs b/src/dialect/ansi.rs index d07bc07e..61ae5829 100644 --- a/src/dialect/ansi.rs +++ b/src/dialect/ansi.rs @@ -24,4 +24,8 @@ impl Dialect for AnsiDialect { fn is_identifier_part(&self, ch: char) -> bool { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_' } + + fn require_interval_qualifier(&self) -> bool { + true + } } diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index d3673337..3bce6702 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -63,4 +63,8 @@ impl Dialect for BigQueryDialect { fn supports_select_wildcard_except(&self) -> bool { true } + + fn require_interval_qualifier(&self) -> bool { + true + } } diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index 34940467..09735cbe 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -37,4 +37,8 @@ impl Dialect for ClickHouseDialect { fn describe_requires_table_keyword(&self) -> bool { true } + + fn require_interval_qualifier(&self) -> bool { + true + } } diff --git a/src/dialect/databricks.rs b/src/dialect/databricks.rs index 42d432d3..d3661444 100644 --- a/src/dialect/databricks.rs +++ b/src/dialect/databricks.rs @@ -38,4 +38,8 @@ impl Dialect for DatabricksDialect { fn supports_select_wildcard_except(&self) -> bool { true } + + fn require_interval_qualifier(&self) -> bool { + true + } } diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 2340df61..b32d44cb 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -42,4 +42,8 @@ impl Dialect for HiveDialect { fn supports_numeric_prefix(&self) -> bool { true } + + fn require_interval_qualifier(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 2a74d992..0be8c17c 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -515,6 +515,27 @@ pub trait Dialect: Debug + Any { fn supports_create_index_with_clause(&self) -> bool { false } + + /// Whether `INTERVAL` expressions require units (called "qualifiers" in the ANSI SQL spec) to be specified, + /// e.g. `INTERVAL 1 DAY` vs `INTERVAL 1`. + /// + /// Expressions within intervals (e.g. `INTERVAL '1' + '1' DAY`) are only allowed when units are required. + /// + /// See for more information. + /// + /// When `true`: + /// * `INTERVAL '1' DAY` is VALID + /// * `INTERVAL 1 + 1 DAY` is VALID + /// * `INTERVAL '1' + '1' DAY` is VALID + /// * `INTERVAL '1'` is INVALID + /// + /// When `false`: + /// * `INTERVAL '1'` is VALID + /// * `INTERVAL '1' DAY` is VALID — unit is not required, but still allowed + /// * `INTERVAL 1 + 1 DAY` is INVALID + fn require_interval_qualifier(&self) -> bool { + false + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 32525658..b8c4631f 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -84,6 +84,10 @@ impl Dialect for MySqlDialect { None } } + + fn require_interval_qualifier(&self) -> bool { + true + } } /// `LOCK TABLES` diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 30e77678..26e9e05f 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -56,15 +56,6 @@ macro_rules! parser_err { }; } -// Returns a successful result if the optional expression is some -macro_rules! return_ok_if_some { - ($e:expr) => {{ - if let Some(v) = $e { - return Ok(v); - } - }}; -} - #[cfg(feature = "std")] /// Implementation [`RecursionCounter`] if std is available mod recursion { @@ -928,35 +919,6 @@ impl<'a> Parser<'a> { Ok(expr) } - pub fn parse_interval_expr(&mut self) -> Result { - let precedence = self.dialect.prec_unknown(); - let mut expr = self.parse_prefix()?; - - loop { - let next_precedence = self.get_next_interval_precedence()?; - - if precedence >= next_precedence { - break; - } - - expr = self.parse_infix(expr, next_precedence)?; - } - - Ok(expr) - } - - /// Get the precedence of the next token, with AND, OR, and XOR. - pub fn get_next_interval_precedence(&self) -> Result { - let token = self.peek_token(); - - match token.token { - Token::Word(w) if w.keyword == Keyword::AND => Ok(self.dialect.prec_unknown()), - Token::Word(w) if w.keyword == Keyword::OR => Ok(self.dialect.prec_unknown()), - Token::Word(w) if w.keyword == Keyword::XOR => Ok(self.dialect.prec_unknown()), - _ => self.get_next_precedence(), - } - } - pub fn parse_assert(&mut self) -> Result { let condition = self.parse_expr()?; let message = if self.parse_keyword(Keyword::AS) { @@ -1004,7 +966,7 @@ impl<'a> Parser<'a> { // name is not followed by a string literal, but in fact in PostgreSQL it is a valid // expression that should parse as the column name "date". let loc = self.peek_token().location; - return_ok_if_some!(self.maybe_parse(|parser| { + let opt_expr = self.maybe_parse(|parser| { match parser.parse_data_type()? { DataType::Interval => parser.parse_interval(), // PostgreSQL allows almost any identifier to be used as custom data type name, @@ -1020,7 +982,11 @@ impl<'a> Parser<'a> { value: parser.parse_literal_string()?, }), } - })); + }); + + if let Some(expr) = opt_expr { + return Ok(expr); + } let next_token = self.next_token(); let expr = match next_token.token { @@ -2110,52 +2076,32 @@ impl<'a> Parser<'a> { // don't currently try to parse it. (The sign can instead be included // inside the value string.) - // The first token in an interval is a string literal which specifies - // the duration of the interval. - let value = self.parse_interval_expr()?; + // to match the different flavours of INTERVAL syntax, we only allow expressions + // if the dialect requires an interval qualifier, + // see https://github.com/sqlparser-rs/sqlparser-rs/pull/1398 for more details + let value = if self.dialect.require_interval_qualifier() { + // parse a whole expression so `INTERVAL 1 + 1 DAY` is valid + self.parse_expr()? + } else { + // parse a prefix expression so `INTERVAL 1 DAY` is valid, but `INTERVAL 1 + 1 DAY` is not + // this also means that `INTERVAL '5 days' > INTERVAL '1 day'` treated properly + self.parse_prefix()? + }; // Following the string literal is a qualifier which indicates the units // of the duration specified in the string literal. // // Note that PostgreSQL allows omitting the qualifier, so we provide // this more general implementation. - let leading_field = match self.peek_token().token { - Token::Word(kw) - if [ - Keyword::YEAR, - Keyword::MONTH, - Keyword::WEEK, - Keyword::DAY, - Keyword::HOUR, - Keyword::MINUTE, - Keyword::SECOND, - Keyword::CENTURY, - Keyword::DECADE, - Keyword::DOW, - Keyword::DOY, - Keyword::EPOCH, - Keyword::ISODOW, - Keyword::ISOYEAR, - Keyword::JULIAN, - Keyword::MICROSECOND, - Keyword::MICROSECONDS, - Keyword::MILLENIUM, - Keyword::MILLENNIUM, - Keyword::MILLISECOND, - Keyword::MILLISECONDS, - Keyword::NANOSECOND, - Keyword::NANOSECONDS, - Keyword::QUARTER, - Keyword::TIMEZONE, - Keyword::TIMEZONE_HOUR, - Keyword::TIMEZONE_MINUTE, - ] - .iter() - .any(|d| kw.keyword == *d) => - { - Some(self.parse_date_time_field()?) - } - _ => None, + let leading_field = if self.next_token_is_temporal_unit() { + Some(self.parse_date_time_field()?) + } else if self.dialect.require_interval_qualifier() { + return parser_err!( + "INTERVAL requires a unit after the literal value", + self.peek_token().location + ); + } else { + None }; let (leading_precision, last_field, fsec_precision) = @@ -2192,6 +2138,45 @@ impl<'a> Parser<'a> { })) } + /// Peek at the next token and determine if it is a temporal unit + /// like `second`. + pub fn next_token_is_temporal_unit(&mut self) -> bool { + if let Token::Word(word) = self.peek_token().token { + matches!( + word.keyword, + Keyword::YEAR + | Keyword::MONTH + | Keyword::WEEK + | Keyword::DAY + | Keyword::HOUR + | Keyword::MINUTE + | Keyword::SECOND + | Keyword::CENTURY + | Keyword::DECADE + | Keyword::DOW + | Keyword::DOY + | Keyword::EPOCH + | Keyword::ISODOW + | Keyword::ISOYEAR + | Keyword::JULIAN + | Keyword::MICROSECOND + | Keyword::MICROSECONDS + | Keyword::MILLENIUM + | Keyword::MILLENNIUM + | Keyword::MILLISECOND + | Keyword::MILLISECONDS + | Keyword::NANOSECOND + | Keyword::NANOSECONDS + | Keyword::QUARTER + | Keyword::TIMEZONE + | Keyword::TIMEZONE_HOUR + | Keyword::TIMEZONE_MINUTE + ) + } else { + false + } + } + /// Bigquery specific: Parse a struct literal /// Syntax /// ```sql diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 57cf9d7f..4f84b376 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -13,7 +13,6 @@ #[macro_use] mod test_utils; -use sqlparser::ast; use std::ops::Deref; use sqlparser::ast::*; @@ -830,16 +829,14 @@ fn parse_typed_struct_syntax_bigquery() { expr_from_projection(&select.projection[3]) ); - let sql = r#"SELECT STRUCT(INTERVAL '1-2 3 4:5:6.789999'), STRUCT(JSON '{"class" : {"students" : [{"name" : "Jane"}]}}')"#; + let sql = r#"SELECT STRUCT(INTERVAL '2' HOUR), STRUCT(JSON '{"class" : {"students" : [{"name" : "Jane"}]}}')"#; let select = bigquery().verified_only_select(sql); assert_eq!(2, select.projection.len()); assert_eq!( &Expr::Struct { - values: vec![Expr::Interval(ast::Interval { - value: Box::new(Expr::Value(Value::SingleQuotedString( - "1-2 3 4:5:6.789999".to_string() - ))), - leading_field: None, + values: vec![Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString("2".to_string()))), + leading_field: Some(DateTimeField::Hour), leading_precision: None, last_field: None, fractional_seconds_precision: None @@ -1141,16 +1138,14 @@ fn parse_typed_struct_syntax_bigquery_and_generic() { expr_from_projection(&select.projection[3]) ); - let sql = r#"SELECT STRUCT(INTERVAL '1-2 3 4:5:6.789999'), STRUCT(JSON '{"class" : {"students" : [{"name" : "Jane"}]}}')"#; + let sql = r#"SELECT STRUCT(INTERVAL '1' MONTH), STRUCT(JSON '{"class" : {"students" : [{"name" : "Jane"}]}}')"#; let select = bigquery_and_generic().verified_only_select(sql); assert_eq!(2, select.projection.len()); assert_eq!( &Expr::Struct { - values: vec![Expr::Interval(ast::Interval { - value: Box::new(Expr::Value(Value::SingleQuotedString( - "1-2 3 4:5:6.789999".to_string() - ))), - leading_field: None, + values: vec![Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString("1".to_string()))), + leading_field: Some(DateTimeField::Month), leading_precision: None, last_field: None, fractional_seconds_precision: None diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index fbe97171..0ed677db 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -4472,7 +4472,8 @@ fn parse_window_functions() { sum(qux) OVER (ORDER BY a \ GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \ FROM foo"; - let select = verified_only_select(sql); + let dialects = all_dialects_except(|d| d.require_interval_qualifier()); + let select = dialects.verified_only_select(sql); const EXPECTED_PROJ_QTY: usize = 7; assert_eq!(EXPECTED_PROJ_QTY, select.projection.len()); @@ -4886,7 +4887,9 @@ fn parse_literal_timestamp_with_time_zone() { } #[test] -fn parse_interval() { +fn parse_interval_all() { + // these intervals expressions all work with both variants of INTERVAL + let sql = "SELECT INTERVAL '1-1' YEAR TO MONTH"; let select = verified_only_select(sql); assert_eq!( @@ -4954,23 +4957,6 @@ fn parse_interval() { expr_from_projection(only(&select.projection)), ); - let sql = "SELECT INTERVAL 1 + 1 DAY"; - let select = verified_only_select(sql); - assert_eq!( - &Expr::Interval(Interval { - value: Box::new(Expr::BinaryOp { - left: Box::new(Expr::Value(number("1"))), - op: BinaryOperator::Plus, - right: Box::new(Expr::Value(number("1"))), - }), - leading_field: Some(DateTimeField::Day), - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }), - expr_from_projection(only(&select.projection)), - ); - let sql = "SELECT INTERVAL '10' HOUR (1)"; let select = verified_only_select(sql); assert_eq!( @@ -4984,21 +4970,6 @@ fn parse_interval() { expr_from_projection(only(&select.projection)), ); - let sql = "SELECT INTERVAL '1 DAY'"; - let select = verified_only_select(sql); - assert_eq!( - &Expr::Interval(Interval { - value: Box::new(Expr::Value(Value::SingleQuotedString(String::from( - "1 DAY" - )))), - leading_field: None, - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }), - expr_from_projection(only(&select.projection)), - ); - let result = parse_sql_statements("SELECT INTERVAL '1' SECOND TO SECOND"); assert_eq!( ParserError::ParserError("Expected: end of statement, found: SECOND".to_string()), @@ -5024,12 +4995,212 @@ fn parse_interval() { verified_only_select("SELECT INTERVAL '1' HOUR TO MINUTE"); verified_only_select("SELECT INTERVAL '1' HOUR TO SECOND"); verified_only_select("SELECT INTERVAL '1' MINUTE TO SECOND"); - verified_only_select("SELECT INTERVAL '1 YEAR'"); - verified_only_select("SELECT INTERVAL '1 YEAR' AS one_year"); - one_statement_parses_to( + verified_only_select("SELECT INTERVAL 1 YEAR"); + verified_only_select("SELECT INTERVAL 1 MONTH"); + verified_only_select("SELECT INTERVAL 1 DAY"); + verified_only_select("SELECT INTERVAL 1 HOUR"); + verified_only_select("SELECT INTERVAL 1 MINUTE"); + verified_only_select("SELECT INTERVAL 1 SECOND"); +} + +#[test] +fn parse_interval_dont_require_unit() { + let dialects = all_dialects_except(|d| d.require_interval_qualifier()); + + let sql = "SELECT INTERVAL '1 DAY'"; + let select = dialects.verified_only_select(sql); + assert_eq!( + &Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString(String::from( + "1 DAY" + )))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }), + expr_from_projection(only(&select.projection)), + ); + dialects.verified_only_select("SELECT INTERVAL '1 YEAR'"); + dialects.verified_only_select("SELECT INTERVAL '1 MONTH'"); + dialects.verified_only_select("SELECT INTERVAL '1 DAY'"); + dialects.verified_only_select("SELECT INTERVAL '1 HOUR'"); + dialects.verified_only_select("SELECT INTERVAL '1 MINUTE'"); + dialects.verified_only_select("SELECT INTERVAL '1 SECOND'"); +} + +#[test] +fn parse_interval_require_unit() { + let dialects = all_dialects_where(|d| d.require_interval_qualifier()); + + let sql = "SELECT INTERVAL '1 DAY'"; + let err = dialects.parse_sql_statements(sql).unwrap_err(); + assert_eq!( + err.to_string(), + "sql parser error: INTERVAL requires a unit after the literal value" + ) +} + +#[test] +fn parse_interval_require_qualifier() { + let dialects = all_dialects_where(|d| d.require_interval_qualifier()); + + let sql = "SELECT INTERVAL 1 + 1 DAY"; + let select = dialects.verified_only_select(sql); + assert_eq!( + expr_from_projection(only(&select.projection)), + &Expr::Interval(Interval { + value: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Value(number("1"))), + op: BinaryOperator::Plus, + right: Box::new(Expr::Value(number("1"))), + }), + leading_field: Some(DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }), + ); + + let sql = "SELECT INTERVAL '1' + '1' DAY"; + let select = dialects.verified_only_select(sql); + assert_eq!( + expr_from_projection(only(&select.projection)), + &Expr::Interval(Interval { + value: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Value(Value::SingleQuotedString("1".to_string()))), + op: BinaryOperator::Plus, + right: Box::new(Expr::Value(Value::SingleQuotedString("1".to_string()))), + }), + leading_field: Some(DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }), + ); + + let sql = "SELECT INTERVAL '1' + '2' - '3' DAY"; + let select = dialects.verified_only_select(sql); + assert_eq!( + expr_from_projection(only(&select.projection)), + &Expr::Interval(Interval { + value: Box::new(Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Value(Value::SingleQuotedString("1".to_string()))), + op: BinaryOperator::Plus, + right: Box::new(Expr::Value(Value::SingleQuotedString("2".to_string()))), + }), + op: BinaryOperator::Minus, + right: Box::new(Expr::Value(Value::SingleQuotedString("3".to_string()))), + }), + leading_field: Some(DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }), + ); +} + +#[test] +fn parse_interval_disallow_interval_expr() { + let dialects = all_dialects_except(|d| d.require_interval_qualifier()); + + let sql = "SELECT INTERVAL '1 DAY'"; + let select = dialects.verified_only_select(sql); + assert_eq!( + expr_from_projection(only(&select.projection)), + &Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString(String::from( + "1 DAY" + )))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }), + ); + + dialects.verified_only_select("SELECT INTERVAL '1 YEAR'"); + dialects.verified_only_select("SELECT INTERVAL '1 YEAR' AS one_year"); + dialects.one_statement_parses_to( "SELECT INTERVAL '1 YEAR' one_year", "SELECT INTERVAL '1 YEAR' AS one_year", ); + + let sql = "SELECT INTERVAL '1 DAY' > INTERVAL '1 SECOND'"; + let select = dialects.verified_only_select(sql); + assert_eq!( + expr_from_projection(only(&select.projection)), + &Expr::BinaryOp { + left: Box::new(Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString(String::from( + "1 DAY" + )))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + })), + op: BinaryOperator::Gt, + right: Box::new(Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString(String::from( + "1 SECOND" + )))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + })) + } + ); +} + +#[test] +fn interval_disallow_interval_expr_gt() { + let dialects = all_dialects_except(|d| d.require_interval_qualifier()); + let expr = dialects.verified_expr("INTERVAL '1 second' > x"); + assert_eq!( + expr, + Expr::BinaryOp { + left: Box::new(Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString( + "1 second".to_string() + ))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + },)), + op: BinaryOperator::Gt, + right: Box::new(Expr::Identifier(Ident { + value: "x".to_string(), + quote_style: None, + })), + } + ) +} + +#[test] +fn interval_disallow_interval_expr_double_colon() { + let dialects = all_dialects_except(|d| d.require_interval_qualifier()); + let expr = dialects.verified_expr("INTERVAL '1 second'::TEXT"); + assert_eq!( + expr, + Expr::Cast { + kind: CastKind::DoubleColon, + expr: Box::new(Expr::Interval(Interval { + value: Box::new(Expr::Value(Value::SingleQuotedString( + "1 second".to_string() + ))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + })), + data_type: DataType::Text, + format: None, + } + ) } #[test] @@ -5038,7 +5209,8 @@ fn parse_interval_and_or_xor() { WHERE d3_date > d1_date + INTERVAL '5 days' \ AND d2_date > d1_date + INTERVAL '3 days'"; - let actual_ast = Parser::parse_sql(&GenericDialect {}, sql).unwrap(); + let dialects = all_dialects_except(|d| d.require_interval_qualifier()); + let actual_ast = dialects.parse_sql_statements(sql).unwrap(); let expected_ast = vec![Statement::Query(Box::new(Query { with: None, @@ -5140,19 +5312,19 @@ fn parse_interval_and_or_xor() { assert_eq!(actual_ast, expected_ast); - verified_stmt( + dialects.verified_stmt( "SELECT col FROM test \ WHERE d3_date > d1_date + INTERVAL '5 days' \ AND d2_date > d1_date + INTERVAL '3 days'", ); - verified_stmt( + dialects.verified_stmt( "SELECT col FROM test \ WHERE d3_date > d1_date + INTERVAL '5 days' \ OR d2_date > d1_date + INTERVAL '3 days'", ); - verified_stmt( + dialects.verified_stmt( "SELECT col FROM test \ WHERE d3_date > d1_date + INTERVAL '5 days' \ XOR d2_date > d1_date + INTERVAL '3 days'",