Modularize SQLValue into an enum

Add capability of parsing dates
This commit is contained in:
Jovansonlee Cesar 2018-09-25 15:54:29 +08:00
parent 199ec67da7
commit 9ab5c1358d
4 changed files with 86 additions and 46 deletions

View file

@ -19,3 +19,4 @@ path = "src/lib.rs"
[dependencies] [dependencies]
log = "0.4.5" log = "0.4.5"
chrono = "0.4.6"

View file

@ -37,6 +37,7 @@
#[macro_use] #[macro_use]
extern crate log; extern crate log;
extern crate chrono;
pub mod dialect; pub mod dialect;
pub mod sqlast; pub mod sqlast;

View file

@ -13,6 +13,11 @@
// limitations under the License. // limitations under the License.
//! SQL Abstract Syntax Tree (AST) types //! SQL Abstract Syntax Tree (AST) types
//!
use chrono::{NaiveDate,
NaiveDateTime,
NaiveTime,
};
/// SQL Abstract Syntax Tree (AST) /// SQL Abstract Syntax Tree (AST)
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -47,16 +52,8 @@ pub enum ASTNode {
operator: SQLOperator, operator: SQLOperator,
rex: Box<ASTNode>, rex: Box<ASTNode>,
}, },
/// Literal signed long /// SQLValue
SQLLiteralLong(i64), SQLValue(Value),
/// Literal floating point value
SQLLiteralDouble(f64),
/// Literal string
SQLLiteralString(String),
/// Boolean value true or false,
SQLBoolean(bool),
/// NULL value in insert statements,
SQLNullValue,
/// Scalar function call e.g. `LEFT(foo, 5)` /// Scalar function call e.g. `LEFT(foo, 5)`
SQLFunction { id: String, args: Vec<ASTNode> }, SQLFunction { id: String, args: Vec<ASTNode> },
/// SELECT /// SELECT
@ -91,7 +88,7 @@ pub enum ASTNode {
/// COLUMNS /// COLUMNS
columns: Vec<String>, columns: Vec<String>,
/// VALUES a vector of values to be copied /// VALUES a vector of values to be copied
values: Vec<SQLValue>, values: Vec<Value>,
}, },
/// UPDATE /// UPDATE
SQLUpdate { SQLUpdate {
@ -123,19 +120,23 @@ pub enum ASTNode {
/// SQL values such as int, double, string timestamp /// SQL values such as int, double, string timestamp
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum SQLValue{ pub enum Value{
/// Literal signed long /// Literal signed long
SQLLiteralLong(i64), Long(i64),
/// Literal floating point value /// Literal floating point value
SQLLiteralDouble(f64), Double(f64),
/// Literal string /// Literal string
SQLLiteralString(String), String(String),
/// Boolean value true or false, /// Boolean value true or false,
SQLBoolean(bool), Boolean(bool),
/// NULL value in insert statements, /// Date value
SQLNullValue, Date(NaiveDate),
// Time
Time(NaiveTime),
/// Timestamp /// Timestamp
SQLLiteralTimestamp(String), DateTime(NaiveDateTime),
/// NULL value in insert statements,
Null,
} }
/// SQL assignment `foo = expr` as used in SQLUpdate /// SQL assignment `foo = expr` as used in SQLUpdate

View file

@ -17,6 +17,10 @@
use super::dialect::Dialect; use super::dialect::Dialect;
use super::sqlast::*; use super::sqlast::*;
use super::sqltokenizer::*; use super::sqltokenizer::*;
use chrono::{NaiveDate,
NaiveDateTime,
NaiveTime,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ParserError { pub enum ParserError {
@ -94,9 +98,9 @@ impl Parser {
"DELETE" => Ok(self.parse_delete()?), "DELETE" => Ok(self.parse_delete()?),
"INSERT" => Ok(self.parse_insert()?), "INSERT" => Ok(self.parse_insert()?),
"COPY" => Ok(self.parse_copy()?), "COPY" => Ok(self.parse_copy()?),
"TRUE" => Ok(ASTNode::SQLBoolean(true)), "TRUE" => Ok(ASTNode::SQLValue(Value::Boolean(true))),
"FALSE" => Ok(ASTNode::SQLBoolean(false)), "FALSE" => Ok(ASTNode::SQLValue(Value::Boolean(false))),
"NULL" => Ok(ASTNode::SQLNullValue), "NULL" => Ok(ASTNode::SQLValue(Value::Null)),
_ => return parser_err!(format!("No prefix parser for keyword {}", k)), _ => return parser_err!(format!("No prefix parser for keyword {}", k)),
}, },
Token::Mult => Ok(ASTNode::SQLWildcard), Token::Mult => Ok(ASTNode::SQLWildcard),
@ -128,14 +132,14 @@ impl Parser {
} }
} }
Token::Number(ref n) if n.contains(".") => match n.parse::<f64>() { Token::Number(ref n) if n.contains(".") => match n.parse::<f64>() {
Ok(n) => Ok(ASTNode::SQLLiteralDouble(n)), Ok(n) => Ok(ASTNode::SQLValue(Value::Double(n))),
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
}, },
Token::Number(ref n) => match n.parse::<i64>() { Token::Number(ref n) => match n.parse::<i64>() {
Ok(n) => Ok(ASTNode::SQLLiteralLong(n)), Ok(n) => Ok(ASTNode::SQLValue(Value::Long(n))),
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
}, },
Token::String(ref s) => Ok(ASTNode::SQLLiteralString(s.to_string())), Token::String(ref s) => Ok(ASTNode::SQLValue(Value::String(s.to_string()))),
_ => parser_err!(format!( _ => parser_err!(format!(
"Prefix parser expected a keyword but found {:?}", "Prefix parser expected a keyword but found {:?}",
t t
@ -470,8 +474,8 @@ impl Parser {
/// Parse a tab separated values in /// Parse a tab separated values in
/// COPY payload /// COPY payload
fn parse_tsv(&mut self) -> Result<Vec<SQLValue>, ParserError>{ fn parse_tsv(&mut self) -> Result<Vec<Value>, ParserError>{
let mut values: Vec<SQLValue> = vec![]; let mut values: Vec<Value> = vec![];
loop { loop {
if let Ok(true) = self.consume_token(&Token::Backslash){ if let Ok(true) = self.consume_token(&Token::Backslash){
if let Ok(true) = self.consume_token(&Token::Period) { if let Ok(true) = self.consume_token(&Token::Period) {
@ -487,27 +491,33 @@ impl Parser {
} }
fn parse_sql_value(&mut self) -> Result<SQLValue, ParserError> { fn parse_sql_value(&mut self) -> Result<Value, ParserError> {
match self.next_token() { match self.next_token() {
Some(t) => { Some(t) => {
match t { match t {
Token::Keyword(k) => match k.to_uppercase().as_ref() { Token::Keyword(k) => match k.to_uppercase().as_ref() {
"TRUE" => Ok(SQLValue::SQLBoolean(true)), "TRUE" => Ok(Value::Boolean(true)),
"FALSE" => Ok(SQLValue::SQLBoolean(false)), "FALSE" => Ok(Value::Boolean(false)),
"NULL" => Ok(SQLValue::SQLNullValue), "NULL" => Ok(Value::Null),
_ => return parser_err!(format!("No value parser for keyword {}", k)), _ => return parser_err!(format!("No value parser for keyword {}", k)),
}, },
//TODO: parse the timestamp here //TODO: parse the timestamp here
Token::Number(ref n) if n.contains(".") => match n.parse::<f64>() { Token::Number(ref n) if n.contains(".") => match n.parse::<f64>() {
Ok(n) => Ok(SQLValue::SQLLiteralDouble(n)), Ok(n) => Ok(Value::Double(n)),
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
}, },
Token::Number(ref n) => match n.parse::<i64>() { Token::Number(ref n) => match n.parse::<i64>() {
Ok(n) => Ok(SQLValue::SQLLiteralLong(n)), Ok(n) => {
if let Some(Token::Minus) = self.peek_token(){
self.parse_date_or_timestamp(n)
}else{
Ok(Value::Long(n))
}
}
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
}, },
Token::Identifier(id) => Ok(SQLValue::SQLLiteralString(id.to_string())), Token::Identifier(id) => Ok(Value::String(id.to_string())),
Token::String(ref s) => Ok(SQLValue::SQLLiteralString(s.to_string())), Token::String(ref s) => Ok(Value::String(s.to_string())),
other => parser_err!(format!("Unsupported value: {:?}", self.peek_token())), other => parser_err!(format!("Unsupported value: {:?}", self.peek_token())),
} }
} }
@ -533,6 +543,34 @@ impl Parser {
} }
} }
pub fn parse_date_or_timestamp(&mut self, year: i64) -> Result<Value, ParserError> {
if let Ok(true) = self.consume_token(&Token::Minus){
let month = self.parse_literal_int()?;
if let Ok(true) = self.consume_token(&Token::Minus){
let day = self.parse_literal_int()?;
let date = NaiveDate::from_ymd(year as i32, month as u32, day as u32);
if let Ok(time) = self.parse_time(){
Ok(Value::DateTime(NaiveDateTime::new(date, time)))
}else{
Ok(Value::Date(date))
}
}else{
parser_err!(format!("Expecting `-` for date separator, found {:?}", self.peek_token()))
}
}else{
parser_err!(format!("Expecting `-` for date separator, found {:?}", self.peek_token()))
}
}
pub fn parse_time(&mut self) -> Result<NaiveTime, ParserError> {
let hour = self.parse_literal_int()?;
self.consume_token(&Token::Colon)?;
let min = self.parse_literal_int()?;
self.consume_token(&Token::Colon)?;
let sec = self.parse_literal_int()?;
Ok(NaiveTime::from_hms(hour as u32, min as u32, sec as u32))
}
/// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example)
pub fn parse_data_type(&mut self) -> Result<SQLType, ParserError> { pub fn parse_data_type(&mut self) -> Result<SQLType, ParserError> {
match self.next_token() { match self.next_token() {
@ -886,7 +924,7 @@ impl Parser {
Ok(None) Ok(None)
} else { } else {
self.parse_literal_int() self.parse_literal_int()
.map(|n| Some(Box::new(ASTNode::SQLLiteralLong(n)))) .map(|n| Some(Box::new(ASTNode::SQLValue(Value::Long(n)))))
} }
} }
} }
@ -917,7 +955,7 @@ mod tests {
match parse_sql(&sql) { match parse_sql(&sql) {
ASTNode::SQLDelete { relation, .. } => { ASTNode::SQLDelete { relation, .. } => {
assert_eq!( assert_eq!(
Some(Box::new(ASTNode::SQLLiteralString("table".to_string()))), Some(Box::new(ASTNode::SQLValue(Value::String("table".to_string())))),
relation relation
); );
} }
@ -940,7 +978,7 @@ mod tests {
.. ..
} => { } => {
assert_eq!( assert_eq!(
Some(Box::new(ASTNode::SQLLiteralString("table".to_string()))), Some(Box::new(ASTNode::SQLValue(Value::String("table".to_string())))),
relation relation
); );
@ -948,7 +986,7 @@ mod tests {
SQLBinaryExpr { SQLBinaryExpr {
left: Box::new(SQLIdentifier("name".to_string())), left: Box::new(SQLIdentifier("name".to_string())),
op: Eq, op: Eq,
right: Box::new(SQLLiteralLong(5)), right: Box::new(ASTNode::SQLValue(Value::Long(5))),
}, },
*selection.unwrap(), *selection.unwrap(),
); );
@ -967,7 +1005,7 @@ mod tests {
projection, limit, .. projection, limit, ..
} => { } => {
assert_eq!(3, projection.len()); assert_eq!(3, projection.len());
assert_eq!(Some(Box::new(ASTNode::SQLLiteralLong(5))), limit); assert_eq!(Some(Box::new(ASTNode::SQLValue(Value::Long(5)))), limit);
} }
_ => assert!(false), _ => assert!(false),
} }
@ -983,7 +1021,7 @@ mod tests {
} => { } => {
assert_eq!(table_name, "customer"); assert_eq!(table_name, "customer");
assert!(columns.is_empty()); assert!(columns.is_empty());
assert_eq!(vec![vec![ASTNode::SQLLiteralLong(1),ASTNode::SQLLiteralLong(2),ASTNode::SQLLiteralLong(3)]], values); assert_eq!(vec![vec![ASTNode::SQLValue(Value::Long(1)),ASTNode::SQLValue(Value::Long(2)),ASTNode::SQLValue(Value::Long(3))]], values);
} }
_ => assert!(false), _ => assert!(false),
} }
@ -999,7 +1037,7 @@ mod tests {
} => { } => {
assert_eq!(table_name, "public.customer"); assert_eq!(table_name, "public.customer");
assert!(columns.is_empty()); assert!(columns.is_empty());
assert_eq!(vec![vec![ASTNode::SQLLiteralLong(1),ASTNode::SQLLiteralLong(2),ASTNode::SQLLiteralLong(3)]], values); assert_eq!(vec![vec![ASTNode::SQLValue(Value::Long(1)),ASTNode::SQLValue(Value::Long(2)),ASTNode::SQLValue(Value::Long(3))]], values);
} }
_ => assert!(false), _ => assert!(false),
} }
@ -1015,7 +1053,7 @@ mod tests {
} => { } => {
assert_eq!(table_name, "db.public.customer"); assert_eq!(table_name, "db.public.customer");
assert!(columns.is_empty()); assert!(columns.is_empty());
assert_eq!(vec![vec![ASTNode::SQLLiteralLong(1),ASTNode::SQLLiteralLong(2),ASTNode::SQLLiteralLong(3)]], values); assert_eq!(vec![vec![ASTNode::SQLValue(Value::Long(1)),ASTNode::SQLValue(Value::Long(2)),ASTNode::SQLValue(Value::Long(3))]], values);
} }
_ => assert!(false), _ => assert!(false),
} }
@ -1038,7 +1076,7 @@ mod tests {
} => { } => {
assert_eq!(table_name, "public.customer"); assert_eq!(table_name, "public.customer");
assert_eq!(columns, vec!["id".to_string(), "name".to_string(), "active".to_string()]); assert_eq!(columns, vec!["id".to_string(), "name".to_string(), "active".to_string()]);
assert_eq!(vec![vec![ASTNode::SQLLiteralLong(1),ASTNode::SQLLiteralLong(2),ASTNode::SQLLiteralLong(3)]], values); assert_eq!(vec![vec![ASTNode::SQLValue(Value::Long(1)),ASTNode::SQLValue(Value::Long(2)),ASTNode::SQLValue(Value::Long(3))]], values);
} }
_ => assert!(false), _ => assert!(false),
} }
@ -1386,7 +1424,6 @@ mod tests {
let ast = parser.parse(); let ast = parser.parse();
println!("ast: {:?}", ast); println!("ast: {:?}", ast);
assert!(ast.is_ok()); assert!(ast.is_ok());
panic!();
} }
#[test] #[test]
@ -1418,7 +1455,7 @@ mod tests {
let sql = "SELECT 'one'"; let sql = "SELECT 'one'";
match parse_sql(&sql) { match parse_sql(&sql) {
ASTNode::SQLSelect { ref projection, .. } => { ASTNode::SQLSelect { ref projection, .. } => {
assert_eq!(projection[0], ASTNode::SQLLiteralString("one".to_string())); assert_eq!(projection[0], ASTNode::SQLValue(Value::String("one".to_string())));
} }
_ => panic!(), _ => panic!(),
} }