mirror of
https://github.com/apache/datafusion-sqlparser-rs.git
synced 2025-08-04 14:28:22 +00:00
feat: Support expression in SET statement (#574)
Co-authored-by: Alex Vasilev <vaspiring@gmail.com>
This commit is contained in:
parent
eb7f1b005e
commit
6d8aacd85b
7 changed files with 63 additions and 52 deletions
|
@ -545,8 +545,10 @@ impl fmt::Display for Expr {
|
|||
Expr::UnaryOp { op, expr } => {
|
||||
if op == &UnaryOperator::PGPostfixFactorial {
|
||||
write!(f, "{}{}", expr, op)
|
||||
} else {
|
||||
} else if op == &UnaryOperator::Not {
|
||||
write!(f, "{} {}", op, expr)
|
||||
} else {
|
||||
write!(f, "{}{}", op, expr)
|
||||
}
|
||||
}
|
||||
Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type),
|
||||
|
@ -1100,7 +1102,7 @@ pub enum Statement {
|
|||
local: bool,
|
||||
hivevar: bool,
|
||||
variable: ObjectName,
|
||||
value: Vec<SetVariableValue>,
|
||||
value: Vec<Expr>,
|
||||
},
|
||||
/// SET NAMES 'charset_name' [COLLATE 'collation_name']
|
||||
///
|
||||
|
@ -2745,23 +2747,6 @@ impl fmt::Display for ShowStatementFilter {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum SetVariableValue {
|
||||
Ident(Ident),
|
||||
Literal(Value),
|
||||
}
|
||||
|
||||
impl fmt::Display for SetVariableValue {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
use SetVariableValue::*;
|
||||
match self {
|
||||
Ident(ident) => write!(f, "{}", ident),
|
||||
Literal(literal) => write!(f, "{}", literal),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sqlite specific syntax
|
||||
///
|
||||
/// https://sqlite.org/lang_conflict.html
|
||||
|
|
|
@ -24,6 +24,7 @@ impl Dialect for MySqlDialect {
|
|||
|| ('A'..='Z').contains(&ch)
|
||||
|| ch == '_'
|
||||
|| ch == '$'
|
||||
|| ch == '@'
|
||||
|| ('\u{0080}'..='\u{ffff}').contains(&ch)
|
||||
}
|
||||
|
||||
|
|
|
@ -3751,22 +3751,12 @@ impl<'a> Parser<'a> {
|
|||
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
|
||||
let mut values = vec![];
|
||||
loop {
|
||||
let token = self.peek_token();
|
||||
let value = match (self.parse_value(), token) {
|
||||
(Ok(value), _) => SetVariableValue::Literal(value),
|
||||
(Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()),
|
||||
(Err(_), Token::Minus) => {
|
||||
let next_token = self.next_token();
|
||||
match next_token {
|
||||
Token::Word(ident) => SetVariableValue::Ident(Ident {
|
||||
quote_style: ident.quote_style,
|
||||
value: format!("-{}", ident.value),
|
||||
}),
|
||||
_ => self.expected("word", next_token)?,
|
||||
}
|
||||
}
|
||||
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
|
||||
let value = if let Ok(expr) = self.parse_expr() {
|
||||
expr
|
||||
} else {
|
||||
self.expected("variable value", self.peek_token())?
|
||||
};
|
||||
|
||||
values.push(value);
|
||||
if self.consume_token(&Token::Comma) {
|
||||
continue;
|
||||
|
|
|
@ -580,7 +580,7 @@ fn parse_select_count_wildcard() {
|
|||
|
||||
#[test]
|
||||
fn parse_select_count_distinct() {
|
||||
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
|
||||
let sql = "SELECT COUNT(DISTINCT +x) FROM customer";
|
||||
let select = verified_only_select(sql);
|
||||
assert_eq!(
|
||||
&Expr::Function(Function {
|
||||
|
@ -597,8 +597,8 @@ fn parse_select_count_distinct() {
|
|||
);
|
||||
|
||||
one_statement_parses_to(
|
||||
"SELECT COUNT(ALL + x) FROM customer",
|
||||
"SELECT COUNT(+ x) FROM customer",
|
||||
"SELECT COUNT(ALL +x) FROM customer",
|
||||
"SELECT COUNT(+x) FROM customer",
|
||||
);
|
||||
|
||||
let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer";
|
||||
|
@ -754,7 +754,7 @@ fn parse_compound_expr_2() {
|
|||
#[test]
|
||||
fn parse_unary_math() {
|
||||
use self::Expr::*;
|
||||
let sql = "- a + - b";
|
||||
let sql = "-a + -b";
|
||||
assert_eq!(
|
||||
BinaryOp {
|
||||
left: Box::new(UnaryOp {
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
//! Test SQL syntax specific to Hive. The parser based on the generic dialect
|
||||
//! is also tested (on the inputs it can handle).
|
||||
|
||||
use sqlparser::ast::{CreateFunctionUsing, Ident, ObjectName, SetVariableValue, Statement};
|
||||
use sqlparser::ast::{CreateFunctionUsing, Expr, Ident, ObjectName, Statement, UnaryOperator};
|
||||
use sqlparser::dialect::{GenericDialect, HiveDialect};
|
||||
use sqlparser::parser::ParserError;
|
||||
use sqlparser::test_utils::*;
|
||||
|
@ -220,14 +220,17 @@ fn set_statement_with_minus() {
|
|||
Ident::new("java"),
|
||||
Ident::new("opts")
|
||||
]),
|
||||
value: vec![SetVariableValue::Ident("-Xmx4g".into())],
|
||||
value: vec![Expr::UnaryOp {
|
||||
op: UnaryOperator::Minus,
|
||||
expr: Box::new(Expr::Identifier(Ident::new("Xmx4g")))
|
||||
}],
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
hive().parse_sql_statements("SET hive.tez.java.opts = -"),
|
||||
Err(ParserError::ParserError(
|
||||
"Expected word, found: EOF".to_string()
|
||||
"Expected variable value, found: EOF".to_string()
|
||||
))
|
||||
)
|
||||
}
|
||||
|
|
|
@ -251,6 +251,26 @@ fn parse_use() {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_set_variables() {
|
||||
mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')");
|
||||
assert_eq!(
|
||||
mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"),
|
||||
Statement::SetVariable {
|
||||
local: true,
|
||||
hivevar: false,
|
||||
variable: ObjectName(vec!["autocommit".into()]),
|
||||
value: vec![Expr::Value(Value::Number(
|
||||
#[cfg(not(feature = "bigdecimal"))]
|
||||
"1".to_string(),
|
||||
#[cfg(feature = "bigdecimal")]
|
||||
bigdecimal::BigDecimal::from(1),
|
||||
false
|
||||
))],
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_create_table_auto_increment() {
|
||||
let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)";
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
mod test_utils;
|
||||
use test_utils::*;
|
||||
|
||||
use sqlparser::ast::Value::Boolean;
|
||||
use sqlparser::ast::*;
|
||||
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect};
|
||||
use sqlparser::parser::ParserError;
|
||||
|
@ -782,7 +781,10 @@ fn parse_set() {
|
|||
local: false,
|
||||
hivevar: false,
|
||||
variable: ObjectName(vec![Ident::new("a")]),
|
||||
value: vec![SetVariableValue::Ident("b".into())],
|
||||
value: vec![Expr::Identifier(Ident {
|
||||
value: "b".into(),
|
||||
quote_style: None
|
||||
})],
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -793,9 +795,7 @@ fn parse_set() {
|
|||
local: false,
|
||||
hivevar: false,
|
||||
variable: ObjectName(vec![Ident::new("a")]),
|
||||
value: vec![SetVariableValue::Literal(Value::SingleQuotedString(
|
||||
"b".into()
|
||||
))],
|
||||
value: vec![Expr::Value(Value::SingleQuotedString("b".into()))],
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -806,7 +806,13 @@ fn parse_set() {
|
|||
local: false,
|
||||
hivevar: false,
|
||||
variable: ObjectName(vec![Ident::new("a")]),
|
||||
value: vec![SetVariableValue::Literal(number("0"))],
|
||||
value: vec![Expr::Value(Value::Number(
|
||||
#[cfg(not(feature = "bigdecimal"))]
|
||||
"0".to_string(),
|
||||
#[cfg(feature = "bigdecimal")]
|
||||
bigdecimal::BigDecimal::from(0),
|
||||
false,
|
||||
))],
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -817,7 +823,10 @@ fn parse_set() {
|
|||
local: false,
|
||||
hivevar: false,
|
||||
variable: ObjectName(vec![Ident::new("a")]),
|
||||
value: vec![SetVariableValue::Ident("DEFAULT".into())],
|
||||
value: vec![Expr::Identifier(Ident {
|
||||
value: "DEFAULT".into(),
|
||||
quote_style: None
|
||||
})],
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -828,7 +837,7 @@ fn parse_set() {
|
|||
local: true,
|
||||
hivevar: false,
|
||||
variable: ObjectName(vec![Ident::new("a")]),
|
||||
value: vec![SetVariableValue::Ident("b".into())],
|
||||
value: vec![Expr::Identifier("b".into())],
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -839,7 +848,10 @@ fn parse_set() {
|
|||
local: false,
|
||||
hivevar: false,
|
||||
variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]),
|
||||
value: vec![SetVariableValue::Ident("b".into())],
|
||||
value: vec![Expr::Identifier(Ident {
|
||||
value: "b".into(),
|
||||
quote_style: None
|
||||
})],
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -859,7 +871,7 @@ fn parse_set() {
|
|||
Ident::new("reducer"),
|
||||
Ident::new("parallelism")
|
||||
]),
|
||||
value: vec![SetVariableValue::Literal(Boolean(false))],
|
||||
value: vec![Expr::Value(Value::Boolean(false))],
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -1107,7 +1119,7 @@ fn parse_pg_unary_ops() {
|
|||
];
|
||||
|
||||
for (str_op, op) in pg_unary_ops {
|
||||
let select = pg().verified_only_select(&format!("SELECT {} a", &str_op));
|
||||
let select = pg().verified_only_select(&format!("SELECT {}a", &str_op));
|
||||
assert_eq!(
|
||||
SelectItem::UnnamedExpr(Expr::UnaryOp {
|
||||
op: op.clone(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue