fix: mysql backslash escaping (#373)

* fix: mysql backslash escaping

* fixes
This commit is contained in:
Alex Vasilev 2021-12-23 16:50:33 +03:00 committed by GitHub
parent 60ad78dafc
commit ea0eb1be61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 6 deletions

1
.gitignore vendored
View file

@ -12,3 +12,4 @@ Cargo.lock
# IDEs # IDEs
.idea .idea
.vscode

View file

@ -11,7 +11,6 @@
// limitations under the License. // limitations under the License.
//! SQL Abstract Syntax Tree (AST) types //! SQL Abstract Syntax Tree (AST) types
mod data_type; mod data_type;
mod ddl; mod ddl;
mod operator; mod operator;
@ -1874,6 +1873,7 @@ pub enum HiveRowFormat {
DELIMITED, DELIMITED,
} }
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]

View file

@ -35,7 +35,7 @@ pub use self::sqlite::SQLiteDialect;
pub use crate::keywords; pub use crate::keywords;
/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates /// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` iff `parser.dialect` is one of the `Dialect`s specified. /// to `true` if `parser.dialect` is one of the `Dialect`s specified.
macro_rules! dialect_of { macro_rules! dialect_of {
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => { ( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
($($parsed_dialect.dialect.is::<$dialect_type>())||+) ($($parsed_dialect.dialect.is::<$dialect_type>())||+)

View file

@ -31,8 +31,8 @@ use core::str::Chars;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::dialect::Dialect;
use crate::dialect::SnowflakeDialect; use crate::dialect::SnowflakeDialect;
use crate::dialect::{Dialect, MySqlDialect};
use crate::keywords::{Keyword, ALL_KEYWORDS, ALL_KEYWORDS_INDEX}; use crate::keywords::{Keyword, ALL_KEYWORDS, ALL_KEYWORDS_INDEX};
/// SQL Token enumeration /// SQL Token enumeration
@ -411,6 +411,7 @@ impl<'a> Tokenizer<'a> {
// string // string
'\'' => { '\'' => {
let s = self.tokenize_single_quoted_string(chars)?; let s = self.tokenize_single_quoted_string(chars)?;
Ok(Some(Token::SingleQuotedString(s))) Ok(Some(Token::SingleQuotedString(s)))
} }
// delimited (quoted) identifier // delimited (quoted) identifier
@ -636,18 +637,31 @@ impl<'a> Tokenizer<'a> {
) -> Result<String, TokenizerError> { ) -> Result<String, TokenizerError> {
let mut s = String::new(); let mut s = String::new();
chars.next(); // consume the opening quote chars.next(); // consume the opening quote
// slash escaping is specific to MySQL dialect
let mut is_escaped = false;
while let Some(&ch) = chars.peek() { while let Some(&ch) = chars.peek() {
match ch { match ch {
'\'' => { '\'' => {
chars.next(); // consume chars.next(); // consume
let escaped_quote = chars.peek().map(|c| *c == '\'').unwrap_or(false); if is_escaped {
if escaped_quote { s.push(ch);
s.push('\''); is_escaped = false;
} else if chars.peek().map(|c| *c == '\'').unwrap_or(false) {
s.push(ch);
chars.next(); chars.next();
} else { } else {
return Ok(s); return Ok(s);
} }
} }
'\\' => {
if dialect_of!(self is MySqlDialect) {
is_escaped = !is_escaped;
} else {
s.push(ch);
}
chars.next();
}
_ => { _ => {
chars.next(); // consume chars.next(); // consume
s.push(ch); s.push(ch);

View file

@ -19,6 +19,8 @@ mod test_utils;
use test_utils::*; use test_utils::*;
use sqlparser::ast::Expr;
use sqlparser::ast::Value;
use sqlparser::ast::*; use sqlparser::ast::*;
use sqlparser::dialect::{GenericDialect, MySqlDialect}; use sqlparser::dialect::{GenericDialect, MySqlDialect};
use sqlparser::tokenizer::Token; use sqlparser::tokenizer::Token;
@ -176,6 +178,50 @@ fn parse_quote_identifiers() {
} }
} }
#[test]
fn parse_unterminated_escape() {
let sql = r#"SELECT 'I\'m not fine\'"#;
let result = std::panic::catch_unwind(|| mysql().one_statement_parses_to(sql, ""));
assert!(result.is_err());
let sql = r#"SELECT 'I\\'m not fine'"#;
let result = std::panic::catch_unwind(|| mysql().one_statement_parses_to(sql, ""));
assert!(result.is_err());
}
#[test]
fn parse_escaped_string() {
let sql = r#"SELECT 'I\'m fine'"#;
let stmt = mysql().one_statement_parses_to(sql, "");
match stmt {
Statement::Query(query) => match query.body {
SetExpr::Select(value) => {
let expr = expr_from_projection(only(&value.projection));
assert_eq!(
*expr,
Expr::Value(Value::SingleQuotedString("I'm fine".to_string()))
);
}
_ => unreachable!(),
},
_ => unreachable!(),
};
let sql = r#"SELECT 'I''m fine'"#;
let projection = mysql().verified_only_select(sql).projection;
let item = projection.get(0).unwrap();
match &item {
SelectItem::UnnamedExpr(Expr::Value(value)) => {
assert_eq!(*value, Value::SingleQuotedString("I'm fine".to_string()));
}
_ => unreachable!(),
}
}
#[test] #[test]
fn parse_create_table_with_minimum_display_width() { fn parse_create_table_with_minimum_display_width() {
let sql = "CREATE TABLE foo (bar_tinyint TINYINT(3), bar_smallint SMALLINT(5), bar_int INT(11), bar_bigint BIGINT(20))"; let sql = "CREATE TABLE foo (bar_tinyint TINYINT(3), bar_smallint SMALLINT(5), bar_int INT(11), bar_bigint BIGINT(20))";