Split operators by arity

It is useful downstream to have two separate enums, one for unary
operators and one for binary operators, so that the compiler can check
exhaustiveness. Otherwise downstream consumers need to manually encode
which operators are unary and which operators are binary when matching
on an Operator enum.
This commit is contained in:
Nikhil Benesch 2019-06-05 12:31:13 -04:00
parent 9e33cea9b8
commit ae25dce246
No known key found for this signature in database
GPG key ID: F7386C5DEADABA7F
4 changed files with 85 additions and 71 deletions

View file

@ -27,11 +27,10 @@ pub use self::query::{
Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect, Cte, Fetch, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect,
SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins, SQLSelectItem, SQLSetExpr, SQLSetOperator, SQLValues, TableAlias, TableFactor, TableWithJoins,
}; };
pub use self::sql_operator::{SQLBinaryOperator, SQLUnaryOperator};
pub use self::sqltype::SQLType; pub use self::sqltype::SQLType;
pub use self::value::{SQLDateTimeField, Value}; pub use self::value::{SQLDateTimeField, Value};
pub use self::sql_operator::SQLOperator;
/// Like `vec.join(", ")`, but for any types implementing ToString. /// Like `vec.join(", ")`, but for any types implementing ToString.
fn comma_separated_string<I>(iter: I) -> String fn comma_separated_string<I>(iter: I) -> String
where where
@ -92,12 +91,12 @@ pub enum ASTNode {
/// Binary operation e.g. `1 + 1` or `foo > bar` /// Binary operation e.g. `1 + 1` or `foo > bar`
SQLBinaryOp { SQLBinaryOp {
left: Box<ASTNode>, left: Box<ASTNode>,
op: SQLOperator, op: SQLBinaryOperator,
right: Box<ASTNode>, right: Box<ASTNode>,
}, },
/// Unary operation e.g. `NOT foo` /// Unary operation e.g. `NOT foo`
SQLUnaryOp { SQLUnaryOp {
op: SQLOperator, op: SQLUnaryOperator,
expr: Box<ASTNode>, expr: Box<ASTNode>,
}, },
/// CAST an expression to a different data type e.g. `CAST(foo AS VARCHAR(123))` /// CAST an expression to a different data type e.g. `CAST(foo AS VARCHAR(123))`

View file

@ -10,9 +10,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
/// SQL Operator /// Unary operators
#[derive(Debug, Clone, PartialEq, Hash)] #[derive(Debug, Clone, PartialEq, Hash)]
pub enum SQLOperator { pub enum SQLUnaryOperator {
Plus,
Minus,
Not,
}
impl ToString for SQLUnaryOperator {
fn to_string(&self) -> String {
match self {
SQLUnaryOperator::Plus => "+".to_string(),
SQLUnaryOperator::Minus => "-".to_string(),
SQLUnaryOperator::Not => "NOT".to_string(),
}
}
}
/// Binary operators
#[derive(Debug, Clone, PartialEq, Hash)]
pub enum SQLBinaryOperator {
Plus, Plus,
Minus, Minus,
Multiply, Multiply,
@ -26,30 +44,28 @@ pub enum SQLOperator {
NotEq, NotEq,
And, And,
Or, Or,
Not,
Like, Like,
NotLike, NotLike,
} }
impl ToString for SQLOperator { impl ToString for SQLBinaryOperator {
fn to_string(&self) -> String { fn to_string(&self) -> String {
match self { match self {
SQLOperator::Plus => "+".to_string(), SQLBinaryOperator::Plus => "+".to_string(),
SQLOperator::Minus => "-".to_string(), SQLBinaryOperator::Minus => "-".to_string(),
SQLOperator::Multiply => "*".to_string(), SQLBinaryOperator::Multiply => "*".to_string(),
SQLOperator::Divide => "/".to_string(), SQLBinaryOperator::Divide => "/".to_string(),
SQLOperator::Modulus => "%".to_string(), SQLBinaryOperator::Modulus => "%".to_string(),
SQLOperator::Gt => ">".to_string(), SQLBinaryOperator::Gt => ">".to_string(),
SQLOperator::Lt => "<".to_string(), SQLBinaryOperator::Lt => "<".to_string(),
SQLOperator::GtEq => ">=".to_string(), SQLBinaryOperator::GtEq => ">=".to_string(),
SQLOperator::LtEq => "<=".to_string(), SQLBinaryOperator::LtEq => "<=".to_string(),
SQLOperator::Eq => "=".to_string(), SQLBinaryOperator::Eq => "=".to_string(),
SQLOperator::NotEq => "<>".to_string(), SQLBinaryOperator::NotEq => "<>".to_string(),
SQLOperator::And => "AND".to_string(), SQLBinaryOperator::And => "AND".to_string(),
SQLOperator::Or => "OR".to_string(), SQLBinaryOperator::Or => "OR".to_string(),
SQLOperator::Not => "NOT".to_string(), SQLBinaryOperator::Like => "LIKE".to_string(),
SQLOperator::Like => "LIKE".to_string(), SQLBinaryOperator::NotLike => "NOT LIKE".to_string(),
SQLOperator::NotLike => "NOT LIKE".to_string(),
} }
} }
} }

View file

@ -184,7 +184,7 @@ impl Parser {
"EXTRACT" => self.parse_extract_expression(), "EXTRACT" => self.parse_extract_expression(),
"INTERVAL" => self.parse_literal_interval(), "INTERVAL" => self.parse_literal_interval(),
"NOT" => Ok(ASTNode::SQLUnaryOp { "NOT" => Ok(ASTNode::SQLUnaryOp {
op: SQLOperator::Not, op: SQLUnaryOperator::Not,
expr: Box::new(self.parse_subexpr(Self::UNARY_NOT_PREC)?), expr: Box::new(self.parse_subexpr(Self::UNARY_NOT_PREC)?),
}), }),
"TIME" => Ok(ASTNode::SQLValue(Value::Time(self.parse_literal_string()?))), "TIME" => Ok(ASTNode::SQLValue(Value::Time(self.parse_literal_string()?))),
@ -225,9 +225,9 @@ impl Parser {
Token::Mult => Ok(ASTNode::SQLWildcard), Token::Mult => Ok(ASTNode::SQLWildcard),
tok @ Token::Minus | tok @ Token::Plus => { tok @ Token::Minus | tok @ Token::Plus => {
let op = if tok == Token::Plus { let op = if tok == Token::Plus {
SQLOperator::Plus SQLUnaryOperator::Plus
} else { } else {
SQLOperator::Minus SQLUnaryOperator::Minus
}; };
Ok(ASTNode::SQLUnaryOp { Ok(ASTNode::SQLUnaryOp {
op, op,
@ -513,24 +513,24 @@ impl Parser {
let tok = self.next_token().unwrap(); // safe as EOF's precedence is the lowest let tok = self.next_token().unwrap(); // safe as EOF's precedence is the lowest
let regular_binary_operator = match tok { let regular_binary_operator = match tok {
Token::Eq => Some(SQLOperator::Eq), Token::Eq => Some(SQLBinaryOperator::Eq),
Token::Neq => Some(SQLOperator::NotEq), Token::Neq => Some(SQLBinaryOperator::NotEq),
Token::Gt => Some(SQLOperator::Gt), Token::Gt => Some(SQLBinaryOperator::Gt),
Token::GtEq => Some(SQLOperator::GtEq), Token::GtEq => Some(SQLBinaryOperator::GtEq),
Token::Lt => Some(SQLOperator::Lt), Token::Lt => Some(SQLBinaryOperator::Lt),
Token::LtEq => Some(SQLOperator::LtEq), Token::LtEq => Some(SQLBinaryOperator::LtEq),
Token::Plus => Some(SQLOperator::Plus), Token::Plus => Some(SQLBinaryOperator::Plus),
Token::Minus => Some(SQLOperator::Minus), Token::Minus => Some(SQLBinaryOperator::Minus),
Token::Mult => Some(SQLOperator::Multiply), Token::Mult => Some(SQLBinaryOperator::Multiply),
Token::Mod => Some(SQLOperator::Modulus), Token::Mod => Some(SQLBinaryOperator::Modulus),
Token::Div => Some(SQLOperator::Divide), Token::Div => Some(SQLBinaryOperator::Divide),
Token::SQLWord(ref k) => match k.keyword.as_ref() { Token::SQLWord(ref k) => match k.keyword.as_ref() {
"AND" => Some(SQLOperator::And), "AND" => Some(SQLBinaryOperator::And),
"OR" => Some(SQLOperator::Or), "OR" => Some(SQLBinaryOperator::Or),
"LIKE" => Some(SQLOperator::Like), "LIKE" => Some(SQLBinaryOperator::Like),
"NOT" => { "NOT" => {
if self.parse_keyword("LIKE") { if self.parse_keyword("LIKE") {
Some(SQLOperator::NotLike) Some(SQLBinaryOperator::NotLike)
} else { } else {
None None
} }

View file

@ -169,7 +169,7 @@ fn parse_delete_statement() {
#[test] #[test]
fn parse_where_delete_statement() { fn parse_where_delete_statement() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLBinaryOperator::*;
let sql = "DELETE FROM foo WHERE name = 5"; let sql = "DELETE FROM foo WHERE name = 5";
match verified_stmt(sql) { match verified_stmt(sql) {
@ -288,7 +288,7 @@ fn parse_column_aliases() {
ref alias, ref alias,
} = only(&select.projection) } = only(&select.projection)
{ {
assert_eq!(&SQLOperator::Plus, op); assert_eq!(&SQLBinaryOperator::Plus, op);
assert_eq!(&ASTNode::SQLValue(Value::Long(1)), right.as_ref()); assert_eq!(&ASTNode::SQLValue(Value::Long(1)), right.as_ref());
assert_eq!("newname", alias); assert_eq!("newname", alias);
} else { } else {
@ -337,7 +337,7 @@ fn parse_select_count_distinct() {
&ASTNode::SQLFunction(SQLFunction { &ASTNode::SQLFunction(SQLFunction {
name: SQLObjectName(vec!["COUNT".to_string()]), name: SQLObjectName(vec!["COUNT".to_string()]),
args: vec![ASTNode::SQLUnaryOp { args: vec![ASTNode::SQLUnaryOp {
op: SQLOperator::Plus, op: SQLUnaryOperator::Plus,
expr: Box::new(ASTNode::SQLIdentifier("x".to_string())) expr: Box::new(ASTNode::SQLIdentifier("x".to_string()))
}], }],
over: None, over: None,
@ -404,7 +404,7 @@ fn parse_projection_nested_type() {
#[test] #[test]
fn parse_escaped_single_quote_string_predicate() { fn parse_escaped_single_quote_string_predicate() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLBinaryOperator::*;
let sql = "SELECT id, fname, lname FROM customer \ let sql = "SELECT id, fname, lname FROM customer \
WHERE salary <> 'Jim''s salary'"; WHERE salary <> 'Jim''s salary'";
let ast = verified_only_select(sql); let ast = verified_only_select(sql);
@ -423,7 +423,7 @@ fn parse_escaped_single_quote_string_predicate() {
#[test] #[test]
fn parse_compound_expr_1() { fn parse_compound_expr_1() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLBinaryOperator::*;
let sql = "a + b * c"; let sql = "a + b * c";
assert_eq!( assert_eq!(
SQLBinaryOp { SQLBinaryOp {
@ -442,7 +442,7 @@ fn parse_compound_expr_1() {
#[test] #[test]
fn parse_compound_expr_2() { fn parse_compound_expr_2() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLBinaryOperator::*;
let sql = "a * b + c"; let sql = "a * b + c";
assert_eq!( assert_eq!(
SQLBinaryOp { SQLBinaryOp {
@ -461,17 +461,16 @@ fn parse_compound_expr_2() {
#[test] #[test]
fn parse_unary_math() { fn parse_unary_math() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*;
let sql = "- a + - b"; let sql = "- a + - b";
assert_eq!( assert_eq!(
SQLBinaryOp { SQLBinaryOp {
left: Box::new(SQLUnaryOp { left: Box::new(SQLUnaryOp {
op: Minus, op: SQLUnaryOperator::Minus,
expr: Box::new(SQLIdentifier("a".to_string())), expr: Box::new(SQLIdentifier("a".to_string())),
}), }),
op: Plus, op: SQLBinaryOperator::Plus,
right: Box::new(SQLUnaryOp { right: Box::new(SQLUnaryOp {
op: Minus, op: SQLUnaryOperator::Minus,
expr: Box::new(SQLIdentifier("b".to_string())), expr: Box::new(SQLIdentifier("b".to_string())),
}), }),
}, },
@ -505,14 +504,14 @@ fn parse_not_precedence() {
// NOT has higher precedence than OR/AND, so the following must parse as (NOT true) OR true // NOT has higher precedence than OR/AND, so the following must parse as (NOT true) OR true
let sql = "NOT true OR true"; let sql = "NOT true OR true";
assert_matches!(verified_expr(sql), SQLBinaryOp { assert_matches!(verified_expr(sql), SQLBinaryOp {
op: SQLOperator::Or, op: SQLBinaryOperator::Or,
.. ..
}); });
// But NOT has lower precedence than comparison operators, so the following parses as NOT (a IS NULL) // But NOT has lower precedence than comparison operators, so the following parses as NOT (a IS NULL)
let sql = "NOT a IS NULL"; let sql = "NOT a IS NULL";
assert_matches!(verified_expr(sql), SQLUnaryOp { assert_matches!(verified_expr(sql), SQLUnaryOp {
op: SQLOperator::Not, op: SQLUnaryOperator::Not,
.. ..
}); });
@ -521,7 +520,7 @@ fn parse_not_precedence() {
assert_eq!( assert_eq!(
verified_expr(sql), verified_expr(sql),
SQLUnaryOp { SQLUnaryOp {
op: SQLOperator::Not, op: SQLUnaryOperator::Not,
expr: Box::new(SQLBetween { expr: Box::new(SQLBetween {
expr: Box::new(SQLValue(Value::Long(1))), expr: Box::new(SQLValue(Value::Long(1))),
low: Box::new(SQLValue(Value::Long(1))), low: Box::new(SQLValue(Value::Long(1))),
@ -536,10 +535,10 @@ fn parse_not_precedence() {
assert_eq!( assert_eq!(
verified_expr(sql), verified_expr(sql),
SQLUnaryOp { SQLUnaryOp {
op: SQLOperator::Not, op: SQLUnaryOperator::Not,
expr: Box::new(SQLBinaryOp { expr: Box::new(SQLBinaryOp {
left: Box::new(SQLValue(Value::SingleQuotedString("a".into()))), left: Box::new(SQLValue(Value::SingleQuotedString("a".into()))),
op: SQLOperator::NotLike, op: SQLBinaryOperator::NotLike,
right: Box::new(SQLValue(Value::SingleQuotedString("b".into()))), right: Box::new(SQLValue(Value::SingleQuotedString("b".into()))),
}), }),
}, },
@ -550,7 +549,7 @@ fn parse_not_precedence() {
assert_eq!( assert_eq!(
verified_expr(sql), verified_expr(sql),
SQLUnaryOp { SQLUnaryOp {
op: SQLOperator::Not, op: SQLUnaryOperator::Not,
expr: Box::new(SQLInList { expr: Box::new(SQLInList {
expr: Box::new(SQLIdentifier("a".into())), expr: Box::new(SQLIdentifier("a".into())),
list: vec![SQLValue(Value::SingleQuotedString("a".into()))], list: vec![SQLValue(Value::SingleQuotedString("a".into()))],
@ -572,9 +571,9 @@ fn parse_like() {
ASTNode::SQLBinaryOp { ASTNode::SQLBinaryOp {
left: Box::new(ASTNode::SQLIdentifier("name".to_string())), left: Box::new(ASTNode::SQLIdentifier("name".to_string())),
op: if negated { op: if negated {
SQLOperator::NotLike SQLBinaryOperator::NotLike
} else { } else {
SQLOperator::Like SQLBinaryOperator::Like
}, },
right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"%a".to_string() "%a".to_string()
@ -594,9 +593,9 @@ fn parse_like() {
ASTNode::SQLIsNull(Box::new(ASTNode::SQLBinaryOp { ASTNode::SQLIsNull(Box::new(ASTNode::SQLBinaryOp {
left: Box::new(ASTNode::SQLIdentifier("name".to_string())), left: Box::new(ASTNode::SQLIdentifier("name".to_string())),
op: if negated { op: if negated {
SQLOperator::NotLike SQLBinaryOperator::NotLike
} else { } else {
SQLOperator::Like SQLBinaryOperator::Like
}, },
right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"%a".to_string() "%a".to_string()
@ -672,7 +671,7 @@ fn parse_between() {
#[test] #[test]
fn parse_between_with_expr() { fn parse_between_with_expr() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLBinaryOperator::*;
let sql = "SELECT * FROM t WHERE 1 BETWEEN 1 + 2 AND 3 + 4 IS NULL"; let sql = "SELECT * FROM t WHERE 1 BETWEEN 1 + 2 AND 3 + 4 IS NULL";
let select = verified_only_select(sql); let select = verified_only_select(sql);
assert_eq!( assert_eq!(
@ -699,14 +698,14 @@ fn parse_between_with_expr() {
ASTNode::SQLBinaryOp { ASTNode::SQLBinaryOp {
left: Box::new(ASTNode::SQLBinaryOp { left: Box::new(ASTNode::SQLBinaryOp {
left: Box::new(ASTNode::SQLValue(Value::Long(1))), left: Box::new(ASTNode::SQLValue(Value::Long(1))),
op: SQLOperator::Eq, op: SQLBinaryOperator::Eq,
right: Box::new(ASTNode::SQLValue(Value::Long(1))), right: Box::new(ASTNode::SQLValue(Value::Long(1))),
}), }),
op: SQLOperator::And, op: SQLBinaryOperator::And,
right: Box::new(ASTNode::SQLBetween { right: Box::new(ASTNode::SQLBetween {
expr: Box::new(ASTNode::SQLBinaryOp { expr: Box::new(ASTNode::SQLBinaryOp {
left: Box::new(ASTNode::SQLValue(Value::Long(1))), left: Box::new(ASTNode::SQLValue(Value::Long(1))),
op: SQLOperator::Plus, op: SQLBinaryOperator::Plus,
right: Box::new(ASTNode::SQLIdentifier("x".to_string())), right: Box::new(ASTNode::SQLIdentifier("x".to_string())),
}), }),
low: Box::new(ASTNode::SQLValue(Value::Long(1))), low: Box::new(ASTNode::SQLValue(Value::Long(1))),
@ -1365,7 +1364,7 @@ fn parse_delimited_identifiers() {
#[test] #[test]
fn parse_parens() { fn parse_parens() {
use self::ASTNode::*; use self::ASTNode::*;
use self::SQLOperator::*; use self::SQLBinaryOperator::*;
let sql = "(a + b) - (c + d)"; let sql = "(a + b) - (c + d)";
assert_eq!( assert_eq!(
SQLBinaryOp { SQLBinaryOp {
@ -1389,7 +1388,7 @@ fn parse_parens() {
fn parse_searched_case_expression() { fn parse_searched_case_expression() {
let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo"; let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo";
use self::ASTNode::{SQLBinaryOp, SQLCase, SQLIdentifier, SQLIsNull, SQLValue}; use self::ASTNode::{SQLBinaryOp, SQLCase, SQLIdentifier, SQLIsNull, SQLValue};
use self::SQLOperator::*; use self::SQLBinaryOperator::*;
let select = verified_only_select(sql); let select = verified_only_select(sql);
assert_eq!( assert_eq!(
&SQLCase { &SQLCase {
@ -1557,7 +1556,7 @@ fn parse_joins_on() {
}, },
join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryOp { join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryOp {
left: Box::new(ASTNode::SQLIdentifier("c1".into())), left: Box::new(ASTNode::SQLIdentifier("c1".into())),
op: SQLOperator::Eq, op: SQLBinaryOperator::Eq,
right: Box::new(ASTNode::SQLIdentifier("c2".into())), right: Box::new(ASTNode::SQLIdentifier("c2".into())),
})), })),
} }
@ -1921,7 +1920,7 @@ fn parse_scalar_subqueries() {
use self::ASTNode::*; use self::ASTNode::*;
let sql = "(SELECT 1) + (SELECT 2)"; let sql = "(SELECT 1) + (SELECT 2)";
assert_matches!(verified_expr(sql), SQLBinaryOp { assert_matches!(verified_expr(sql), SQLBinaryOp {
op: SQLOperator::Plus, .. op: SQLBinaryOperator::Plus, ..
//left: box SQLSubquery { .. }, //left: box SQLSubquery { .. },
//right: box SQLSubquery { .. }, //right: box SQLSubquery { .. },
}); });
@ -1941,7 +1940,7 @@ fn parse_exists_subquery() {
let select = verified_only_select(sql); let select = verified_only_select(sql);
assert_eq!( assert_eq!(
ASTNode::SQLUnaryOp { ASTNode::SQLUnaryOp {
op: SQLOperator::Not, op: SQLUnaryOperator::Not,
expr: Box::new(ASTNode::SQLExists(Box::new(expected_inner))), expr: Box::new(ASTNode::SQLExists(Box::new(expected_inner))),
}, },
select.selection.unwrap(), select.selection.unwrap(),