feat: Support expression in SET statement (#574)

Co-authored-by: Alex Vasilev <vaspiring@gmail.com>
This commit is contained in:
Dmitry Patsura 2022-08-18 20:29:55 +03:00 committed by GitHub
parent eb7f1b005e
commit 6d8aacd85b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 52 deletions

View file

@ -545,8 +545,10 @@ impl fmt::Display for Expr {
Expr::UnaryOp { op, expr } => { Expr::UnaryOp { op, expr } => {
if op == &UnaryOperator::PGPostfixFactorial { if op == &UnaryOperator::PGPostfixFactorial {
write!(f, "{}{}", expr, op) write!(f, "{}{}", expr, op)
} else { } else if op == &UnaryOperator::Not {
write!(f, "{} {}", op, expr) write!(f, "{} {}", op, expr)
} else {
write!(f, "{}{}", op, expr)
} }
} }
Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type), Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type),
@ -1100,7 +1102,7 @@ pub enum Statement {
local: bool, local: bool,
hivevar: bool, hivevar: bool,
variable: ObjectName, variable: ObjectName,
value: Vec<SetVariableValue>, value: Vec<Expr>,
}, },
/// SET NAMES 'charset_name' [COLLATE 'collation_name'] /// SET NAMES 'charset_name' [COLLATE 'collation_name']
/// ///
@ -2745,23 +2747,6 @@ impl fmt::Display for ShowStatementFilter {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum SetVariableValue {
Ident(Ident),
Literal(Value),
}
impl fmt::Display for SetVariableValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use SetVariableValue::*;
match self {
Ident(ident) => write!(f, "{}", ident),
Literal(literal) => write!(f, "{}", literal),
}
}
}
/// Sqlite specific syntax /// Sqlite specific syntax
/// ///
/// https://sqlite.org/lang_conflict.html /// https://sqlite.org/lang_conflict.html

View file

@ -24,6 +24,7 @@ impl Dialect for MySqlDialect {
|| ('A'..='Z').contains(&ch) || ('A'..='Z').contains(&ch)
|| ch == '_' || ch == '_'
|| ch == '$' || ch == '$'
|| ch == '@'
|| ('\u{0080}'..='\u{ffff}').contains(&ch) || ('\u{0080}'..='\u{ffff}').contains(&ch)
} }

View file

@ -3751,22 +3751,12 @@ impl<'a> Parser<'a> {
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { } else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
let mut values = vec![]; let mut values = vec![];
loop { loop {
let token = self.peek_token(); let value = if let Ok(expr) = self.parse_expr() {
let value = match (self.parse_value(), token) { expr
(Ok(value), _) => SetVariableValue::Literal(value), } else {
(Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()), self.expected("variable value", self.peek_token())?
(Err(_), Token::Minus) => {
let next_token = self.next_token();
match next_token {
Token::Word(ident) => SetVariableValue::Ident(Ident {
quote_style: ident.quote_style,
value: format!("-{}", ident.value),
}),
_ => self.expected("word", next_token)?,
}
}
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
}; };
values.push(value); values.push(value);
if self.consume_token(&Token::Comma) { if self.consume_token(&Token::Comma) {
continue; continue;

View file

@ -580,7 +580,7 @@ fn parse_select_count_wildcard() {
#[test] #[test]
fn parse_select_count_distinct() { fn parse_select_count_distinct() {
let sql = "SELECT COUNT(DISTINCT + x) FROM customer"; let sql = "SELECT COUNT(DISTINCT +x) FROM customer";
let select = verified_only_select(sql); let select = verified_only_select(sql);
assert_eq!( assert_eq!(
&Expr::Function(Function { &Expr::Function(Function {
@ -597,8 +597,8 @@ fn parse_select_count_distinct() {
); );
one_statement_parses_to( one_statement_parses_to(
"SELECT COUNT(ALL + x) FROM customer", "SELECT COUNT(ALL +x) FROM customer",
"SELECT COUNT(+ x) FROM customer", "SELECT COUNT(+x) FROM customer",
); );
let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer"; let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer";
@ -754,7 +754,7 @@ fn parse_compound_expr_2() {
#[test] #[test]
fn parse_unary_math() { fn parse_unary_math() {
use self::Expr::*; use self::Expr::*;
let sql = "- a + - b"; let sql = "-a + -b";
assert_eq!( assert_eq!(
BinaryOp { BinaryOp {
left: Box::new(UnaryOp { left: Box::new(UnaryOp {

View file

@ -15,7 +15,7 @@
//! Test SQL syntax specific to Hive. The parser based on the generic dialect //! Test SQL syntax specific to Hive. The parser based on the generic dialect
//! is also tested (on the inputs it can handle). //! is also tested (on the inputs it can handle).
use sqlparser::ast::{CreateFunctionUsing, Ident, ObjectName, SetVariableValue, Statement}; use sqlparser::ast::{CreateFunctionUsing, Expr, Ident, ObjectName, Statement, UnaryOperator};
use sqlparser::dialect::{GenericDialect, HiveDialect}; use sqlparser::dialect::{GenericDialect, HiveDialect};
use sqlparser::parser::ParserError; use sqlparser::parser::ParserError;
use sqlparser::test_utils::*; use sqlparser::test_utils::*;
@ -220,14 +220,17 @@ fn set_statement_with_minus() {
Ident::new("java"), Ident::new("java"),
Ident::new("opts") Ident::new("opts")
]), ]),
value: vec![SetVariableValue::Ident("-Xmx4g".into())], value: vec![Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Identifier(Ident::new("Xmx4g")))
}],
} }
); );
assert_eq!( assert_eq!(
hive().parse_sql_statements("SET hive.tez.java.opts = -"), hive().parse_sql_statements("SET hive.tez.java.opts = -"),
Err(ParserError::ParserError( Err(ParserError::ParserError(
"Expected word, found: EOF".to_string() "Expected variable value, found: EOF".to_string()
)) ))
) )
} }

View file

@ -251,6 +251,26 @@ fn parse_use() {
); );
} }
#[test]
fn parse_set_variables() {
mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')");
assert_eq!(
mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"),
Statement::SetVariable {
local: true,
hivevar: false,
variable: ObjectName(vec!["autocommit".into()]),
value: vec![Expr::Value(Value::Number(
#[cfg(not(feature = "bigdecimal"))]
"1".to_string(),
#[cfg(feature = "bigdecimal")]
bigdecimal::BigDecimal::from(1),
false
))],
}
);
}
#[test] #[test]
fn parse_create_table_auto_increment() { fn parse_create_table_auto_increment() {
let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)"; let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)";

View file

@ -18,7 +18,6 @@
mod test_utils; mod test_utils;
use test_utils::*; use test_utils::*;
use sqlparser::ast::Value::Boolean;
use sqlparser::ast::*; use sqlparser::ast::*;
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect}; use sqlparser::dialect::{GenericDialect, PostgreSqlDialect};
use sqlparser::parser::ParserError; use sqlparser::parser::ParserError;
@ -782,7 +781,10 @@ fn parse_set() {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("b".into())], value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
})],
} }
); );
@ -793,9 +795,7 @@ fn parse_set() {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Literal(Value::SingleQuotedString( value: vec![Expr::Value(Value::SingleQuotedString("b".into()))],
"b".into()
))],
} }
); );
@ -806,7 +806,13 @@ fn parse_set() {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Literal(number("0"))], value: vec![Expr::Value(Value::Number(
#[cfg(not(feature = "bigdecimal"))]
"0".to_string(),
#[cfg(feature = "bigdecimal")]
bigdecimal::BigDecimal::from(0),
false,
))],
} }
); );
@ -817,7 +823,10 @@ fn parse_set() {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("DEFAULT".into())], value: vec![Expr::Identifier(Ident {
value: "DEFAULT".into(),
quote_style: None
})],
} }
); );
@ -828,7 +837,7 @@ fn parse_set() {
local: true, local: true,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a")]), variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("b".into())], value: vec![Expr::Identifier("b".into())],
} }
); );
@ -839,7 +848,10 @@ fn parse_set() {
local: false, local: false,
hivevar: false, hivevar: false,
variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]), variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]),
value: vec![SetVariableValue::Ident("b".into())], value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
})],
} }
); );
@ -859,7 +871,7 @@ fn parse_set() {
Ident::new("reducer"), Ident::new("reducer"),
Ident::new("parallelism") Ident::new("parallelism")
]), ]),
value: vec![SetVariableValue::Literal(Boolean(false))], value: vec![Expr::Value(Value::Boolean(false))],
} }
); );
@ -1107,7 +1119,7 @@ fn parse_pg_unary_ops() {
]; ];
for (str_op, op) in pg_unary_ops { for (str_op, op) in pg_unary_ops {
let select = pg().verified_only_select(&format!("SELECT {} a", &str_op)); let select = pg().verified_only_select(&format!("SELECT {}a", &str_op));
assert_eq!( assert_eq!(
SelectItem::UnnamedExpr(Expr::UnaryOp { SelectItem::UnnamedExpr(Expr::UnaryOp {
op: op.clone(), op: op.clone(),