From 20079959389ffd9ee2f55d72ee766b12e7050b45 Mon Sep 17 00:00:00 2001 From: Jovansonlee Cesar Date: Mon, 24 Sep 2018 03:34:40 +0800 Subject: [PATCH] Improve the create statement parser that uses create statements from pg database dump Added PostgreSQL style casting --- Cargo.toml | 1 + src/dialect.rs | 6 + src/lib.rs | 3 + src/sqlast.rs | 15 +- src/sqlparser.rs | 423 +++++++++++++++++++++++++++++++++++--------- src/sqltokenizer.rs | 25 +++ 6 files changed, 392 insertions(+), 81 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a53021c1..00121fa7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ name = "sqlparser" path = "src/lib.rs" [dependencies] +log = "0.4.5" diff --git a/src/dialect.rs b/src/dialect.rs index 1c6ce3e5..67496046 100644 --- a/src/dialect.rs +++ b/src/dialect.rs @@ -432,6 +432,12 @@ impl Dialect for GenericSqlDialect { "DATE", "TIME", "TIMESTAMP", + "VALUES", + "DEFAULT", + "ZONE", + "REGCLASS", + "TEXT", + "BYTEA", ]; } diff --git a/src/lib.rs b/src/lib.rs index 2c9f9883..20cd3711 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,9 @@ //! println!("AST: {:?}", ast); //! ``` +#[macro_use] +extern crate log; + pub mod dialect; pub mod sqlast; pub mod sqlparser; diff --git a/src/sqlast.rs b/src/sqlast.rs index cb8b03ba..eeb770a8 100644 --- a/src/sqlast.rs +++ b/src/sqlast.rs @@ -135,15 +135,16 @@ pub struct SQLColumnDef { pub name: String, pub data_type: SQLType, pub allow_null: bool, + pub default: Option>, } /// SQL datatypes for literals in SQL statements #[derive(Debug, Clone, PartialEq)] pub enum SQLType { /// Fixed-length character type e.g. CHAR(10) - Char(usize), + Char(Option), /// Variable-length character type e.g. VARCHAR(10) - Varchar(usize), + Varchar(Option), /// Large character object e.g. CLOB(1000) Clob(usize), /// Fixed-length binary type e.g. BINARY(10) @@ -174,6 +175,16 @@ pub enum SQLType { Time, /// Timestamp Timestamp, + /// Regclass used in postgresql serial + Regclass, + /// Text + Text, + /// Bytea + Bytea, + /// Custom type such as enums + Custom(String), + /// Arrays + Array(Box), } /// SQL Operator diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 04848e61..78c6e92a 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -66,9 +66,12 @@ impl Parser { /// Parse tokens until the precedence changes pub fn parse_expr(&mut self, precedence: u8) -> Result { + debug!("parsing expr"); let mut expr = self.parse_prefix()?; + debug!("prefix: {:?}", expr); loop { let next_precedence = self.get_next_precedence()?; + debug!("next precedence: {:?}", next_precedence); if precedence >= next_precedence { break; } @@ -93,34 +96,30 @@ impl Parser { }, Token::Mult => Ok(ASTNode::SQLWildcard), Token::Identifier(id) => { - match self.peek_token() { - Some(Token::LParen) => { - self.next_token(); // skip lparen - match id.to_uppercase().as_ref() { - "CAST" => self.parse_cast_expression(), - _ => { - let args = self.parse_expr_list()?; - self.next_token(); // skip rparen - Ok(ASTNode::SQLFunction { id, args }) - } + if "CAST" == id.to_uppercase(){ + self.parse_cast_expression() + }else{ + match self.peek_token() { + Some(Token::LParen) => { + self.parse_function_or_pg_cast(&id) } - } - Some(Token::Period) => { - let mut id_parts: Vec = vec![id]; - while self.peek_token() == Some(Token::Period) { - self.consume_token(&Token::Period)?; - match self.next_token() { - Some(Token::Identifier(id)) => id_parts.push(id), - _ => { - return parser_err!(format!( - "Error parsing compound identifier" - )) + Some(Token::Period) => { + let mut id_parts: Vec = vec![id]; + while self.peek_token() == Some(Token::Period) { + self.consume_token(&Token::Period)?; + match self.next_token() { + Some(Token::Identifier(id)) => id_parts.push(id), + _ => { + return parser_err!(format!( + "Error parsing compound identifier" + )) + } } } + Ok(ASTNode::SQLCompoundIdentifier(id_parts)) } - Ok(ASTNode::SQLCompoundIdentifier(id_parts)) + _ => Ok(ASTNode::SQLIdentifier(id)), } - _ => Ok(ASTNode::SQLIdentifier(id)), } } Token::Number(ref n) if n.contains(".") => match n.parse::() { @@ -142,8 +141,31 @@ impl Parser { } } + pub fn parse_function_or_pg_cast(&mut self, id: &str) -> Result { + let func = self.parse_function(&id)?; + println!("func: {:?}", func); + if let Some(Token::DoubleColon) = self.peek_token(){ + self.parse_pg_cast(func) + }else{ + Ok(func) + } + } + + pub fn parse_function(&mut self, id: &str) -> Result { + self.consume_token(&Token::LParen)?; + if let Ok(true) = self.consume_token(&Token::RParen){ + Ok(ASTNode::SQLFunction { id: id.to_string(), args: vec![] }) + }else{ + let args = self.parse_expr_list()?; + self.consume_token(&Token::RParen)?; + Ok(ASTNode::SQLFunction { id: id.to_string(), args }) + } + } + /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expression(&mut self) -> Result { + println!("parsing cast"); + self.consume_token(&Token::LParen)?; let expr = self.parse_expr(0)?; self.consume_token(&Token::Keyword("AS".to_string()))?; let data_type = self.parse_data_type()?; @@ -154,12 +176,35 @@ impl Parser { }) } + + /// Parse a postgresql casting style which is in the form or expr::datatype + pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result { + let ast = self.consume_token(&Token::DoubleColon)?; + let datatype = if let Ok(data_type) = self.parse_data_type(){ + Ok(data_type) + }else if let Ok(table_name) = self.parse_tablename(){ + Ok(SQLType::Custom(table_name)) + }else{ + parser_err!("Expecting datatype or identifier") + }; + let pg_cast = ASTNode::SQLCast{ + expr: Box::new(expr), + data_type: datatype?, + }; + if let Some(Token::DoubleColon) = self.peek_token(){ + self.parse_pg_cast(pg_cast) + }else{ + Ok(pg_cast) + } + } + /// Parse an expression infix (typically an operator) pub fn parse_infix( &mut self, expr: ASTNode, precedence: u8, ) -> Result, ParserError> { + debug!("parsing infix"); match self.next_token() { Some(tok) => match tok { Token::Keyword(ref k) => if k == "IS" { @@ -192,6 +237,10 @@ impl Parser { op: self.to_sql_operator(&tok)?, right: Box::new(self.parse_expr(precedence)?), })), + | Token::DoubleColon => { + let pg_cast = self.parse_pg_cast(expr)?; + Ok(Some(pg_cast)) + }, _ => parser_err!(format!("No infix parser for token {:?}", tok)), }, None => Ok(None), @@ -229,7 +278,7 @@ impl Parser { /// Get the precedence of a token pub fn get_precedence(&self, tok: &Token) -> Result { - //println!("get_precedence() {:?}", tok); + debug!("get_precedence() {:?}", tok); match tok { &Token::Keyword(ref k) if k == "OR" => Ok(5), @@ -240,6 +289,7 @@ impl Parser { } &Token::Plus | &Token::Minus => Ok(30), &Token::Mult | &Token::Div | &Token::Mod => Ok(40), + &Token::DoubleColon => Ok(50), _ => Ok(0), } } @@ -325,64 +375,70 @@ impl Parser { /// Parse a SQL CREATE statement pub fn parse_create(&mut self) -> Result { if self.parse_keywords(vec!["TABLE"]) { - match self.next_token() { - Some(Token::Identifier(id)) => { - // parse optional column list (schema) - let mut columns = vec![]; - if self.consume_token(&Token::LParen)? { - loop { - if let Some(Token::Identifier(column_name)) = self.next_token() { - if let Ok(data_type) = self.parse_data_type() { - let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { - false - } else if self.parse_keyword("NULL") { - true - } else { - true - }; + let table_name = self.parse_tablename()?; + // parse optional column list (schema) + println!("table_name: {}", table_name); + let mut columns = vec![]; + if self.consume_token(&Token::LParen)? { + loop { + if let Some(Token::Identifier(column_name)) = self.next_token() { + println!("column name: {}", column_name); + if let Ok(data_type) = self.parse_data_type() { + let default = if self.parse_keyword("DEFAULT"){ + self.consume_token(&Token::LParen); + let expr = self.parse_expr(0)?; + self.consume_token(&Token::RParen); + Some(Box::new(expr)) + }else{ + None + }; + println!("default: {:?}", default); + let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { + false + } else if self.parse_keyword("NULL") { + true + } else { + true + }; + debug!("default: {:?}", default); - match self.peek_token() { - Some(Token::Comma) => { - self.next_token(); - columns.push(SQLColumnDef { - name: column_name, - data_type: data_type, - allow_null, - }); - } - Some(Token::RParen) => { - self.next_token(); - columns.push(SQLColumnDef { - name: column_name, - data_type: data_type, - allow_null, - }); - break; - } - _ => { - return parser_err!( - "Expected ',' or ')' after column definition" - ); - } - } - } else { + match self.peek_token() { + Some(Token::Comma) => { + self.next_token(); + columns.push(SQLColumnDef { + name: column_name, + data_type: data_type, + allow_null, + default, + }); + } + Some(Token::RParen) => { + self.next_token(); + columns.push(SQLColumnDef { + name: column_name, + data_type: data_type, + allow_null, + default, + }); + break; + } + other => { return parser_err!( - "Error parsing data type in column definition" + format!("Expected ',' or ')' after column definition but found {:?}", other) ); } - } else { - return parser_err!("Error parsing column name"); } + } else { + return parser_err!( + format!("Error parsing data type in column definition near: {:?}", self.peek_token()) + ); } + } else { + return parser_err!("Error parsing column name"); } - - Ok(ASTNode::SQLCreateTable { name: id, columns }) } - _ => parser_err!(format!( - "Unexpected token after CREATE EXTERNAL TABLE: {:?}", - self.peek_token() - )), } + Ok(ASTNode::SQLCreateTable { name: table_name, columns }) } else { parser_err!(format!( "Unexpected token after CREATE: {:?}", @@ -420,13 +476,111 @@ impl Parser { "SMALLINT" => Ok(SQLType::SmallInt), "INT" | "INTEGER" => Ok(SQLType::Int), "BIGINT" => Ok(SQLType::BigInt), - "VARCHAR" => Ok(SQLType::Varchar(self.parse_precision()?)), + "VARCHAR" => Ok(SQLType::Varchar(self.parse_optional_precision()?)), + "CHARACTER" => { + if self.parse_keyword("VARYING"){ + Ok(SQLType::Varchar(self.parse_optional_precision()?)) + }else{ + Ok(SQLType::Char(self.parse_optional_precision()?)) + } + } + "DATE" => Ok(SQLType::Date), + "TIMESTAMP" => if self.parse_keyword("WITH"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Timestamp) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else if self.parse_keyword("WITHOUT"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Timestamp) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else{ + Ok(SQLType::Timestamp) + } + "REGCLASS" => Ok(SQLType::Regclass), + "TEXT" => { + if let Ok(true) = self.consume_token(&Token::LBracket){ + self.consume_token(&Token::RBracket)?; + Ok(SQLType::Array(Box::new(SQLType::Text))) + }else{ + Ok(SQLType::Text) + } + } + "BYTEA" => Ok(SQLType::Bytea), + "NUMERIC" => { + let (precision, scale) = self.parse_optional_precision_scale()?; + Ok(SQLType::Decimal(precision, scale)) + } _ => parser_err!(format!("Invalid data type '{:?}'", k)), }, + Some(Token::Identifier(id)) => { + if let Ok(true) = self.consume_token(&Token::Period) { + let ids = self.parse_tablename()?; + Ok(SQLType::Custom(format!("{}.{}",id,ids))) + }else{ + Ok(SQLType::Custom(id)) + } + } other => parser_err!(format!("Invalid data type: '{:?}'", other)), } } + + pub fn parse_compound_identifier(&mut self, separator: &Token) -> Result { + let mut idents = vec![]; + let mut expect_identifier = true; + loop { + let token = &self.next_token(); + match token{ + Some(token) => match token{ + Token::Identifier(s) => if expect_identifier{ + expect_identifier = false; + idents.push(s.to_string()); + }else{ + self.prev_token(); + break; + } + token if token == separator => { + if expect_identifier{ + return parser_err!(format!("Expecting identifier, found {:?}", token)); + }else{ + expect_identifier = true; + continue; + } + } + _ => { + self.prev_token(); + break; + } + } + None => { + self.prev_token(); + break; + } + } + } + Ok(ASTNode::SQLCompoundIdentifier(idents)) + } + + pub fn parse_tablename(&mut self) -> Result { + let identifier = self.parse_compound_identifier(&Token::Period)?; + match identifier{ + ASTNode::SQLCompoundIdentifier(idents) => Ok(idents.join(".")), + other => parser_err!(format!("Expecting compound identifier, found: {:?}", other)), + } + } + + pub fn parse_column_names(&mut self) -> Result, ParserError> { + let identifier = self.parse_compound_identifier(&Token::Comma)?; + match identifier{ + ASTNode::SQLCompoundIdentifier(idents) => Ok(idents), + other => parser_err!(format!("Expecting compound identifier, found: {:?}", other)), + } + } + pub fn parse_precision(&mut self) -> Result { //TODO: error handling Ok(self.parse_optional_precision()?.unwrap()) @@ -443,6 +597,21 @@ impl Parser { } } + pub fn parse_optional_precision_scale(&mut self) -> Result<(usize, Option), ParserError> { + if self.consume_token(&Token::LParen)? { + let n = self.parse_literal_int()?; + let scale = if let Ok(true) = self.consume_token(&Token::Comma){ + Some(self.parse_literal_int()? as usize) + }else{ + None + }; + self.consume_token(&Token::RParen)?; + Ok((n as usize, scale)) + } else { + parser_err!("Expecting `(`") + } + } + pub fn parse_delete(&mut self) -> Result { let relation: Option> = if self.parse_keyword("FROM") { Some(Box::new(self.parse_expr(0)?)) @@ -898,7 +1067,7 @@ mod tests { let c_name = &columns[0]; assert_eq!("name", c_name.name); - assert_eq!(SQLType::Varchar(100), c_name.data_type); + assert_eq!(SQLType::Varchar(Some(100)), c_name.data_type); assert_eq!(false, c_name.allow_null); let c_lat = &columns[1]; @@ -915,6 +1084,86 @@ mod tests { } } + #[test] + fn parse_create_table_with_defaults() { + let sql = String::from( + "CREATE TABLE public.customer ( + customer_id integer DEFAULT nextval(public.customer_customer_id_seq) NOT NULL, + store_id smallint NOT NULL, + first_name character varying(45) NOT NULL, + last_name character varying(45) NOT NULL, + email character varying(50), + address_id smallint NOT NULL, + activebool boolean DEFAULT true NOT NULL, + create_date date DEFAULT now()::text NOT NULL, + last_update timestamp without time zone DEFAULT now() NOT NULL, + active integer NOT NULL)"); + let ast = parse_sql(&sql); + match ast { + ASTNode::SQLCreateTable { name, columns } => { + assert_eq!("public.customer", name); + assert_eq!(10, columns.len()); + + let c_name = &columns[0]; + assert_eq!("customer_id", c_name.name); + assert_eq!(SQLType::Int, c_name.data_type); + assert_eq!(false, c_name.allow_null); + + let c_lat = &columns[1]; + assert_eq!("store_id", c_lat.name); + assert_eq!(SQLType::SmallInt, c_lat.data_type); + assert_eq!(false, c_lat.allow_null); + + let c_lng = &columns[2]; + assert_eq!("first_name", c_lng.name); + assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type); + assert_eq!(false, c_lng.allow_null); + } + _ => assert!(false), + } + } + + #[test] + fn parse_create_table_from_pg_dump() { + let sql = String::from(" + CREATE TABLE public.customer ( + customer_id integer DEFAULT nextval('public.customer_customer_id_seq'::regclass) NOT NULL, + store_id smallint NOT NULL, + first_name character varying(45) NOT NULL, + last_name character varying(45) NOT NULL, + info text[], + address_id smallint NOT NULL, + activebool boolean DEFAULT true NOT NULL, + create_date date DEFAULT now()::date NOT NULL, + create_date1 date DEFAULT 'now'::text::date NOT NULL, + last_update timestamp without time zone DEFAULT now(), + release_year public.year, + active integer + )"); + let ast = parse_sql(&sql); + match ast { + ASTNode::SQLCreateTable { name, columns } => { + assert_eq!("public.customer", name); + + let c_name = &columns[0]; + assert_eq!("customer_id", c_name.name); + assert_eq!(SQLType::Int, c_name.data_type); + assert_eq!(false, c_name.allow_null); + + let c_lat = &columns[1]; + assert_eq!("store_id", c_lat.name); + assert_eq!(SQLType::SmallInt, c_lat.data_type); + assert_eq!(false, c_lat.allow_null); + + let c_lng = &columns[2]; + assert_eq!("first_name", c_lng.name); + assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type); + assert_eq!(false, c_lng.allow_null); + } + _ => assert!(false), + } + } + #[test] fn parse_scalar_function_in_projection() { let sql = String::from("SELECT sqrt(id) FROM foo"); @@ -964,13 +1213,29 @@ mod tests { } } + #[test] + fn parse_function_now(){ + let sql = "now()"; + let mut parser = parser(sql); + let ast = parser.parse(); + println!("ast: {:?}", ast); + assert!(ast.is_ok()); + } + fn parse_sql(sql: &str) -> ASTNode { - let dialect = GenericSqlDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); - let tokens = tokenizer.tokenize().unwrap(); - let mut parser = Parser::new(tokens); + debug!("sql: {}", sql); + println!("sql: {}", sql); + let mut parser = parser(sql); let ast = parser.parse().unwrap(); ast } + fn parser(sql: &str) -> Parser { + let dialect = GenericSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize().unwrap(); + debug!("tokens: {:#?}", tokens); + Parser::new(tokens) + } + } diff --git a/src/sqltokenizer.rs b/src/sqltokenizer.rs index da675285..28435d0f 100644 --- a/src/sqltokenizer.rs +++ b/src/sqltokenizer.rs @@ -66,6 +66,14 @@ pub enum Token { RParen, /// Period (used for compound identifiers or projections into nested types) Period, + /// Colon `:` + Colon, + /// DoubleColon `::` (used for casting in postgresql) + DoubleColon, + /// Left bracket `[` + LBracket, + /// Right bracket `]` + RBracket, } /// Tokenizer error @@ -243,6 +251,23 @@ impl<'a> Tokenizer<'a> { None => Ok(Some(Token::Gt)), } } + // colon + ':' => { + chars.next(); + match chars.peek() { + Some(&ch) => match ch { + // double colon + ':' => { + self.consume_and_return(chars, Token::DoubleColon) + } + _ => Ok(Some(Token::Colon)), + }, + None => Ok(Some(Token::Colon)), + } + } + // brakets + '[' => self.consume_and_return(chars, Token::LBracket), + ']' => self.consume_and_return(chars, Token::RBracket), _ => Err(TokenizerError(format!( "Tokenizer Error at Line: {}, Column: {}, unhandled char '{}'", self.line, self.col, ch