diff --git a/src/dialect.rs b/src/dialect.rs index 67496046..ea1fd3c7 100644 --- a/src/dialect.rs +++ b/src/dialect.rs @@ -438,6 +438,8 @@ impl Dialect for GenericSqlDialect { "REGCLASS", "TEXT", "BYTEA", + "TRUE", + "FALSE", ]; } diff --git a/src/sqlast.rs b/src/sqlast.rs index eeb770a8..3e5d46ac 100644 --- a/src/sqlast.rs +++ b/src/sqlast.rs @@ -53,6 +53,8 @@ pub enum ASTNode { SQLLiteralDouble(f64), /// Literal string SQLLiteralString(String), + /// Boolean value true or false, + SQLBoolean(bool), /// Scalar function call e.g. `LEFT(foo, 5)` SQLFunction { id: String, args: Vec }, /// SELECT diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 78c6e92a..8251a381 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -92,6 +92,8 @@ impl Parser { "SELECT" => Ok(self.parse_select()?), "CREATE" => Ok(self.parse_create()?), "DELETE" => Ok(self.parse_delete()?), + "TRUE" => Ok(ASTNode::SQLBoolean(true)), + "FALSE" => Ok(ASTNode::SQLBoolean(false)), _ => return parser_err!(format!("No prefix parser for keyword {}", k)), }, Token::Mult => Ok(ASTNode::SQLWildcard), @@ -385,9 +387,8 @@ impl Parser { println!("column name: {}", column_name); if let Ok(data_type) = self.parse_data_type() { let default = if self.parse_keyword("DEFAULT"){ - self.consume_token(&Token::LParen); let expr = self.parse_expr(0)?; - self.consume_token(&Token::RParen); + println!("expr: {:?}", expr); Some(Box::new(expr)) }else{ None @@ -472,7 +473,11 @@ impl Parser { "BOOLEAN" => Ok(SQLType::Boolean), "FLOAT" => Ok(SQLType::Float(self.parse_optional_precision()?)), "REAL" => Ok(SQLType::Real), - "DOUBLE" => Ok(SQLType::Double), + "DOUBLE" => if self.parse_keyword("PRECISION"){ + Ok(SQLType::Double) + }else{ + Ok(SQLType::Double) + } "SMALLINT" => Ok(SQLType::SmallInt), "INT" | "INTEGER" => Ok(SQLType::Int), "BIGINT" => Ok(SQLType::BigInt), @@ -500,6 +505,21 @@ impl Parser { }else{ Ok(SQLType::Timestamp) } + "TIME" => if self.parse_keyword("WITH"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Time) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else if self.parse_keyword("WITHOUT"){ + if self.parse_keywords(vec!["TIME","ZONE"]){ + Ok(SQLType::Time) + }else{ + parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token())) + } + }else{ + Ok(SQLType::Timestamp) + } "REGCLASS" => Ok(SQLType::Regclass), "TEXT" => { if let Ok(true) = self.consume_token(&Token::LBracket){ @@ -1164,6 +1184,30 @@ mod tests { } } + #[test] + fn parse_create_table_with_inherit() { + let sql = String::from(" + CREATE TABLE bazaar.settings ( + user_id uuid, + value text[], + settings_id uuid DEFAULT uuid_generate_v4() NOT NULL, + use_metric boolean DEFAULT true + ) + INHERITS (system.record)"); + let ast = parse_sql(&sql); + match ast { + ASTNode::SQLCreateTable { name, columns } => { + assert_eq!("bazaar.settings", name); + + let c_name = &columns[0]; + assert_eq!("user_id", c_name.name); + assert_eq!(SQLType::Custom("uuid".into()), c_name.data_type); + assert_eq!(true, c_name.allow_null); + } + _ => assert!(false), + } + } + #[test] fn parse_scalar_function_in_projection() { let sql = String::from("SELECT sqrt(id) FROM foo");