From ea0eb1be617635b29d49d712be22a9a568d02366 Mon Sep 17 00:00:00 2001 From: Alex Vasilev Date: Thu, 23 Dec 2021 16:50:33 +0300 Subject: [PATCH] fix: mysql backslash escaping (#373) * fix: mysql backslash escaping * fixes --- .gitignore | 1 + src/ast/mod.rs | 2 +- src/dialect/mod.rs | 2 +- src/tokenizer.rs | 22 +++++++++++++++---- tests/sqlparser_mysql.rs | 46 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 67 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index dcc3cbd9..6dfd81a6 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ Cargo.lock # IDEs .idea +.vscode diff --git a/src/ast/mod.rs b/src/ast/mod.rs index cf439ef4..477a2289 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -11,7 +11,6 @@ // limitations under the License. //! SQL Abstract Syntax Tree (AST) types - mod data_type; mod ddl; mod operator; @@ -1874,6 +1873,7 @@ pub enum HiveRowFormat { DELIMITED, } +#[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[allow(clippy::large_enum_variant)] diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 5f1fcb0b..008b099d 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -35,7 +35,7 @@ pub use self::sqlite::SQLiteDialect; pub use crate::keywords; /// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates -/// to `true` iff `parser.dialect` is one of the `Dialect`s specified. +/// to `true` if `parser.dialect` is one of the `Dialect`s specified. macro_rules! dialect_of { ( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => { ($($parsed_dialect.dialect.is::<$dialect_type>())||+) diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 7fc44194..296bcc64 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -31,8 +31,8 @@ use core::str::Chars; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::dialect::Dialect; use crate::dialect::SnowflakeDialect; +use crate::dialect::{Dialect, MySqlDialect}; use crate::keywords::{Keyword, ALL_KEYWORDS, ALL_KEYWORDS_INDEX}; /// SQL Token enumeration @@ -411,6 +411,7 @@ impl<'a> Tokenizer<'a> { // string '\'' => { let s = self.tokenize_single_quoted_string(chars)?; + Ok(Some(Token::SingleQuotedString(s))) } // delimited (quoted) identifier @@ -636,18 +637,31 @@ impl<'a> Tokenizer<'a> { ) -> Result { let mut s = String::new(); chars.next(); // consume the opening quote + + // slash escaping is specific to MySQL dialect + let mut is_escaped = false; while let Some(&ch) = chars.peek() { match ch { '\'' => { chars.next(); // consume - let escaped_quote = chars.peek().map(|c| *c == '\'').unwrap_or(false); - if escaped_quote { - s.push('\''); + if is_escaped { + s.push(ch); + is_escaped = false; + } else if chars.peek().map(|c| *c == '\'').unwrap_or(false) { + s.push(ch); chars.next(); } else { return Ok(s); } } + '\\' => { + if dialect_of!(self is MySqlDialect) { + is_escaped = !is_escaped; + } else { + s.push(ch); + } + chars.next(); + } _ => { chars.next(); // consume s.push(ch); diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 8e32ab49..f67d05c3 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -19,6 +19,8 @@ mod test_utils; use test_utils::*; +use sqlparser::ast::Expr; +use sqlparser::ast::Value; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, MySqlDialect}; use sqlparser::tokenizer::Token; @@ -176,6 +178,50 @@ fn parse_quote_identifiers() { } } +#[test] +fn parse_unterminated_escape() { + let sql = r#"SELECT 'I\'m not fine\'"#; + let result = std::panic::catch_unwind(|| mysql().one_statement_parses_to(sql, "")); + assert!(result.is_err()); + + let sql = r#"SELECT 'I\\'m not fine'"#; + let result = std::panic::catch_unwind(|| mysql().one_statement_parses_to(sql, "")); + assert!(result.is_err()); +} + +#[test] +fn parse_escaped_string() { + let sql = r#"SELECT 'I\'m fine'"#; + + let stmt = mysql().one_statement_parses_to(sql, ""); + + match stmt { + Statement::Query(query) => match query.body { + SetExpr::Select(value) => { + let expr = expr_from_projection(only(&value.projection)); + assert_eq!( + *expr, + Expr::Value(Value::SingleQuotedString("I'm fine".to_string())) + ); + } + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let sql = r#"SELECT 'I''m fine'"#; + + let projection = mysql().verified_only_select(sql).projection; + let item = projection.get(0).unwrap(); + + match &item { + SelectItem::UnnamedExpr(Expr::Value(value)) => { + assert_eq!(*value, Value::SingleQuotedString("I'm fine".to_string())); + } + _ => unreachable!(), + } +} + #[test] fn parse_create_table_with_minimum_display_width() { let sql = "CREATE TABLE foo (bar_tinyint TINYINT(3), bar_smallint SMALLINT(5), bar_int INT(11), bar_bigint BIGINT(20))";