diff --git a/docs/custom_sql_parser.md b/docs/custom_sql_parser.md new file mode 100644 index 00000000..a660978c --- /dev/null +++ b/docs/custom_sql_parser.md @@ -0,0 +1,6 @@ +# Writing a Custom SQL Parser + +I have explored many different ways of building this library to make it easy to extend it for custom SQL dialects. Most of my attempts ended in failure but I have now found a workable solution. It is not without downsides but this seems to be the most pragmatic solution. + +The concept is simply to write a new parser that delegates to the ANSI parser so that as much as possible of the core functionality can be re-used. + diff --git a/src/sqlast.rs b/src/sqlast.rs index 6c40d939..5dcf5693 100644 --- a/src/sqlast.rs +++ b/src/sqlast.rs @@ -95,8 +95,8 @@ pub enum SQLType { Int, /// Big integer BigInt, - /// Floating point with precision e.g. FLOAT(8) - Float(usize), + /// Floating point with optional precision e.g. FLOAT(8) + Float(Option), /// Floating point e.g. REAL Real, /// Double e.g. DOUBLE PRECISION diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 7e11c87c..e59fee35 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -64,32 +64,23 @@ impl Parser { } /// Parse tokens until the precedence changes - fn parse_expr(&mut self, precedence: u8) -> Result { - // println!("parse_expr() precendence = {}", precedence); - + pub fn parse_expr(&mut self, precedence: u8) -> Result { let mut expr = self.parse_prefix()?; - // println!("parsed prefix: {:?}", expr); - loop { let next_precedence = self.get_next_precedence()?; if precedence >= next_precedence { - // println!("break on precedence change ({} >= {})", precedence, next_precedence); break; } if let Some(infix_expr) = self.parse_infix(expr.clone(), next_precedence)? { - // println!("parsed infix: {:?}", infix_expr); expr = infix_expr; } } - - // println!("parse_expr() returning {:?}", expr); - Ok(expr) } /// Parse an expression prefix - fn parse_prefix(&mut self) -> Result { + pub fn parse_prefix(&mut self) -> Result { match self.next_token() { Some(t) => { match t { @@ -150,7 +141,7 @@ impl Parser { } /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` - fn parse_cast_expression(&mut self) -> Result { + pub fn parse_cast_expression(&mut self) -> Result { let expr = self.parse_expr(0)?; self.consume_token(&Token::Keyword("AS".to_string()))?; let data_type = self.parse_data_type()?; @@ -162,7 +153,7 @@ impl Parser { } /// Parse an expression infix (typically an operator) - fn parse_infix( + pub fn parse_infix( &mut self, expr: ASTNode, precedence: u8, @@ -206,7 +197,7 @@ impl Parser { } /// Convert a token operator to an AST operator - fn to_sql_operator(&self, tok: &Token) -> Result { + pub fn to_sql_operator(&self, tok: &Token) -> Result { match tok { &Token::Eq => Ok(SQLOperator::Eq), &Token::Neq => Ok(SQLOperator::NotEq), @@ -226,7 +217,7 @@ impl Parser { } /// Get the precedence of the next token - fn get_next_precedence(&self) -> Result { + pub fn get_next_precedence(&self) -> Result { if self.index < self.tokens.len() { self.get_precedence(&self.tokens[self.index]) } else { @@ -235,7 +226,7 @@ impl Parser { } /// Get the precedence of a token - fn get_precedence(&self, tok: &Token) -> Result { + pub fn get_precedence(&self, tok: &Token) -> Result { //println!("get_precedence() {:?}", tok); match tok { @@ -252,7 +243,7 @@ impl Parser { } /// Peek at the next token - fn peek_token(&mut self) -> Option { + pub fn peek_token(&mut self) -> Option { if self.index < self.tokens.len() { Some(self.tokens[self.index].clone()) } else { @@ -261,7 +252,7 @@ impl Parser { } /// Get the next token and increment the token index - fn next_token(&mut self) -> Option { + pub fn next_token(&mut self) -> Option { if self.index < self.tokens.len() { self.index = self.index + 1; Some(self.tokens[self.index - 1].clone()) @@ -271,7 +262,7 @@ impl Parser { } /// Get the previous token and decrement the token index - fn prev_token(&mut self) -> Option { + pub fn prev_token(&mut self) -> Option { if self.index > 0 { Some(self.tokens[self.index - 1].clone()) } else { @@ -280,7 +271,7 @@ impl Parser { } /// Look for an expected keyword and consume it if it exists - fn parse_keyword(&mut self, expected: &'static str) -> bool { + pub fn parse_keyword(&mut self, expected: &'static str) -> bool { match self.peek_token() { Some(Token::Keyword(k)) => { if expected.eq_ignore_ascii_case(k.as_str()) { @@ -295,7 +286,7 @@ impl Parser { } /// Look for an expected sequence of keywords and consume them if they exist - fn parse_keywords(&mut self, keywords: Vec<&'static str>) -> bool { + pub fn parse_keywords(&mut self, keywords: Vec<&'static str>) -> bool { let index = self.index; for keyword in keywords { //println!("parse_keywords aborting .. expecting {}", keyword); @@ -312,7 +303,7 @@ impl Parser { //TODO: this function is inconsistent and sometimes returns bool and sometimes fails /// Consume the next token if it matches the expected token, otherwise return an error - fn consume_token(&mut self, expected: &Token) -> Result { + pub fn consume_token(&mut self, expected: &Token) -> Result { match self.peek_token() { Some(ref t) => if *t == *expected { self.next_token(); @@ -329,7 +320,7 @@ impl Parser { } /// Parse a SQL CREATE statement - fn parse_create(&mut self) -> Result { + pub fn parse_create(&mut self) -> Result { if self.parse_keywords(vec!["TABLE"]) { match self.next_token() { Some(Token::Identifier(id)) => { @@ -398,7 +389,7 @@ impl Parser { } /// Parse a literal integer/long - fn parse_literal_int(&mut self) -> Result { + pub fn parse_literal_int(&mut self) -> Result { match self.next_token() { Some(Token::Number(s)) => s.parse::().map_err(|e| { ParserError::ParserError(format!("Could not parse '{}' as i64: {}", s, e)) @@ -408,19 +399,19 @@ impl Parser { } /// Parse a literal string - // fn parse_literal_string(&mut self) -> Result { - // match self.next_token() { - // Some(Token::String(ref s)) => Ok(s.clone()), - // other => parser_err!(format!("Expected literal string, found {:?}", other)), - // } - // } + pub fn parse_literal_string(&mut self) -> Result { + match self.next_token() { + Some(Token::String(ref s)) => Ok(s.clone()), + other => parser_err!(format!("Expected literal string, found {:?}", other)), + } + } /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) - fn parse_data_type(&mut self) -> Result { + pub fn parse_data_type(&mut self) -> Result { match self.next_token() { Some(Token::Keyword(k)) => match k.to_uppercase().as_ref() { "BOOLEAN" => Ok(SQLType::Boolean), - "FLOAT" => Ok(SQLType::Float(self.parse_precision()?)), + "FLOAT" => Ok(SQLType::Float(self.parse_optional_precision()?)), "REAL" => Ok(SQLType::Real), "DOUBLE" => Ok(SQLType::Double), "SMALLINT" => Ok(SQLType::SmallInt), @@ -433,12 +424,12 @@ impl Parser { } } - fn parse_precision(&mut self) -> Result { + pub fn parse_precision(&mut self) -> Result { //TODO: error handling Ok(self.parse_optional_precision()?.unwrap()) } - fn parse_optional_precision(&mut self) -> Result, ParserError> { + pub fn parse_optional_precision(&mut self) -> Result, ParserError> { if self.consume_token(&Token::LParen)? { let n = self.parse_literal_int()?; //TODO: check return value of reading rparen @@ -450,7 +441,7 @@ impl Parser { } /// Parse a SELECT statement - fn parse_select(&mut self) -> Result { + pub fn parse_select(&mut self) -> Result { let projection = self.parse_expr_list()?; let relation: Option> = if self.parse_keyword("FROM") { @@ -509,7 +500,7 @@ impl Parser { } /// Parse a comma-delimited list of SQL expressions - fn parse_expr_list(&mut self) -> Result, ParserError> { + pub fn parse_expr_list(&mut self) -> Result, ParserError> { let mut expr_list: Vec = vec![]; loop { expr_list.push(self.parse_expr(0)?); @@ -528,7 +519,7 @@ impl Parser { } /// Parse a comma-delimited list of SQL ORDER BY expressions - fn parse_order_by_expr_list(&mut self) -> Result, ParserError> { + pub fn parse_order_by_expr_list(&mut self) -> Result, ParserError> { let mut expr_list: Vec = vec![]; loop { let expr = self.parse_expr(0)?; @@ -575,7 +566,7 @@ impl Parser { } /// Parse a LIMIT clause - fn parse_limit(&mut self) -> Result>, ParserError> { + pub fn parse_limit(&mut self) -> Result>, ParserError> { if self.parse_keyword("ALL") { Ok(None) } else { @@ -845,6 +836,21 @@ mod tests { //TODO: assertions } + #[test] + fn parse_literal_string() { + let sql = "SELECT 'one'"; + match parse_sql(&sql) { + ASTNode::SQLSelect { ref projection, .. } => { + assert_eq!( + projection[0], + ASTNode::SQLLiteralString("one".to_string()) + ); + } + _ => panic!(), + } + + } + #[test] fn parse_select_version() { let sql = "SELECT @@version";