Improve the create statement parser that uses create statements from pg database dump

Added PostgreSQL style casting
This commit is contained in:
Jovansonlee Cesar 2018-09-24 03:34:40 +08:00
parent 7d27abdfb4
commit 2007995938
6 changed files with 392 additions and 81 deletions

View file

@ -18,3 +18,4 @@ name = "sqlparser"
path = "src/lib.rs"
[dependencies]
log = "0.4.5"

View file

@ -432,6 +432,12 @@ impl Dialect for GenericSqlDialect {
"DATE",
"TIME",
"TIMESTAMP",
"VALUES",
"DEFAULT",
"ZONE",
"REGCLASS",
"TEXT",
"BYTEA",
];
}

View file

@ -35,6 +35,9 @@
//! println!("AST: {:?}", ast);
//! ```
#[macro_use]
extern crate log;
pub mod dialect;
pub mod sqlast;
pub mod sqlparser;

View file

@ -135,15 +135,16 @@ pub struct SQLColumnDef {
pub name: String,
pub data_type: SQLType,
pub allow_null: bool,
pub default: Option<Box<ASTNode>>,
}
/// SQL datatypes for literals in SQL statements
#[derive(Debug, Clone, PartialEq)]
pub enum SQLType {
/// Fixed-length character type e.g. CHAR(10)
Char(usize),
Char(Option<usize>),
/// Variable-length character type e.g. VARCHAR(10)
Varchar(usize),
Varchar(Option<usize>),
/// Large character object e.g. CLOB(1000)
Clob(usize),
/// Fixed-length binary type e.g. BINARY(10)
@ -174,6 +175,16 @@ pub enum SQLType {
Time,
/// Timestamp
Timestamp,
/// Regclass used in postgresql serial
Regclass,
/// Text
Text,
/// Bytea
Bytea,
/// Custom type such as enums
Custom(String),
/// Arrays
Array(Box<SQLType>),
}
/// SQL Operator

View file

@ -66,9 +66,12 @@ impl Parser {
/// Parse tokens until the precedence changes
pub fn parse_expr(&mut self, precedence: u8) -> Result<ASTNode, ParserError> {
debug!("parsing expr");
let mut expr = self.parse_prefix()?;
debug!("prefix: {:?}", expr);
loop {
let next_precedence = self.get_next_precedence()?;
debug!("next precedence: {:?}", next_precedence);
if precedence >= next_precedence {
break;
}
@ -93,34 +96,30 @@ impl Parser {
},
Token::Mult => Ok(ASTNode::SQLWildcard),
Token::Identifier(id) => {
match self.peek_token() {
Some(Token::LParen) => {
self.next_token(); // skip lparen
match id.to_uppercase().as_ref() {
"CAST" => self.parse_cast_expression(),
_ => {
let args = self.parse_expr_list()?;
self.next_token(); // skip rparen
Ok(ASTNode::SQLFunction { id, args })
}
if "CAST" == id.to_uppercase(){
self.parse_cast_expression()
}else{
match self.peek_token() {
Some(Token::LParen) => {
self.parse_function_or_pg_cast(&id)
}
}
Some(Token::Period) => {
let mut id_parts: Vec<String> = vec![id];
while self.peek_token() == Some(Token::Period) {
self.consume_token(&Token::Period)?;
match self.next_token() {
Some(Token::Identifier(id)) => id_parts.push(id),
_ => {
return parser_err!(format!(
"Error parsing compound identifier"
))
Some(Token::Period) => {
let mut id_parts: Vec<String> = vec![id];
while self.peek_token() == Some(Token::Period) {
self.consume_token(&Token::Period)?;
match self.next_token() {
Some(Token::Identifier(id)) => id_parts.push(id),
_ => {
return parser_err!(format!(
"Error parsing compound identifier"
))
}
}
}
Ok(ASTNode::SQLCompoundIdentifier(id_parts))
}
Ok(ASTNode::SQLCompoundIdentifier(id_parts))
_ => Ok(ASTNode::SQLIdentifier(id)),
}
_ => Ok(ASTNode::SQLIdentifier(id)),
}
}
Token::Number(ref n) if n.contains(".") => match n.parse::<f64>() {
@ -142,8 +141,31 @@ impl Parser {
}
}
pub fn parse_function_or_pg_cast(&mut self, id: &str) -> Result<ASTNode, ParserError> {
let func = self.parse_function(&id)?;
println!("func: {:?}", func);
if let Some(Token::DoubleColon) = self.peek_token(){
self.parse_pg_cast(func)
}else{
Ok(func)
}
}
pub fn parse_function(&mut self, id: &str) -> Result<ASTNode, ParserError> {
self.consume_token(&Token::LParen)?;
if let Ok(true) = self.consume_token(&Token::RParen){
Ok(ASTNode::SQLFunction { id: id.to_string(), args: vec![] })
}else{
let args = self.parse_expr_list()?;
self.consume_token(&Token::RParen)?;
Ok(ASTNode::SQLFunction { id: id.to_string(), args })
}
}
/// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)`
pub fn parse_cast_expression(&mut self) -> Result<ASTNode, ParserError> {
println!("parsing cast");
self.consume_token(&Token::LParen)?;
let expr = self.parse_expr(0)?;
self.consume_token(&Token::Keyword("AS".to_string()))?;
let data_type = self.parse_data_type()?;
@ -154,12 +176,35 @@ impl Parser {
})
}
/// Parse a postgresql casting style which is in the form or expr::datatype
pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result<ASTNode, ParserError> {
let ast = self.consume_token(&Token::DoubleColon)?;
let datatype = if let Ok(data_type) = self.parse_data_type(){
Ok(data_type)
}else if let Ok(table_name) = self.parse_tablename(){
Ok(SQLType::Custom(table_name))
}else{
parser_err!("Expecting datatype or identifier")
};
let pg_cast = ASTNode::SQLCast{
expr: Box::new(expr),
data_type: datatype?,
};
if let Some(Token::DoubleColon) = self.peek_token(){
self.parse_pg_cast(pg_cast)
}else{
Ok(pg_cast)
}
}
/// Parse an expression infix (typically an operator)
pub fn parse_infix(
&mut self,
expr: ASTNode,
precedence: u8,
) -> Result<Option<ASTNode>, ParserError> {
debug!("parsing infix");
match self.next_token() {
Some(tok) => match tok {
Token::Keyword(ref k) => if k == "IS" {
@ -192,6 +237,10 @@ impl Parser {
op: self.to_sql_operator(&tok)?,
right: Box::new(self.parse_expr(precedence)?),
})),
| Token::DoubleColon => {
let pg_cast = self.parse_pg_cast(expr)?;
Ok(Some(pg_cast))
},
_ => parser_err!(format!("No infix parser for token {:?}", tok)),
},
None => Ok(None),
@ -229,7 +278,7 @@ impl Parser {
/// Get the precedence of a token
pub fn get_precedence(&self, tok: &Token) -> Result<u8, ParserError> {
//println!("get_precedence() {:?}", tok);
debug!("get_precedence() {:?}", tok);
match tok {
&Token::Keyword(ref k) if k == "OR" => Ok(5),
@ -240,6 +289,7 @@ impl Parser {
}
&Token::Plus | &Token::Minus => Ok(30),
&Token::Mult | &Token::Div | &Token::Mod => Ok(40),
&Token::DoubleColon => Ok(50),
_ => Ok(0),
}
}
@ -325,64 +375,70 @@ impl Parser {
/// Parse a SQL CREATE statement
pub fn parse_create(&mut self) -> Result<ASTNode, ParserError> {
if self.parse_keywords(vec!["TABLE"]) {
match self.next_token() {
Some(Token::Identifier(id)) => {
// parse optional column list (schema)
let mut columns = vec![];
if self.consume_token(&Token::LParen)? {
loop {
if let Some(Token::Identifier(column_name)) = self.next_token() {
if let Ok(data_type) = self.parse_data_type() {
let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) {
false
} else if self.parse_keyword("NULL") {
true
} else {
true
};
let table_name = self.parse_tablename()?;
// parse optional column list (schema)
println!("table_name: {}", table_name);
let mut columns = vec![];
if self.consume_token(&Token::LParen)? {
loop {
if let Some(Token::Identifier(column_name)) = self.next_token() {
println!("column name: {}", column_name);
if let Ok(data_type) = self.parse_data_type() {
let default = if self.parse_keyword("DEFAULT"){
self.consume_token(&Token::LParen);
let expr = self.parse_expr(0)?;
self.consume_token(&Token::RParen);
Some(Box::new(expr))
}else{
None
};
println!("default: {:?}", default);
let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) {
false
} else if self.parse_keyword("NULL") {
true
} else {
true
};
debug!("default: {:?}", default);
match self.peek_token() {
Some(Token::Comma) => {
self.next_token();
columns.push(SQLColumnDef {
name: column_name,
data_type: data_type,
allow_null,
});
}
Some(Token::RParen) => {
self.next_token();
columns.push(SQLColumnDef {
name: column_name,
data_type: data_type,
allow_null,
});
break;
}
_ => {
return parser_err!(
"Expected ',' or ')' after column definition"
);
}
}
} else {
match self.peek_token() {
Some(Token::Comma) => {
self.next_token();
columns.push(SQLColumnDef {
name: column_name,
data_type: data_type,
allow_null,
default,
});
}
Some(Token::RParen) => {
self.next_token();
columns.push(SQLColumnDef {
name: column_name,
data_type: data_type,
allow_null,
default,
});
break;
}
other => {
return parser_err!(
"Error parsing data type in column definition"
format!("Expected ',' or ')' after column definition but found {:?}", other)
);
}
} else {
return parser_err!("Error parsing column name");
}
} else {
return parser_err!(
format!("Error parsing data type in column definition near: {:?}", self.peek_token())
);
}
} else {
return parser_err!("Error parsing column name");
}
Ok(ASTNode::SQLCreateTable { name: id, columns })
}
_ => parser_err!(format!(
"Unexpected token after CREATE EXTERNAL TABLE: {:?}",
self.peek_token()
)),
}
Ok(ASTNode::SQLCreateTable { name: table_name, columns })
} else {
parser_err!(format!(
"Unexpected token after CREATE: {:?}",
@ -420,13 +476,111 @@ impl Parser {
"SMALLINT" => Ok(SQLType::SmallInt),
"INT" | "INTEGER" => Ok(SQLType::Int),
"BIGINT" => Ok(SQLType::BigInt),
"VARCHAR" => Ok(SQLType::Varchar(self.parse_precision()?)),
"VARCHAR" => Ok(SQLType::Varchar(self.parse_optional_precision()?)),
"CHARACTER" => {
if self.parse_keyword("VARYING"){
Ok(SQLType::Varchar(self.parse_optional_precision()?))
}else{
Ok(SQLType::Char(self.parse_optional_precision()?))
}
}
"DATE" => Ok(SQLType::Date),
"TIMESTAMP" => if self.parse_keyword("WITH"){
if self.parse_keywords(vec!["TIME","ZONE"]){
Ok(SQLType::Timestamp)
}else{
parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token()))
}
}else if self.parse_keyword("WITHOUT"){
if self.parse_keywords(vec!["TIME","ZONE"]){
Ok(SQLType::Timestamp)
}else{
parser_err!(format!("Expecting 'time zone', found: {:?}", self.peek_token()))
}
}else{
Ok(SQLType::Timestamp)
}
"REGCLASS" => Ok(SQLType::Regclass),
"TEXT" => {
if let Ok(true) = self.consume_token(&Token::LBracket){
self.consume_token(&Token::RBracket)?;
Ok(SQLType::Array(Box::new(SQLType::Text)))
}else{
Ok(SQLType::Text)
}
}
"BYTEA" => Ok(SQLType::Bytea),
"NUMERIC" => {
let (precision, scale) = self.parse_optional_precision_scale()?;
Ok(SQLType::Decimal(precision, scale))
}
_ => parser_err!(format!("Invalid data type '{:?}'", k)),
},
Some(Token::Identifier(id)) => {
if let Ok(true) = self.consume_token(&Token::Period) {
let ids = self.parse_tablename()?;
Ok(SQLType::Custom(format!("{}.{}",id,ids)))
}else{
Ok(SQLType::Custom(id))
}
}
other => parser_err!(format!("Invalid data type: '{:?}'", other)),
}
}
pub fn parse_compound_identifier(&mut self, separator: &Token) -> Result<ASTNode, ParserError> {
let mut idents = vec![];
let mut expect_identifier = true;
loop {
let token = &self.next_token();
match token{
Some(token) => match token{
Token::Identifier(s) => if expect_identifier{
expect_identifier = false;
idents.push(s.to_string());
}else{
self.prev_token();
break;
}
token if token == separator => {
if expect_identifier{
return parser_err!(format!("Expecting identifier, found {:?}", token));
}else{
expect_identifier = true;
continue;
}
}
_ => {
self.prev_token();
break;
}
}
None => {
self.prev_token();
break;
}
}
}
Ok(ASTNode::SQLCompoundIdentifier(idents))
}
pub fn parse_tablename(&mut self) -> Result<String, ParserError> {
let identifier = self.parse_compound_identifier(&Token::Period)?;
match identifier{
ASTNode::SQLCompoundIdentifier(idents) => Ok(idents.join(".")),
other => parser_err!(format!("Expecting compound identifier, found: {:?}", other)),
}
}
pub fn parse_column_names(&mut self) -> Result<Vec<String>, ParserError> {
let identifier = self.parse_compound_identifier(&Token::Comma)?;
match identifier{
ASTNode::SQLCompoundIdentifier(idents) => Ok(idents),
other => parser_err!(format!("Expecting compound identifier, found: {:?}", other)),
}
}
pub fn parse_precision(&mut self) -> Result<usize, ParserError> {
//TODO: error handling
Ok(self.parse_optional_precision()?.unwrap())
@ -443,6 +597,21 @@ impl Parser {
}
}
pub fn parse_optional_precision_scale(&mut self) -> Result<(usize, Option<usize>), ParserError> {
if self.consume_token(&Token::LParen)? {
let n = self.parse_literal_int()?;
let scale = if let Ok(true) = self.consume_token(&Token::Comma){
Some(self.parse_literal_int()? as usize)
}else{
None
};
self.consume_token(&Token::RParen)?;
Ok((n as usize, scale))
} else {
parser_err!("Expecting `(`")
}
}
pub fn parse_delete(&mut self) -> Result<ASTNode, ParserError> {
let relation: Option<Box<ASTNode>> = if self.parse_keyword("FROM") {
Some(Box::new(self.parse_expr(0)?))
@ -898,7 +1067,7 @@ mod tests {
let c_name = &columns[0];
assert_eq!("name", c_name.name);
assert_eq!(SQLType::Varchar(100), c_name.data_type);
assert_eq!(SQLType::Varchar(Some(100)), c_name.data_type);
assert_eq!(false, c_name.allow_null);
let c_lat = &columns[1];
@ -915,6 +1084,86 @@ mod tests {
}
}
#[test]
fn parse_create_table_with_defaults() {
let sql = String::from(
"CREATE TABLE public.customer (
customer_id integer DEFAULT nextval(public.customer_customer_id_seq) NOT NULL,
store_id smallint NOT NULL,
first_name character varying(45) NOT NULL,
last_name character varying(45) NOT NULL,
email character varying(50),
address_id smallint NOT NULL,
activebool boolean DEFAULT true NOT NULL,
create_date date DEFAULT now()::text NOT NULL,
last_update timestamp without time zone DEFAULT now() NOT NULL,
active integer NOT NULL)");
let ast = parse_sql(&sql);
match ast {
ASTNode::SQLCreateTable { name, columns } => {
assert_eq!("public.customer", name);
assert_eq!(10, columns.len());
let c_name = &columns[0];
assert_eq!("customer_id", c_name.name);
assert_eq!(SQLType::Int, c_name.data_type);
assert_eq!(false, c_name.allow_null);
let c_lat = &columns[1];
assert_eq!("store_id", c_lat.name);
assert_eq!(SQLType::SmallInt, c_lat.data_type);
assert_eq!(false, c_lat.allow_null);
let c_lng = &columns[2];
assert_eq!("first_name", c_lng.name);
assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type);
assert_eq!(false, c_lng.allow_null);
}
_ => assert!(false),
}
}
#[test]
fn parse_create_table_from_pg_dump() {
let sql = String::from("
CREATE TABLE public.customer (
customer_id integer DEFAULT nextval('public.customer_customer_id_seq'::regclass) NOT NULL,
store_id smallint NOT NULL,
first_name character varying(45) NOT NULL,
last_name character varying(45) NOT NULL,
info text[],
address_id smallint NOT NULL,
activebool boolean DEFAULT true NOT NULL,
create_date date DEFAULT now()::date NOT NULL,
create_date1 date DEFAULT 'now'::text::date NOT NULL,
last_update timestamp without time zone DEFAULT now(),
release_year public.year,
active integer
)");
let ast = parse_sql(&sql);
match ast {
ASTNode::SQLCreateTable { name, columns } => {
assert_eq!("public.customer", name);
let c_name = &columns[0];
assert_eq!("customer_id", c_name.name);
assert_eq!(SQLType::Int, c_name.data_type);
assert_eq!(false, c_name.allow_null);
let c_lat = &columns[1];
assert_eq!("store_id", c_lat.name);
assert_eq!(SQLType::SmallInt, c_lat.data_type);
assert_eq!(false, c_lat.allow_null);
let c_lng = &columns[2];
assert_eq!("first_name", c_lng.name);
assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type);
assert_eq!(false, c_lng.allow_null);
}
_ => assert!(false),
}
}
#[test]
fn parse_scalar_function_in_projection() {
let sql = String::from("SELECT sqrt(id) FROM foo");
@ -964,13 +1213,29 @@ mod tests {
}
}
#[test]
fn parse_function_now(){
let sql = "now()";
let mut parser = parser(sql);
let ast = parser.parse();
println!("ast: {:?}", ast);
assert!(ast.is_ok());
}
fn parse_sql(sql: &str) -> ASTNode {
let dialect = GenericSqlDialect {};
let mut tokenizer = Tokenizer::new(&dialect, &sql);
let tokens = tokenizer.tokenize().unwrap();
let mut parser = Parser::new(tokens);
debug!("sql: {}", sql);
println!("sql: {}", sql);
let mut parser = parser(sql);
let ast = parser.parse().unwrap();
ast
}
fn parser(sql: &str) -> Parser {
let dialect = GenericSqlDialect {};
let mut tokenizer = Tokenizer::new(&dialect, &sql);
let tokens = tokenizer.tokenize().unwrap();
debug!("tokens: {:#?}", tokens);
Parser::new(tokens)
}
}

View file

@ -66,6 +66,14 @@ pub enum Token {
RParen,
/// Period (used for compound identifiers or projections into nested types)
Period,
/// Colon `:`
Colon,
/// DoubleColon `::` (used for casting in postgresql)
DoubleColon,
/// Left bracket `[`
LBracket,
/// Right bracket `]`
RBracket,
}
/// Tokenizer error
@ -243,6 +251,23 @@ impl<'a> Tokenizer<'a> {
None => Ok(Some(Token::Gt)),
}
}
// colon
':' => {
chars.next();
match chars.peek() {
Some(&ch) => match ch {
// double colon
':' => {
self.consume_and_return(chars, Token::DoubleColon)
}
_ => Ok(Some(Token::Colon)),
},
None => Ok(Some(Token::Colon)),
}
}
// brakets
'[' => self.consume_and_return(chars, Token::LBracket),
']' => self.consume_and_return(chars, Token::RBracket),
_ => Err(TokenizerError(format!(
"Tokenizer Error at Line: {}, Column: {}, unhandled char '{}'",
self.line, self.col, ch