diff --git a/Cargo.toml b/Cargo.toml index a53021c1..00121fa7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ name = "sqlparser" path = "src/lib.rs" [dependencies] +log = "0.4.5" diff --git a/src/dialect.rs b/src/dialect.rs index c263befd..9af842f8 100644 --- a/src/dialect.rs +++ b/src/dialect.rs @@ -434,6 +434,13 @@ impl Dialect for GenericSqlDialect { "TIME", "TIMESTAMP", "VALUES", + "DEFAULT", + "ZONE", + "REGCLASS", + "TEXT", + "BYTEA", + "TRUE", + "FALSE", ]; } diff --git a/src/lib.rs b/src/lib.rs index 2c9f9883..20cd3711 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,9 @@ //! println!("AST: {:?}", ast); //! ``` +#[macro_use] +extern crate log; + pub mod dialect; pub mod sqlast; pub mod sqlparser; diff --git a/src/sqlast.rs b/src/sqlast.rs index cb8b03ba..1e645262 100644 --- a/src/sqlast.rs +++ b/src/sqlast.rs @@ -53,6 +53,10 @@ pub enum ASTNode { SQLLiteralDouble(f64), /// Literal string SQLLiteralString(String), + /// Boolean value true or false, + SQLBoolean(bool), + /// NULL value in insert statements, + SQLNullValue, /// Scalar function call e.g. `LEFT(foo, 5)` SQLFunction { id: String, args: Vec }, /// SELECT @@ -135,15 +139,16 @@ pub struct SQLColumnDef { pub name: String, pub data_type: SQLType, pub allow_null: bool, + pub default: Option>, } /// SQL datatypes for literals in SQL statements #[derive(Debug, Clone, PartialEq)] pub enum SQLType { /// Fixed-length character type e.g. CHAR(10) - Char(usize), + Char(Option), /// Variable-length character type e.g. VARCHAR(10) - Varchar(usize), + Varchar(Option), /// Large character object e.g. CLOB(1000) Clob(usize), /// Fixed-length binary type e.g. BINARY(10) @@ -174,6 +179,16 @@ pub enum SQLType { Time, /// Timestamp Timestamp, + /// Regclass used in postgresql serial + Regclass, + /// Text + Text, + /// Bytea + Bytea, + /// Custom type such as enums + Custom(String), + /// Arrays + Array(Box), } /// SQL Operator diff --git a/src/sqlparser.rs b/src/sqlparser.rs index fa64624e..ae7e55cc 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -66,9 +66,12 @@ impl Parser { /// Parse tokens until the precedence changes pub fn parse_expr(&mut self, precedence: u8) -> Result { + debug!("parsing expr"); let mut expr = self.parse_prefix()?; + debug!("prefix: {:?}", expr); loop { let next_precedence = self.get_next_precedence()?; + debug!("next precedence: {:?}", next_precedence); if precedence >= next_precedence { break; } @@ -90,38 +93,37 @@ impl Parser { "CREATE" => Ok(self.parse_create()?), "DELETE" => Ok(self.parse_delete()?), "INSERT" => Ok(self.parse_insert()?), + "TRUE" => Ok(ASTNode::SQLBoolean(true)), + "FALSE" => Ok(ASTNode::SQLBoolean(false)), + "NULL" => Ok(ASTNode::SQLNullValue), _ => return parser_err!(format!("No prefix parser for keyword {}", k)), }, Token::Mult => Ok(ASTNode::SQLWildcard), Token::Identifier(id) => { - match self.peek_token() { - Some(Token::LParen) => { - self.next_token(); // skip lparen - match id.to_uppercase().as_ref() { - "CAST" => self.parse_cast_expression(), - _ => { - let args = self.parse_expr_list()?; - self.next_token(); // skip rparen - Ok(ASTNode::SQLFunction { id, args }) - } + if "CAST" == id.to_uppercase(){ + self.parse_cast_expression() + }else{ + match self.peek_token() { + Some(Token::LParen) => { + self.parse_function_or_pg_cast(&id) } - } - Some(Token::Period) => { - let mut id_parts: Vec = vec![id]; - while self.peek_token() == Some(Token::Period) { - self.consume_token(&Token::Period)?; - match self.next_token() { - Some(Token::Identifier(id)) => id_parts.push(id), - _ => { - return parser_err!(format!( - "Error parsing compound identifier" - )) + Some(Token::Period) => { + let mut id_parts: Vec = vec![id]; + while self.peek_token() == Some(Token::Period) { + self.consume_token(&Token::Period)?; + match self.next_token() { + Some(Token::Identifier(id)) => id_parts.push(id), + _ => { + return parser_err!(format!( + "Error parsing compound identifier" + )) + } } } + Ok(ASTNode::SQLCompoundIdentifier(id_parts)) } - Ok(ASTNode::SQLCompoundIdentifier(id_parts)) + _ => Ok(ASTNode::SQLIdentifier(id)), } - _ => Ok(ASTNode::SQLIdentifier(id)), } } Token::Number(ref n) if n.contains(".") => match n.parse::() { @@ -143,8 +145,31 @@ impl Parser { } } + pub fn parse_function_or_pg_cast(&mut self, id: &str) -> Result { + let func = self.parse_function(&id)?; + println!("func: {:?}", func); + if let Some(Token::DoubleColon) = self.peek_token(){ + self.parse_pg_cast(func) + }else{ + Ok(func) + } + } + + pub fn parse_function(&mut self, id: &str) -> Result { + self.consume_token(&Token::LParen)?; + if let Ok(true) = self.consume_token(&Token::RParen){ + Ok(ASTNode::SQLFunction { id: id.to_string(), args: vec![] }) + }else{ + let args = self.parse_expr_list()?; + self.consume_token(&Token::RParen)?; + Ok(ASTNode::SQLFunction { id: id.to_string(), args }) + } + } + /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expression(&mut self) -> Result { + println!("parsing cast"); + self.consume_token(&Token::LParen)?; let expr = self.parse_expr(0)?; self.consume_token(&Token::Keyword("AS".to_string()))?; let data_type = self.parse_data_type()?; @@ -155,12 +180,35 @@ impl Parser { }) } + + /// Parse a postgresql casting style which is in the form or expr::datatype + pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result { + let ast = self.consume_token(&Token::DoubleColon)?; + let datatype = if let Ok(data_type) = self.parse_data_type(){ + Ok(data_type) + }else if let Ok(table_name) = self.parse_tablename(){ + Ok(SQLType::Custom(table_name)) + }else{ + parser_err!("Expecting datatype or identifier") + }; + let pg_cast = ASTNode::SQLCast{ + expr: Box::new(expr), + data_type: datatype?, + }; + if let Some(Token::DoubleColon) = self.peek_token(){ + self.parse_pg_cast(pg_cast) + }else{ + Ok(pg_cast) + } + } + /// Parse an expression infix (typically an operator) pub fn parse_infix( &mut self, expr: ASTNode, precedence: u8, ) -> Result, ParserError> { + debug!("parsing infix"); match self.next_token() { Some(tok) => match tok { Token::Keyword(ref k) => if k == "IS" { @@ -193,6 +241,10 @@ impl Parser { op: self.to_sql_operator(&tok)?, right: Box::new(self.parse_expr(precedence)?), })), + | Token::DoubleColon => { + let pg_cast = self.parse_pg_cast(expr)?; + Ok(Some(pg_cast)) + }, _ => parser_err!(format!("No infix parser for token {:?}", tok)), }, None => Ok(None), @@ -230,7 +282,7 @@ impl Parser { /// Get the precedence of a token pub fn get_precedence(&self, tok: &Token) -> Result { - //println!("get_precedence() {:?}", tok); + debug!("get_precedence() {:?}", tok); match tok { &Token::Keyword(ref k) if k == "OR" => Ok(5), @@ -241,6 +293,7 @@ impl Parser { } &Token::Plus | &Token::Minus => Ok(30), &Token::Mult | &Token::Div | &Token::Mod => Ok(40), + &Token::DoubleColon => Ok(50), _ => Ok(0), } } @@ -326,64 +379,69 @@ impl Parser { /// Parse a SQL CREATE statement pub fn parse_create(&mut self) -> Result { if self.parse_keywords(vec!["TABLE"]) { - match self.next_token() { - Some(Token::Identifier(id)) => { - // parse optional column list (schema) - let mut columns = vec![]; - if self.consume_token(&Token::LParen)? { - loop { - if let Some(Token::Identifier(column_name)) = self.next_token() { - if let Ok(data_type) = self.parse_data_type() { - let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { - false - } else if self.parse_keyword("NULL") { - true - } else { - true - }; + let table_name = self.parse_tablename()?; + // parse optional column list (schema) + println!("table_name: {}", table_name); + let mut columns = vec![]; + if self.consume_token(&Token::LParen)? { + loop { + if let Some(Token::Identifier(column_name)) = self.next_token() { + println!("column name: {}", column_name); + if let Ok(data_type) = self.parse_data_type() { + let default = if self.parse_keyword("DEFAULT"){ + let expr = self.parse_expr(0)?; + println!("expr: {:?}", expr); + Some(Box::new(expr)) + }else{ + None + }; + println!("default: {:?}", default); + let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { + false + } else if self.parse_keyword("NULL") { + true + } else { + true + }; + debug!("default: {:?}", default); - match self.peek_token() { - Some(Token::Comma) => { - self.next_token(); - columns.push(SQLColumnDef { - name: column_name, - data_type: data_type, - allow_null, - }); - } - Some(Token::RParen) => { - self.next_token(); - columns.push(SQLColumnDef { - name: column_name, - data_type: data_type, - allow_null, - }); - break; - } - _ => { - return parser_err!( - "Expected ',' or ')' after column definition" - ); - } - } - } else { + match self.peek_token() { + Some(Token::Comma) => { + self.next_token(); + columns.push(SQLColumnDef { + name: column_name, + data_type: data_type, + allow_null, + default, + }); + } + Some(Token::RParen) => { + self.next_token(); + columns.push(SQLColumnDef { + name: column_name, + data_type: data_type, + allow_null, + default, + }); + break; + } + other => { return parser_err!( - "Error parsing data type in column definition" + format!("Expected ',' or ')' after column definition but found {:?}", other) ); } - } else { - return parser_err!("Error parsing column name"); } + } else { + return parser_err!( + format!("Error parsing data type in column definition near: {:?}", self.peek_token()) + ); } + } else { + return parser_err!("Error parsing column name"); } - - Ok(ASTNode::SQLCreateTable { name: id, columns }) } - _ => parser_err!(format!( - "Unexpected token after CREATE EXTERNAL TABLE: {:?}", - self.peek_token() - )), } + Ok(ASTNode::SQLCreateTable { name: table_name, columns }) } else { parser_err!(format!( "Unexpected token after CREATE: {:?}", @@ -417,13 +475,77 @@ impl Parser { "BOOLEAN" => Ok(SQLType::Boolean), "FLOAT" => Ok(SQLType::Float(self.parse_optional_precision()?)), "REAL" => Ok(SQLType::Real), - "DOUBLE" => Ok(SQLType::Double), + "DOUBLE" => if self.parse_keyword("PRECISION"){ + Ok(SQLType::Double) + }else{ + Ok(SQLType::Double) + } "SMALLINT" => Ok(SQLType::SmallInt), "INT" | "INTEGER" => Ok(SQLType::Int), "BIGINT" => Ok(SQLType::BigInt), - "VARCHAR" => Ok(SQLType::Varchar(self.parse_precision()?)), + "VARCHAR" => Ok(SQLType::Varchar(self.parse_optional_precision()?)), + "CHARACTER" => { + if self.parse_keyword("VARYING"){ + Ok(SQLType::Varchar(self.parse_optional_precision()?)) + }else{ + Ok(SQLType::Char(self.parse_optional_precision()?)) + } + } + "DATE" => Ok(SQLType::Date), + "TIMESTAMP" => if self.parse_keyword("WITH"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Timestamp) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else if self.parse_keyword("WITHOUT"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Timestamp) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else{ + Ok(SQLType::Timestamp) + } + "TIME" => if self.parse_keyword("WITH"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Time) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else if self.parse_keyword("WITHOUT"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Time) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else{ + Ok(SQLType::Timestamp) + } + "REGCLASS" => Ok(SQLType::Regclass), + "TEXT" => { + if let Ok(true) = self.consume_token(&Token::LBracket){ + self.consume_token(&Token::RBracket)?; + Ok(SQLType::Array(Box::new(SQLType::Text))) + }else{ + Ok(SQLType::Text) + } + } + "BYTEA" => Ok(SQLType::Bytea), + "NUMERIC" => { + let (precision, scale) = self.parse_optional_precision_scale()?; + Ok(SQLType::Decimal(precision, scale)) + } _ => parser_err!(format!("Invalid data type '{:?}'", k)), }, + Some(Token::Identifier(id)) => { + if let Ok(true) = self.consume_token(&Token::Period) { + let ids = self.parse_tablename()?; + Ok(SQLType::Custom(format!("{}.{}",id,ids))) + }else{ + Ok(SQLType::Custom(id)) + } + } other => parser_err!(format!("Invalid data type: '{:?}'", other)), } } @@ -497,6 +619,21 @@ impl Parser { } } + pub fn parse_optional_precision_scale(&mut self) -> Result<(usize, Option), ParserError> { + if self.consume_token(&Token::LParen)? { + let n = self.parse_literal_int()?; + let scale = if let Ok(true) = self.consume_token(&Token::Comma){ + Some(self.parse_literal_int()? as usize) + }else{ + None + }; + self.consume_token(&Token::RParen)?; + Ok((n as usize, scale)) + } else { + parser_err!("Expecting `(`") + } + } + pub fn parse_delete(&mut self) -> Result { let relation: Option> = if self.parse_keyword("FROM") { Some(Box::new(self.parse_expr(0)?)) @@ -1041,7 +1178,7 @@ mod tests { let c_name = &columns[0]; assert_eq!("name", c_name.name); - assert_eq!(SQLType::Varchar(100), c_name.data_type); + assert_eq!(SQLType::Varchar(Some(100)), c_name.data_type); assert_eq!(false, c_name.allow_null); let c_lat = &columns[1]; @@ -1058,6 +1195,110 @@ mod tests { } } + #[test] + fn parse_create_table_with_defaults() { + let sql = String::from( + "CREATE TABLE public.customer ( + customer_id integer DEFAULT nextval(public.customer_customer_id_seq) NOT NULL, + store_id smallint NOT NULL, + first_name character varying(45) NOT NULL, + last_name character varying(45) NOT NULL, + email character varying(50), + address_id smallint NOT NULL, + activebool boolean DEFAULT true NOT NULL, + create_date date DEFAULT now()::text NOT NULL, + last_update timestamp without time zone DEFAULT now() NOT NULL, + active integer NOT NULL)"); + let ast = parse_sql(&sql); + match ast { + ASTNode::SQLCreateTable { name, columns } => { + assert_eq!("public.customer", name); + assert_eq!(10, columns.len()); + + let c_name = &columns[0]; + assert_eq!("customer_id", c_name.name); + assert_eq!(SQLType::Int, c_name.data_type); + assert_eq!(false, c_name.allow_null); + + let c_lat = &columns[1]; + assert_eq!("store_id", c_lat.name); + assert_eq!(SQLType::SmallInt, c_lat.data_type); + assert_eq!(false, c_lat.allow_null); + + let c_lng = &columns[2]; + assert_eq!("first_name", c_lng.name); + assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type); + assert_eq!(false, c_lng.allow_null); + } + _ => assert!(false), + } + } + + #[test] + fn parse_create_table_from_pg_dump() { + let sql = String::from(" + CREATE TABLE public.customer ( + customer_id integer DEFAULT nextval('public.customer_customer_id_seq'::regclass) NOT NULL, + store_id smallint NOT NULL, + first_name character varying(45) NOT NULL, + last_name character varying(45) NOT NULL, + info text[], + address_id smallint NOT NULL, + activebool boolean DEFAULT true NOT NULL, + create_date date DEFAULT now()::date NOT NULL, + create_date1 date DEFAULT 'now'::text::date NOT NULL, + last_update timestamp without time zone DEFAULT now(), + release_year public.year, + active integer + )"); + let ast = parse_sql(&sql); + match ast { + ASTNode::SQLCreateTable { name, columns } => { + assert_eq!("public.customer", name); + + let c_name = &columns[0]; + assert_eq!("customer_id", c_name.name); + assert_eq!(SQLType::Int, c_name.data_type); + assert_eq!(false, c_name.allow_null); + + let c_lat = &columns[1]; + assert_eq!("store_id", c_lat.name); + assert_eq!(SQLType::SmallInt, c_lat.data_type); + assert_eq!(false, c_lat.allow_null); + + let c_lng = &columns[2]; + assert_eq!("first_name", c_lng.name); + assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type); + assert_eq!(false, c_lng.allow_null); + } + _ => assert!(false), + } + } + + #[test] + fn parse_create_table_with_inherit() { + let sql = String::from(" + CREATE TABLE bazaar.settings ( + user_id uuid, + value text[], + settings_id uuid DEFAULT uuid_generate_v4() NOT NULL, + use_metric boolean DEFAULT true + ) + INHERITS (system.record)"); + let ast = parse_sql(&sql); + match ast { + ASTNode::SQLCreateTable { name, columns } => { + assert_eq!("bazaar.settings", name); + + let c_name = &columns[0]; + assert_eq!("user_id", c_name.name); + assert_eq!(SQLType::Custom("uuid".into()), c_name.data_type); + assert_eq!(true, c_name.allow_null); + } + _ => assert!(false), + } + } + #[test] fn parse_scalar_function_in_projection() { let sql = String::from("SELECT sqrt(id) FROM foo"); @@ -1107,7 +1348,18 @@ mod tests { } } + #[test] + fn parse_function_now(){ + let sql = "now()"; + let mut parser = parser(sql); + let ast = parser.parse(); + println!("ast: {:?}", ast); + assert!(ast.is_ok()); + } + fn parse_sql(sql: &str) -> ASTNode { + debug!("sql: {}", sql); + println!("sql: {}", sql); let mut parser = parser(sql); let ast = parser.parse().unwrap(); ast @@ -1117,6 +1369,7 @@ mod tests { let dialect = GenericSqlDialect {}; let mut tokenizer = Tokenizer::new(&dialect, &sql); let tokens = tokenizer.tokenize().unwrap(); + debug!("tokens: {:#?}", tokens); Parser::new(tokens) } diff --git a/src/sqltokenizer.rs b/src/sqltokenizer.rs index da675285..28435d0f 100644 --- a/src/sqltokenizer.rs +++ b/src/sqltokenizer.rs @@ -66,6 +66,14 @@ pub enum Token { RParen, /// Period (used for compound identifiers or projections into nested types) Period, + /// Colon `:` + Colon, + /// DoubleColon `::` (used for casting in postgresql) + DoubleColon, + /// Left bracket `[` + LBracket, + /// Right bracket `]` + RBracket, } /// Tokenizer error @@ -243,6 +251,23 @@ impl<'a> Tokenizer<'a> { None => Ok(Some(Token::Gt)), } } + // colon + ':' => { + chars.next(); + match chars.peek() { + Some(&ch) => match ch { + // double colon + ':' => { + self.consume_and_return(chars, Token::DoubleColon) + } + _ => Ok(Some(Token::Colon)), + }, + None => Ok(Some(Token::Colon)), + } + } + // brakets + '[' => self.consume_and_return(chars, Token::LBracket), + ']' => self.consume_and_return(chars, Token::RBracket), _ => Err(TokenizerError(format!( "Tokenizer Error at Line: {}, Column: {}, unhandled char '{}'", self.line, self.col, ch