Support for joins

This commit is contained in:
Fredrik Roos 2018-11-17 15:40:24 +01:00
parent face97226b
commit 7624095738
4 changed files with 409 additions and 64 deletions

View file

@ -11,7 +11,8 @@ impl Dialect for GenericSqlDialect {
STORED, CSV, PARQUET, LOCATION, WITH, WITHOUT, HEADER, ROW, // SQL types
CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT,
REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC,
BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END,
BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL,
CROSS, OUTER, INNER, NATURAL, ON, USING,
];
}

View file

@ -75,6 +75,8 @@ pub enum ASTNode {
projection: Vec<ASTNode>,
/// FROM
relation: Option<Box<ASTNode>>,
// JOIN
joins: Vec<Join>,
/// WHERE
selection: Option<Box<ASTNode>>,
/// ORDER BY
@ -167,10 +169,16 @@ impl ToString for ASTNode {
.collect::<Vec<String>>()
.join(", ")
),
ASTNode::SQLCase { conditions, results, else_result } => {
ASTNode::SQLCase {
conditions,
results,
else_result,
} => {
let mut s = format!(
"CASE {}",
conditions.iter().zip(results)
conditions
.iter()
.zip(results)
.map(|(c, r)| format!("WHEN {} THEN {}", c.to_string(), r.to_string()))
.collect::<Vec<String>>()
.join(" ")
@ -179,10 +187,11 @@ impl ToString for ASTNode {
s += &format!(" ELSE {}", else_result.to_string())
}
s + " END"
},
}
ASTNode::SQLSelect {
projection,
relation,
joins,
selection,
order_by,
group_by,
@ -200,6 +209,9 @@ impl ToString for ASTNode {
if let Some(relation) = relation {
s += &format!(" FROM {}", relation.as_ref().to_string());
}
for join in joins {
s += &join.to_string();
}
if let Some(selection) = selection {
s += &format!(" WHERE {}", selection.as_ref().to_string());
}
@ -402,3 +414,72 @@ impl ToString for SQLColumnDef {
s
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Join {
pub relation: ASTNode,
pub join_operator: JoinOperator,
}
impl ToString for Join {
fn to_string(&self) -> String {
fn prefix(constraint: &JoinConstraint) -> String {
match constraint {
JoinConstraint::Natural => "NATURAL ".to_string(),
_ => "".to_string(),
}
}
fn suffix(constraint: &JoinConstraint) -> String {
match constraint {
JoinConstraint::On(expr) => format!(" ON({})", expr.to_string()),
JoinConstraint::Using(attrs) => format!(" USING({})", attrs.join(", ")),
_ => "".to_string(),
}
}
match &self.join_operator {
JoinOperator::Inner(constraint) => format!(
"{}INNER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
prefix(constraint)
),
JoinOperator::Cross => format!("CROSS JOIN {}", self.relation.to_string()),
JoinOperator::Implicit => format!(", {}", self.relation.to_string()),
JoinOperator::LeftOuter(constraint) => format!(
"{}LEFT OUTER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
suffix(constraint)
),
JoinOperator::RightOuter(constraint) => format!(
"{}RIGHT OUTER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
suffix(constraint)
),
JoinOperator::FullOuter(constraint) => format!(
"{}FULL OUTER JOIN {}{}",
prefix(constraint),
self.relation.to_string(),
suffix(constraint)
),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum JoinOperator {
Inner(JoinConstraint),
LeftOuter(JoinConstraint),
RightOuter(JoinConstraint),
FullOuter(JoinConstraint),
Implicit,
Cross,
}
#[derive(Debug, Clone, PartialEq)]
pub enum JoinConstraint {
On(ASTNode),
Using(Vec<String>),
Natural,
}

View file

@ -17,10 +17,7 @@
use super::dialect::Dialect;
use super::sqlast::*;
use super::sqltokenizer::*;
use chrono::{
offset::{FixedOffset},
DateTime, NaiveDate, NaiveDateTime, NaiveTime,
};
use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime};
#[derive(Debug, Clone)]
pub enum ParserError {
@ -109,10 +106,8 @@ impl Parser {
"NULL" => {
self.prev_token();
self.parse_sql_value()
},
"CASE" => {
self.parse_case_expression()
}
"CASE" => self.parse_case_expression(),
_ => return parser_err!(format!("No prefix parser for keyword {}", k)),
},
Token::Mult => Ok(ASTNode::SQLWildcard),
@ -156,14 +151,14 @@ impl Parser {
Token::DoubleQuotedString(_) => {
self.prev_token();
self.parse_sql_value()
},
}
Token::LParen => {
let expr = self.parse();
if !self.consume_token(&Token::RParen)? {
return parser_err!(format!("expected token RParen"));
}
expr
},
}
_ => parser_err!(format!(
"Prefix parser expected a keyword but found {:?}",
t
@ -211,20 +206,20 @@ impl Parser {
if self.parse_keywords(vec!["ELSE"]) {
else_result = Some(Box::new(self.parse_expr(0)?));
if self.parse_keywords(vec!["END"]) {
break
break;
} else {
return parser_err!("Expecting END after a CASE..ELSE");
}
}
if self.parse_keywords(vec!["END"]) {
break
break;
}
self.consume_token(&Token::Keyword("WHEN".to_string()))?;
}
Ok(ASTNode::SQLCase {
conditions,
results,
else_result
else_result,
})
} else {
// TODO: implement "simple" case
@ -489,6 +484,18 @@ impl Parser {
true
}
pub fn expect_keyword(&mut self, expected: &'static str) -> Result<(), ParserError> {
if self.parse_keyword(expected) {
Ok(())
} else {
parser_err!(format!(
"Expected keyword {}, found {:?}",
expected,
self.peek_token()
))
}
}
//TODO: this function is inconsistent and sometimes returns bool and sometimes fails
/// Consume the next token if it matches the expected token, otherwise return an error
@ -751,12 +758,12 @@ impl Parser {
},
Token::Number(ref n) => match n.parse::<i64>() {
Ok(n) => {
// if let Some(Token::Minus) = self.peek_token() {
// self.prev_token();
// self.parse_timestamp_value()
// } else {
// if let Some(Token::Minus) = self.peek_token() {
// self.prev_token();
// self.parse_timestamp_value()
// } else {
Ok(Value::Long(n))
// }
// }
}
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
},
@ -1108,11 +1115,12 @@ impl Parser {
pub fn parse_select(&mut self) -> Result<ASTNode, ParserError> {
let projection = self.parse_expr_list()?;
let relation: Option<Box<ASTNode>> = if self.parse_keyword("FROM") {
//TODO: add support for JOIN
Some(Box::new(self.parse_expr(0)?))
let (relation, joins): (Option<Box<ASTNode>>, Vec<Join>) = if self.parse_keyword("FROM") {
let relation = Some(Box::new(self.parse_expr(0)?));
let joins = self.parse_joins()?;
(relation, joins)
} else {
None
(None, vec![])
};
let selection = if self.parse_keyword("WHERE") {
@ -1158,6 +1166,7 @@ impl Parser {
projection,
selection,
relation,
joins,
limit,
order_by,
group_by,
@ -1166,6 +1175,131 @@ impl Parser {
}
}
fn parse_join_constraint(&mut self, natural: bool) -> Result<JoinConstraint, ParserError> {
if natural {
Ok(JoinConstraint::Natural)
} else if self.parse_keyword("ON") {
let constraint = self.parse_expr(0)?;
Ok(JoinConstraint::On(constraint))
} else if self.parse_keyword("USING") {
if self.consume_token(&Token::LParen)? {
let attributes = self
.parse_expr_list()?
.into_iter()
.map(|ast_node| match ast_node {
ASTNode::SQLIdentifier(ident) => Ok(ident),
unexpected => {
parser_err!(format!("Expected identifier, found {:?}", unexpected))
}
})
.collect::<Result<Vec<String>, ParserError>>()?;
if self.consume_token(&Token::RParen)? {
Ok(JoinConstraint::Using(attributes))
} else {
parser_err!(format!("Expected token ')', found {:?}", self.peek_token()))
}
} else {
parser_err!(format!("Expected token '(', found {:?}", self.peek_token()))
}
} else {
parser_err!(format!(
"Unexpected token after JOIN: {:?}",
self.peek_token()
))
}
}
fn parse_joins(&mut self) -> Result<Vec<Join>, ParserError> {
let mut joins = vec![];
loop {
let natural = match &self.peek_token() {
Some(Token::Comma) => {
self.next_token();
let relation = self.parse_expr(0)?;
let join = Join {
relation,
join_operator: JoinOperator::Implicit,
};
joins.push(join);
continue;
}
Some(Token::Keyword(kw)) if kw == "CROSS" => {
self.next_token();
self.expect_keyword("JOIN")?;
let relation = self.parse_expr(0)?;
let join = Join {
relation,
join_operator: JoinOperator::Cross,
};
joins.push(join);
continue;
}
Some(Token::Keyword(kw)) if kw == "NATURAL" => {
self.next_token();
true
}
Some(_) => false,
None => return Ok(joins),
};
let join = match &self.peek_token() {
Some(Token::Keyword(kw)) if kw == "INNER" => {
self.next_token();
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?),
}
}
Some(Token::Keyword(kw)) if kw == "JOIN" => {
self.next_token();
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?),
}
}
Some(Token::Keyword(kw)) if kw == "LEFT" => {
self.next_token();
self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::LeftOuter(
self.parse_join_constraint(natural)?,
),
}
}
Some(Token::Keyword(kw)) if kw == "RIGHT" => {
self.next_token();
self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::RightOuter(
self.parse_join_constraint(natural)?,
),
}
}
Some(Token::Keyword(kw)) if kw == "FULL" => {
self.next_token();
self.parse_keyword("OUTER");
self.expect_keyword("JOIN")?;
Join {
relation: self.parse_expr(0)?,
join_operator: JoinOperator::FullOuter(
self.parse_join_constraint(natural)?,
),
}
}
_ => break,
};
joins.push(join);
}
Ok(joins)
}
/// Parse an INSERT statement
pub fn parse_insert(&mut self) -> Result<ASTNode, ParserError> {
self.parse_keyword("INTO");
@ -1215,19 +1349,17 @@ impl Parser {
// look for optional ASC / DESC specifier
let asc = match self.peek_token() {
Some(Token::Keyword(k)) => {
match k.to_uppercase().as_ref() {
Some(Token::Keyword(k)) => match k.to_uppercase().as_ref() {
"ASC" => {
self.next_token();
true
},
}
"DESC" => {
self.next_token();
false
}
_ => true,
},
_ => true
}
}
Some(Token::Comma) => true,
_ => true,
};

View file

@ -13,7 +13,9 @@ fn parse_delete_statement() {
match parse_sql(&sql) {
ASTNode::SQLDelete { relation, .. } => {
assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))),
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string()
)))),
relation
);
}
@ -36,7 +38,9 @@ fn parse_where_delete_statement() {
..
} => {
assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))),
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string()
)))),
relation
);
@ -207,7 +211,9 @@ fn parse_select_order_by_limit() {
);
let ast = parse_sql(&sql);
match ast {
ASTNode::SQLSelect { order_by, limit, .. } => {
ASTNode::SQLSelect {
order_by, limit, ..
} => {
assert_eq!(
Some(vec![
SQLOrderByExpr {
@ -341,7 +347,10 @@ fn parse_literal_string() {
let sql = "SELECT 'one'";
match parse_sql(&sql) {
ASTNode::SQLSelect { ref projection, .. } => {
assert_eq!(projection[0], ASTNode::SQLValue(Value::SingleQuotedString("one".to_string())));
assert_eq!(
projection[0],
ASTNode::SQLValue(Value::SingleQuotedString("one".to_string()))
);
}
_ => panic!(),
}
@ -392,8 +401,9 @@ fn parse_parens() {
op: Plus,
right: Box::new(SQLIdentifier("d".to_string()))
})
}
, ast);
},
ast
);
}
#[test]
@ -410,17 +420,28 @@ fn parse_case_expression() {
SQLCase {
conditions: vec![
SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))),
SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())),
op: Eq, right: Box::new(SQLValue(Value::Long(0))) },
SQLBinaryExpr { left: Box::new(SQLIdentifier("bar".to_string())),
op: GtEq, right: Box::new(SQLValue(Value::Long(0))) }
],
results: vec![SQLValue(Value::SingleQuotedString("null".to_string())),
SQLValue(Value::SingleQuotedString("=0".to_string())),
SQLValue(Value::SingleQuotedString(">=0".to_string()))],
else_result: Some(Box::new(SQLValue(Value::SingleQuotedString("<0".to_string()))))
SQLBinaryExpr {
left: Box::new(SQLIdentifier("bar".to_string())),
op: Eq,
right: Box::new(SQLValue(Value::Long(0)))
},
projection[0]);
SQLBinaryExpr {
left: Box::new(SQLIdentifier("bar".to_string())),
op: GtEq,
right: Box::new(SQLValue(Value::Long(0)))
}
],
results: vec![
SQLValue(Value::SingleQuotedString("null".to_string())),
SQLValue(Value::SingleQuotedString("=0".to_string())),
SQLValue(Value::SingleQuotedString(">=0".to_string()))
],
else_result: Some(Box::new(SQLValue(Value::SingleQuotedString(
"<0".to_string()
))))
},
projection[0]
);
}
_ => assert!(false),
}
@ -445,7 +466,9 @@ fn parse_delete_with_semi_colon() {
match parse_sql(&sql) {
ASTNode::SQLDelete { relation, .. } => {
assert_eq!(
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString("table".to_string())))),
Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"table".to_string()
)))),
relation
);
}
@ -453,12 +476,120 @@ fn parse_delete_with_semi_colon() {
}
}
#[test]
fn parse_implicit_join() {
let sql = "SELECT * FROM t1,t2";
match parse_sql(sql) {
ASTNode::SQLSelect { joins, .. } => {
assert_eq!(joins.len(), 1);
assert_eq!(
joins[0],
Join {
relation: ASTNode::SQLIdentifier("t2".to_string()),
join_operator: JoinOperator::Implicit
}
)
}
_ => assert!(false),
}
}
#[test]
fn parse_cross_join() {
let sql = "SELECT * FROM t1 CROSS JOIN t2";
match parse_sql(sql) {
ASTNode::SQLSelect { joins, .. } => {
assert_eq!(joins.len(), 1);
assert_eq!(
joins[0],
Join {
relation: ASTNode::SQLIdentifier("t2".to_string()),
join_operator: JoinOperator::Cross
}
)
}
_ => assert!(false),
}
}
#[test]
fn parse_joins_on() {
fn join_with_constraint(
relation: impl Into<String>,
f: impl Fn(JoinConstraint) -> JoinOperator,
) -> Join {
Join {
relation: ASTNode::SQLIdentifier(relation.into()),
join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr {
left: Box::new(ASTNode::SQLIdentifier("c1".into())),
op: SQLOperator::Eq,
right: Box::new(ASTNode::SQLIdentifier("c2".into())),
})),
}
}
assert_eq!(
joins_from("SELECT * FROM t1 JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::Inner)]
);
assert_eq!(
joins_from("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::LeftOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::RightOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 ON c1 = c2"),
vec![join_with_constraint("t2", JoinOperator::FullOuter)]
);
}
#[test]
fn parse_joins_using() {
fn join_with_constraint(
relation: impl Into<String>,
f: impl Fn(JoinConstraint) -> JoinOperator,
) -> Join {
Join {
relation: ASTNode::SQLIdentifier(relation.into()),
join_operator: f(JoinConstraint::Using(vec!["c1".into()])),
}
}
assert_eq!(
joins_from("SELECT * FROM t1 JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::Inner)]
);
assert_eq!(
joins_from("SELECT * FROM t1 LEFT JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::LeftOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::RightOuter)]
);
assert_eq!(
joins_from("SELECT * FROM t1 FULL OUTER JOIN t2 USING(c1)"),
vec![join_with_constraint("t2", JoinOperator::FullOuter)]
);
}
fn joins_from(sql: &str) -> Vec<Join> {
match parse_sql(sql) {
ASTNode::SQLSelect { joins, .. } => joins,
_ => panic!("Expected SELECT"),
}
}
fn parse_sql(sql: &str) -> ASTNode {
let dialect = GenericSqlDialect {};
let mut tokenizer = Tokenizer::new(&dialect,&sql, );
let mut tokenizer = Tokenizer::new(&dialect, &sql);
let tokens = tokenizer.tokenize().unwrap();
let mut parser = Parser::new(tokens);
let ast = parser.parse().unwrap();
ast
}