diff --git a/src/ast/mod.rs b/src/ast/mod.rs index bcebd5cb..20f504da 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -40,7 +40,7 @@ pub use self::query::{ Query, Select, SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values, With, }; -pub use self::value::{DateTimeField, Value}; +pub use self::value::{DateTimeField, TrimWhereField, Value}; struct DisplaySeparated<'a, T> where @@ -231,7 +231,7 @@ pub enum Expr { Trim { expr: Box, // ([BOTH | LEADING | TRAILING], ) - trim_where: Option<(Box, Box)>, + trim_where: Option<(TrimWhereField, Box)>, }, /// `expr COLLATE collation` Collate { diff --git a/src/ast/value.rs b/src/ast/value.rs index 1cf111da..0742fbd0 100644 --- a/src/ast/value.rs +++ b/src/ast/value.rs @@ -157,3 +157,22 @@ impl<'a> fmt::Display for EscapeSingleQuoteString<'a> { pub fn escape_single_quote_string(s: &str) -> EscapeSingleQuoteString<'_> { EscapeSingleQuoteString(s) } + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum TrimWhereField { + Both, + Leading, + Trailing, +} + +impl fmt::Display for TrimWhereField { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use TrimWhereField::*; + f.write_str(match self { + Both => "BOTH", + Leading => "LEADING", + Trailing => "TRAILING", + }) + } +} diff --git a/src/parser.rs b/src/parser.rs index 8347aa83..c6e306ab 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -666,10 +666,10 @@ impl<'a> Parser<'a> { .iter() .any(|d| word.keyword == *d) { - let ident = self.parse_identifier()?; + let trim_where = self.parse_trim_where()?; let sub_expr = self.parse_expr()?; self.expect_keyword(Keyword::FROM)?; - where_expr = Some((ident, sub_expr)) + where_expr = Some((trim_where, Box::new(sub_expr))); } } let expr = self.parse_expr()?; @@ -677,10 +677,22 @@ impl<'a> Parser<'a> { Ok(Expr::Trim { expr: Box::new(expr), - trim_where: where_expr.map(|(ident, expr)| (Box::new(ident), Box::new(expr))), + trim_where: where_expr, }) } + pub fn parse_trim_where(&mut self) -> Result { + match self.next_token() { + Token::Word(w) => match w.keyword { + Keyword::BOTH => Ok(TrimWhereField::Both), + Keyword::LEADING => Ok(TrimWhereField::Leading), + Keyword::TRAILING => Ok(TrimWhereField::Trailing), + _ => self.expected("trim_where field", Token::Word(w))?, + }, + unexpected => self.expected("trim_where field", unexpected), + } + } + /// Parse a SQL LISTAGG expression, e.g. `LISTAGG(...) WITHIN GROUP (ORDER BY ...)`. pub fn parse_listagg_expr(&mut self) -> Result { self.expect_token(&Token::LParen)?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 783b90a5..dbd28d12 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -2765,6 +2765,11 @@ fn parse_trim() { ); one_statement_parses_to("SELECT TRIM(' foo ')", "SELECT TRIM(' foo ')"); + + assert_eq!( + ParserError::ParserError("Expected ), found: 'xyz'".to_owned()), + parse_sql_statements("SELECT TRIM(FOO 'xyz' FROM 'xyzfooxyz')").unwrap_err() + ); } #[test]