diff --git a/src/dialect/generic_sql.rs b/src/dialect/generic_sql.rs index 5f38c93c..0d358277 100644 --- a/src/dialect/generic_sql.rs +++ b/src/dialect/generic_sql.rs @@ -11,7 +11,7 @@ impl Dialect for GenericSqlDialect { STORED, CSV, PARQUET, LOCATION, WITH, WITHOUT, HEADER, ROW, // SQL types CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, - BOOLEAN, DATE, TIME, TIMESTAMP, + BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, ]; } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index f5afd3f3..e8682819 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -14,6 +14,7 @@ impl Dialect for PostgreSqlDialect { DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN, DATE, TIME, TIMESTAMP, VALUES, DEFAULT, ZONE, REGCLASS, TEXT, BYTEA, TRUE, FALSE, COPY, STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, + CASE, WHEN, THEN, ELSE, END, ]; } diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index a59ac9c7..6b28f245 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -62,6 +62,13 @@ pub enum ASTNode { SQLValue(Value), /// Scalar function call e.g. `LEFT(foo, 5)` SQLFunction { id: String, args: Vec }, + /// CASE [] WHEN THEN ... [ELSE ] END + SQLCase { + // TODO: support optional operand for "simple case" + conditions: Vec, + results: Vec, + else_result: Option>, + }, /// SELECT SQLSelect { /// projection expressions @@ -160,6 +167,19 @@ impl ToString for ASTNode { .collect::>() .join(", ") ), + ASTNode::SQLCase { conditions, results, else_result } => { + let mut s = format!( + "CASE {}", + conditions.iter().zip(results) + .map(|(c, r)| format!("WHEN {} THEN {}", c.to_string(), r.to_string())) + .collect::>() + .join(" ") + ); + if let Some(else_result) = else_result { + s += &format!(" ELSE {}", else_result.to_string()) + } + s + " END" + }, ASTNode::SQLSelect { projection, relation, diff --git a/src/sqlparser.rs b/src/sqlparser.rs index d97dc998..d694c11c 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -109,6 +109,9 @@ impl Parser { "NULL" => { self.prev_token(); self.parse_sql_value() + }, + "CASE" => { + self.parse_case_expression() } _ => return parser_err!(format!("No prefix parser for keyword {}", k)), }, @@ -196,6 +199,40 @@ impl Parser { } } + pub fn parse_case_expression(&mut self) -> Result { + if self.parse_keywords(vec!["WHEN"]) { + let mut conditions = vec![]; + let mut results = vec![]; + let mut else_result = None; + loop { + conditions.push(self.parse_expr(0)?); + self.consume_token(&Token::Keyword("THEN".to_string()))?; + results.push(self.parse_expr(0)?); + if self.parse_keywords(vec!["ELSE"]) { + else_result = Some(Box::new(self.parse_expr(0)?)); + if self.parse_keywords(vec!["END"]) { + break + } else { + return parser_err!("Expecting END after a CASE..ELSE"); + } + } + if self.parse_keywords(vec!["END"]) { + break + } + self.consume_token(&Token::Keyword("WHEN".to_string()))?; + } + Ok(ASTNode::SQLCase { + conditions, + results, + else_result + }) + } else { + // TODO: implement "simple" case + // https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-case + parser_err!("Simple case not implemented") + } + } + /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expression(&mut self) -> Result { self.consume_token(&Token::LParen)?; diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_generic.rs index 50cd2af9..910678e2 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_generic.rs @@ -396,6 +396,37 @@ fn parse_parens() { , ast); } +#[test] +fn parse_case_expression() { + let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo"; + let ast = parse_sql(&sql); + assert_eq!(sql, ast.to_string()); + use self::ASTNode::*; + use self::SQLOperator::*; + match ast { + ASTNode::SQLSelect { projection, .. } => { + assert_eq!(1, projection.len()); + assert_eq!( + SQLCase { + conditions: vec![ + SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))), + SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())), + op: Eq, right: Box::new(SQLValue(Value::Long(0))) }, + SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())), + op: GtEq, right: Box::new(SQLValue(Value::Long(0))) } + ], + results: vec![SQLValue(Value::SingleQuotedString("null".to_string())), + SQLValue(Value::SingleQuotedString("=0".to_string())), + SQLValue(Value::SingleQuotedString(">=0".to_string()))], + else_result: Some(Box::new(SQLValue(Value::SingleQuotedString("<0".to_string())))) + }, + projection[0] + ); + } + _ => assert!(false), + } +} + fn parse_sql(sql: &str) -> ASTNode { let dialect = GenericSqlDialect {}; let mut tokenizer = Tokenizer::new(&dialect,&sql, );