Support for joins

This commit is contained in:
Fredrik Roos 2018-11-17 15:40:24 +01:00
parent face97226b
commit 7624095738
4 changed files with 409 additions and 64 deletions

View file

@ -11,7 +11,8 @@ impl Dialect for GenericSqlDialect {
STORED, CSV, PARQUET, LOCATION, WITH, WITHOUT, HEADER, ROW, // SQL types STORED, CSV, PARQUET, LOCATION, WITH, WITHOUT, HEADER, ROW, // SQL types
CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT,
REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC,
BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL,
CROSS, OUTER, INNER, NATURAL, ON, USING,
]; ];
} }

View file

@ -75,6 +75,8 @@ pub enum ASTNode {
projection: Vec<ASTNode>, projection: Vec<ASTNode>,
/// FROM /// FROM
relation: Option<Box<ASTNode>>, relation: Option<Box<ASTNode>>,
// JOIN
joins: Vec<Join>,
/// WHERE /// WHERE
selection: Option<Box<ASTNode>>, selection: Option<Box<ASTNode>>,
/// ORDER BY /// ORDER BY
@ -167,10 +169,16 @@ impl ToString for ASTNode {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", ") .join(", ")
), ),
ASTNode::SQLCase { conditions, results, else_result } => { ASTNode::SQLCase {
conditions,
results,
else_result,
} => {
let mut s = format!( let mut s = format!(
"CASE {}", "CASE {}",
conditions.iter().zip(results) conditions
.iter()
.zip(results)
.map(|(c, r)| format!("WHEN {} THEN {}", c.to_string(), r.to_string())) .map(|(c, r)| format!("WHEN {} THEN {}", c.to_string(), r.to_string()))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(" ") .join(" ")
@ -179,10 +187,11 @@ impl ToString for ASTNode {
s += &format!(" ELSE {}", else_result.to_string()) s += &format!(" ELSE {}", else_result.to_string())
} }
s + " END" s + " END"
}, }
ASTNode::SQLSelect { ASTNode::SQLSelect {
projection, projection,
relation, relation,
joins,
selection, selection,
order_by, order_by,
group_by, group_by,
@ -200,6 +209,9 @@ impl ToString for ASTNode {
if let Some(relation) = relation { if let Some(relation) = relation {
s += &format!(" FROM {}", relation.as_ref().to_string()); s += &format!(" FROM {}", relation.as_ref().to_string());
} }
for join in joins {
s += &join.to_string();
}
if let Some(selection) = selection { if let Some(selection) = selection {
s += &format!(" WHERE {}", selection.as_ref().to_string()); s += &format!(" WHERE {}", selection.as_ref().to_string());
} }
@ -402,3 +414,72 @@ impl ToString for SQLColumnDef {
s s
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub struct Join {
pub relation: ASTNode,
pub join_operator: JoinOperator,
}
impl ToString for Join {
fn to_string(&self) -> String {
fn prefix(constraint: &JoinConstraint) -> String {
match constraint {
JoinConstraint::Natural => "NATURAL ".to_string(),
_ => "".to_string(),
}
}
fn suffix(constraint: &JoinConstraint) -> String {
match constraint {
JoinConstraint::On(expr) => format!(" ON({})", expr.to_string()),
JoinConstraint::Using(attrs) => format!(" USING({})", attrs.join(", ")),
_ => "".to_string(),
}
}
match &self.join_operator {
JoinOperator::Inner(constraint) => format!(
"{}INNER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
prefix(constraint)
),
JoinOperator::Cross => format!("CROSS JOIN {}", self.relation.to_string()),
JoinOperator::Implicit => format!(", {}", self.relation.to_string()),
JoinOperator::LeftOuter(constraint) => format!(
"{}LEFT OUTER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
suffix(constraint)
),
JoinOperator::RightOuter(constraint) => format!(
"{}RIGHT OUTER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
suffix(constraint)
),
JoinOperator::FullOuter(constraint) => format!(
"{}FULL OUTER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
suffix(constraint)
),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum JoinOperator {
Inner(JoinConstraint),
LeftOuter(JoinConstraint),
RightOuter(JoinConstraint),
FullOuter(JoinConstraint),
Implicit,
Cross,
}
#[derive(Debug, Clone, PartialEq)]
pub enum JoinConstraint {
On(ASTNode),
Using(Vec<String>),
Natural,
}

View file

@ -17,10 +17,7 @@
use super::dialect::Dialect; use super::dialect::Dialect;
use super::sqlast::*; use super::sqlast::*;
use super::sqltokenizer::*; use super::sqltokenizer::*;
use chrono::{ use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime};
offset::{FixedOffset},
DateTime, NaiveDate, NaiveDateTime, NaiveTime,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ParserError { pub enum ParserError {
@ -109,10 +106,8 @@ impl Parser {
"NULL" => { "NULL" => {
self.prev_token(); self.prev_token();
self.parse_sql_value() self.parse_sql_value()
},
"CASE" => {
self.parse_case_expression()
} }
"CASE" => self.parse_case_expression(),
_ => return parser_err!(format!("No prefix parser for keyword {}", k)), _ => return parser_err!(format!("No prefix parser for keyword {}", k)),
}, },
Token::Mult => Ok(ASTNode::SQLWildcard), Token::Mult => Ok(ASTNode::SQLWildcard),
@ -156,14 +151,14 @@ impl Parser {
Token::DoubleQuotedString(_) => { Token::DoubleQuotedString(_) => {
self.prev_token(); self.prev_token();
self.parse_sql_value() self.parse_sql_value()
}, }
Token::LParen => { Token::LParen => {
let expr = self.parse(); let expr = self.parse();
if !self.consume_token(&Token::RParen)? { if !self.consume_token(&Token::RParen)? {
return parser_err!(format!("expected token RParen")); return parser_err!(format!("expected token RParen"));
} }
expr expr
}, }
_ => parser_err!(format!( _ => parser_err!(format!(
"Prefix parser expected a keyword but found {:?}", "Prefix parser expected a keyword but found {:?}",
t t
@ -211,20 +206,20 @@ impl Parser {
if self.parse_keywords(vec!["ELSE"]) { if self.parse_keywords(vec!["ELSE"]) {
else_result = Some(Box::new(self.parse_expr(0)?)); else_result = Some(Box::new(self.parse_expr(0)?));
if self.parse_keywords(vec!["END"]) { if self.parse_keywords(vec!["END"]) {
break break;
} else { } else {
return parser_err!("Expecting END after a CASE..ELSE"); return parser_err!("Expecting END after a CASE..ELSE");
} }
} }
if self.parse_keywords(vec!["END"]) { if self.parse_keywords(vec!["END"]) {
break break;
} }
self.consume_token(&Token::Keyword("WHEN".to_string()))?; self.consume_token(&Token::Keyword("WHEN".to_string()))?;
} }
Ok(ASTNode::SQLCase { Ok(ASTNode::SQLCase {
conditions, conditions,
results, results,
else_result else_result,
}) })
} else { } else {
// TODO: implement "simple" case // TODO: implement "simple" case
@ -489,6 +484,18 @@ impl Parser {
true true
} }
pub fn expect_keyword(&mut self, expected: &'static str) -> Result<(), ParserError> {
if self.parse_keyword(expected) {
Ok(())
} else {
parser_err!(format!(
"Expected keyword {}, found {:?}",
expected,
self.peek_token()
))
}
}
//TODO: this function is inconsistent and sometimes returns bool and sometimes fails //TODO: this function is inconsistent and sometimes returns bool and sometimes fails
/// Consume the next token if it matches the expected token, otherwise return an error /// Consume the next token if it matches the expected token, otherwise return an error
@ -751,12 +758,12 @@ impl Parser {
}, },
Token::Number(ref n) => match n.parse::<i64>() { Token::Number(ref n) => match n.parse::<i64>() {
Ok(n) => { Ok(n) => {
// if let Some(Token::Minus) = self.peek_token() { // if let Some(Token::Minus) = self.peek_token() {
// self.prev_token(); // self.prev_token();
// self.parse_timestamp_value() // self.parse_timestamp_value()
// } else { // } else {
Ok(Value::Long(n)) Ok(Value::Long(n))
// } // }
} }
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
}, },
@ -1108,11 +1115,12 @@ impl Parser {
pub fn parse_select(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_select(&mut self) -> Result<ASTNode, ParserError> {
let projection = self.parse_expr_list()?; let projection = self.parse_expr_list()?;
let relation: Option<Box<ASTNode>> = if self.parse_keyword("FROM") { let (relation, joins): (Option<Box<ASTNode>>, Vec<Join>) = if self.parse_keyword("FROM") {
//TODO: add support for JOIN let relation = Some(Box::new(self.parse_expr(0)?));
Some(Box::new(self.parse_expr(0)?)) let joins = self.parse_joins()?;
(relation, joins)
} else { } else {
None (None, vec![])
}; };
let selection = if self.parse_keyword("WHERE") { let selection = if self.parse_keyword("WHERE") {
@ -1158,6 +1166,7 @@ impl Parser {
projection, projection,
selection, selection,
relation, relation,
joins,
limit, limit,
order_by, order_by,
group_by, group_by,
@ -1166,6 +1175,131 @@ impl Parser {
} }
} }
fn parse_join_constraint(&mut self, natural: bool) -> Result<JoinConstraint, ParserError> {
if natural {
Ok(JoinConstraint::Natural)
} else if self.parse_keyword("ON") {
let constraint = self.parse_expr(0)?;
Ok(JoinConstraint::On(constraint))
} else if self.parse_keyword("USING") {
if self.consume_token(&Token::LParen)? {
let attributes = self
.parse_expr_list()?
.into_iter()
.map(|ast_node| match ast_node {
ASTNode::SQLIdentifier(ident) => Ok(ident),
unexpected => {
parser_err!(format!("Expected identifier, found {:?}", unexpected))
}
})
.collect::<Result<Vec<String>, ParserError>>()?;
if self.consume_token(&Token::RParen)? {
Ok(JoinConstraint::Using(attributes))
} else {
parser_err!(format!("Expected token ')', found {:?}", self.peek_token()))
}
} else {
parser_err!(format!("Expected token '(', found {:?}", self.peek_token()))
}
} else {
parser_err!(format!(
"Unexpected token after JOIN: {:?}",
self.peek_token()
))
}
}
fn parse_joins(&mut self) -> Result<Vec<Join>, ParserError> {
let mut joins = vec![];
loop {
let natural = match &self.peek_token() {
Some(Token::Comma) => {
self.next_token();
let relation = self.parse_expr(0)?;
let join = Join {
relation,
join_operator: JoinOperator::Implicit,
};
joins.push(join);
continue;
}
Some(Token::Keyword(kw)) if kw == "CROSS" => {
self.next_token();
self.expect_keyword("JOIN")?;
let relation = self.parse_expr(0)?;
let join = Join {
relation,
join_operator: JoinOperator::Cross,
};
joins.push(join);
continue;
}
Some(Token::Keyword(kw)) if kw == "NATURAL" => {
self.next_token();
true
}
Some(_) => false,
None => return Ok(joins),
};
let join = match &self.peek_token() {
Some(Token::Keyword(kw)) if kw == "INNER" => {
self.next_token();
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?),
}
}
Some(Token::Keyword(kw)) if kw == "JOIN" => {
self.next_token();
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?),
}
}
Some(Token::Keyword(kw)) if kw == "LEFT" => {
self.next_token();
self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::LeftOuter(
self.parse_join_constraint(natural)?,
),
}
}
Some(Token::Keyword(kw)) if kw == "RIGHT" => {
self.next_token();
self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::RightOuter(
self.parse_join_constraint(natural)?,
),
}
}
Some(Token::Keyword(kw)) if kw == "FULL" => {
self.next_token();
self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::FullOuter(
self.parse_join_constraint(natural)?,
),
}
}
_ => break,
};
joins.push(join);
}
Ok(joins)
}
/// Parse an INSERT statement /// Parse an INSERT statement
pub fn parse_insert(&mut self) -> Result<ASTNode, ParserError> { pub fn parse_insert(&mut self) -> Result<ASTNode, ParserError> {
self.parse_keyword("INTO"); self.parse_keyword("INTO");
@ -1215,19 +1349,17 @@ impl Parser {
// look for optional ASC / DESC specifier // look for optional ASC / DESC specifier
let asc = match self.peek_token() { let asc = match self.peek_token() {
Some(Token::Keyword(k)) => { Some(Token::Keyword(k)) => match k.to_uppercase().as_ref() {
match k.to_uppercase().as_ref() { "ASC" => {
"ASC" => { self.next_token();
self.next_token(); true
true
},
"DESC" => {
self.next_token();
false
},
_ => true
} }
} "DESC" => {
self.next_token();
false
}
_ => true,
},
Some(Token::Comma) => true, Some(Token::Comma) => true,
_ => true, _ => true,
}; };

View file

@ -13,7 +13,9 @@ fn parse_delete_statement() {
match parse_sql(&sql) { match parse_sql(&sql) {
ASTNode::SQLDelete { relation, .. } => { ASTNode::SQLDelete { relation, .. } => {
assert_eq!( assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))), Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string()
)))),
relation relation
); );
} }
@ -36,7 +38,9 @@ fn parse_where_delete_statement() {
.. ..
} => { } => {
assert_eq!( assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))), Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string()
)))),
relation relation
); );
@ -207,7 +211,9 @@ fn parse_select_order_by_limit() {
); );
let ast = parse_sql(&sql); let ast = parse_sql(&sql);
match ast { match ast {
ASTNode::SQLSelect { order_by, limit, .. } => { ASTNode::SQLSelect {
order_by, limit, ..
} => {
assert_eq!( assert_eq!(
Some(vec![ Some(vec![
SQLOrderByExpr { SQLOrderByExpr {
@ -341,7 +347,10 @@ fn parse_literal_string() {
let sql = "SELECT 'one'"; let sql = "SELECT 'one'";
match parse_sql(&sql) { match parse_sql(&sql) {
ASTNode::SQLSelect { ref projection, .. } => { ASTNode::SQLSelect { ref projection, .. } => {
assert_eq!(projection[0], ASTNode::SQLValue(Value::SingleQuotedString("one".to_string()))); assert_eq!(
projection[0],
ASTNode::SQLValue(Value::SingleQuotedString("one".to_string()))
);
} }
_ => panic!(), _ => panic!(),
} }
@ -380,20 +389,21 @@ fn parse_parens() {
let sql = "(a + b) - (c + d)"; let sql = "(a + b) - (c + d)";
let ast = parse_sql(&sql); let ast = parse_sql(&sql);
assert_eq!( assert_eq!(
SQLBinaryExpr { SQLBinaryExpr {
left: Box::new(SQLBinaryExpr { left: Box::new(SQLBinaryExpr {
left: Box::new(SQLIdentifier("a".to_string())), left: Box::new(SQLIdentifier("a".to_string())),
op: Plus, op: Plus,
right: Box::new(SQLIdentifier("b".to_string())) right: Box::new(SQLIdentifier("b".to_string()))
}), }),
op: Minus, op: Minus,
right: Box::new(SQLBinaryExpr { right: Box::new(SQLBinaryExpr {
left: Box::new(SQLIdentifier("c".to_string())), left: Box::new(SQLIdentifier("c".to_string())),
op: Plus, op: Plus,
right: Box::new(SQLIdentifier("d".to_string())) right: Box::new(SQLIdentifier("d".to_string()))
}) })
} },
, ast); ast
);
} }
#[test] #[test]
@ -410,17 +420,28 @@ fn parse_case_expression() {
SQLCase { SQLCase {
conditions: vec![ conditions: vec![
SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))), SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))),
SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())), SQLBinaryExpr {
op: Eq, right: Box::new(SQLValue(Value::Long(0))) }, left: Box::new(SQLIdentifier("bar".to_string())),
SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())), op: Eq,
op: GtEq, right: Box::new(SQLValue(Value::Long(0))) } right: Box::new(SQLValue(Value::Long(0)))
},
SQLBinaryExpr {
left: Box::new(SQLIdentifier("bar".to_string())),
op: GtEq,
right: Box::new(SQLValue(Value::Long(0)))
}
], ],
results: vec![SQLValue(Value::SingleQuotedString("null".to_string())), results: vec![
SQLValue(Value::SingleQuotedString("=0".to_string())), SQLValue(Value::SingleQuotedString("null".to_string())),
SQLValue(Value::SingleQuotedString(">=0".to_string()))], SQLValue(Value::SingleQuotedString("=0".to_string())),
else_result: Some(Box::new(SQLValue(Value::SingleQuotedString("<0".to_string())))) SQLValue(Value::SingleQuotedString(">=0".to_string()))
],
else_result: Some(Box::new(SQLValue(Value::SingleQuotedString(
"<0".to_string()
))))
}, },
projection[0]); projection[0]
);
} }
_ => assert!(false), _ => assert!(false),
} }
@ -445,7 +466,9 @@ fn parse_delete_with_semi_colon() {
match parse_sql(&sql) { match parse_sql(&sql) {
ASTNode::SQLDelete { relation, .. } => { ASTNode::SQLDelete { relation, .. } => {
assert_eq!( assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))), Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string()
)))),
relation relation
); );
} }
@ -453,12 +476,120 @@ fn parse_delete_with_semi_colon() {
} }
} }
#[test]
fn parse_implicit_join() {
let sql = "SELECT * FROM t1,t2";
match parse_sql(sql) {
ASTNode::SQLSelect { joins, .. } => {
assert_eq!(joins.len(), 1);
assert_eq!(
joins[0],
Join {
relation: ASTNode::SQLIdentifier("t2".to_string()),
join_operator: JoinOperator::Implicit
}
)
}
_ => assert!(false),
}
}
#[test]
fn parse_cross_join() {
let sql = "SELECT * FROM t1 CROSS JOIN t2";
match parse_sql(sql) {
ASTNode::SQLSelect { joins, .. } => {
assert_eq!(joins.len(), 1);
assert_eq!(
joins[0],
Join {
relation: ASTNode::SQLIdentifier("t2".to_string()),
join_operator: JoinOperator::Cross
}
)
}
_ => assert!(false),
}
}
#[test]
fn parse_joins_on() {
fn join_with_constraint(
relation: impl Into<String>,
f: impl Fn(JoinConstraint) -> JoinOperator,
) -> Join {
Join {
relation: ASTNode::SQLIdentifier(relation.into()),
join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr {
left: Box::new(ASTNode::SQLIdentifier("c1".into())),
op: SQLOperator::Eq,
right: Box::new(ASTNode::SQLIdentifier("c2".into())),
})),
}
}
assert_eq!(
joins_from("SELECT * FROM t1 JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::Inner)]
);
assert_eq!(
joins_from("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::LeftOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::RightOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::FullOuter)]
);
}
#[test]
fn parse_joins_using() {
fn join_with_constraint(
relation: impl Into<String>,
f: impl Fn(JoinConstraint) -> JoinOperator,
) -> Join {
Join {
relation: ASTNode::SQLIdentifier(relation.into()),
join_operator: f(JoinConstraint::Using(vec!["c1".into()])),
}
}
assert_eq!(
joins_from("SELECT * FROM t1 JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::Inner)]
);
assert_eq!(
joins_from("SELECT * FROM t1 LEFT JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::LeftOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::RightOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::FullOuter)]
);
}
fn joins_from(sql: &str) -> Vec<Join> {
match parse_sql(sql) {
ASTNode::SQLSelect { joins, .. } => joins,
_ => panic!("Expected SELECT"),
}
}
fn parse_sql(sql: &str) -> ASTNode { fn parse_sql(sql: &str) -> ASTNode {
let dialect = GenericSqlDialect {}; let dialect = GenericSqlDialect {};
let mut tokenizer = Tokenizer::new(&dialect,&sql, ); let mut tokenizer = Tokenizer::new(&dialect, &sql);
let tokens = tokenizer.tokenize().unwrap(); let tokens = tokenizer.tokenize().unwrap();
let mut parser = Parser::new(tokens); let mut parser = Parser::new(tokens);
let ast = parser.parse().unwrap(); let ast = parser.parse().unwrap();
ast ast
} }