diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 8acc0b13..3327c4c1 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -396,6 +396,7 @@ pub enum SQLStatement { name: SQLObjectName, /// Optional schema columns: Vec, + constraints: Vec, external: bool, file_format: Option, location: Option, @@ -503,21 +504,30 @@ impl ToString for SQLStatement { SQLStatement::SQLCreateTable { name, columns, + constraints, external, file_format, location, - } if *external => format!( - "CREATE EXTERNAL TABLE {} ({}) STORED AS {} LOCATION '{}'", - name.to_string(), - comma_separated_string(columns), - file_format.as_ref().unwrap().to_string(), - location.as_ref().unwrap() - ), - SQLStatement::SQLCreateTable { name, columns, .. } => format!( - "CREATE TABLE {} ({})", - name.to_string(), - comma_separated_string(columns) - ), + } => { + let mut s = format!( + "CREATE {}TABLE {} ({}", + if *external { "EXTERNAL " } else { "" }, + name.to_string(), + comma_separated_string(columns) + ); + if !constraints.is_empty() { + s += &format!(", {}", comma_separated_string(constraints)); + } + s += ")"; + if *external { + s += &format!( + " STORED AS {} LOCATION '{}'", + file_format.as_ref().unwrap().to_string(), + location.as_ref().unwrap() + ); + } + s + } SQLStatement::SQLAlterTable { name, operation } => { format!("ALTER TABLE {} {}", name.to_string(), operation.to_string()) } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 78f1012a..8e4b3fa5 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -769,7 +769,7 @@ impl Parser { pub fn parse_create_external_table(&mut self) -> Result { self.expect_keyword("TABLE")?; let table_name = self.parse_object_name()?; - let columns = self.parse_columns()?; + let (columns, constraints) = self.parse_columns()?; self.expect_keyword("STORED")?; self.expect_keyword("AS")?; let file_format = self.parse_identifier()?.parse::()?; @@ -780,6 +780,7 @@ impl Parser { Ok(SQLStatement::SQLCreateTable { name: table_name, columns, + constraints, external: true, file_format: Some(file_format), location: Some(location), @@ -845,74 +846,78 @@ impl Parser { pub fn parse_create_table(&mut self) -> Result { let table_name = self.parse_object_name()?; // parse optional column list (schema) - let columns = self.parse_columns()?; + let (columns, constraints) = self.parse_columns()?; Ok(SQLStatement::SQLCreateTable { name: table_name, columns, + constraints, external: false, file_format: None, location: None, }) } - fn parse_columns(&mut self) -> Result, ParserError> { + fn parse_columns(&mut self) -> Result<(Vec, Vec), ParserError> { let mut columns = vec![]; + let mut constraints = vec![]; if !self.consume_token(&Token::LParen) { - return Ok(columns); + return Ok((columns, constraints)); } loop { - match self.next_token() { - Some(Token::SQLWord(column_name)) => { - let data_type = self.parse_data_type()?; - let is_primary = self.parse_keywords(vec!["PRIMARY", "KEY"]); - let is_unique = self.parse_keyword("UNIQUE"); - let default = if self.parse_keyword("DEFAULT") { - let expr = self.parse_default_expr(0)?; - Some(expr) - } else { - None - }; - let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { - false - } else { - let _ = self.parse_keyword("NULL"); - true - }; - debug!("default: {:?}", default); + if let Some(constraint) = self.parse_optional_table_constraint()? { + constraints.push(constraint); + } else if let Some(Token::SQLWord(column_name)) = self.peek_token() { + self.next_token(); + let data_type = self.parse_data_type()?; + let is_primary = self.parse_keywords(vec!["PRIMARY", "KEY"]); + let is_unique = self.parse_keyword("UNIQUE"); + let default = if self.parse_keyword("DEFAULT") { + let expr = self.parse_default_expr(0)?; + Some(expr) + } else { + None + }; + let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { + false + } else { + let _ = self.parse_keyword("NULL"); + true + }; + debug!("default: {:?}", default); - columns.push(SQLColumnDef { - name: column_name.as_sql_ident(), - data_type, - allow_null, - is_primary, - is_unique, - default, - }); - match self.next_token() { - Some(Token::Comma) => {} - Some(Token::RParen) => { - break; - } - other => { - return parser_err!(format!( - "Expected ',' or ')' after column definition but found {:?}", - other - )); - } - } + columns.push(SQLColumnDef { + name: column_name.as_sql_ident(), + data_type, + allow_null, + is_primary, + is_unique, + default, + }); + } else { + return self.expected("column name or constraint definition", self.peek_token()); + } + match self.next_token() { + Some(Token::Comma) => {} + Some(Token::RParen) => { + break; } - unexpected => { - return parser_err!(format!("Expected column name, got {:?}", unexpected)); + other => { + return parser_err!(format!( + "Expected ',' or ')' after column definition but found {:?}", + other + )); } } } - Ok(columns) + Ok((columns, constraints)) } - pub fn parse_table_constraint(&mut self) -> Result { + pub fn parse_optional_table_constraint( + &mut self, + ) -> Result, ParserError> { let name = if self.parse_keyword("CONSTRAINT") { Some(self.parse_identifier()?) } else { @@ -925,11 +930,11 @@ impl Parser { self.expect_keyword("KEY")?; } let columns = self.parse_parenthesized_column_list(Mandatory)?; - Ok(TableConstraint::Unique { + Ok(Some(TableConstraint::Unique { name, columns, is_primary, - }) + })) } Some(Token::SQLWord(ref k)) if k.keyword == "FOREIGN" => { self.expect_keyword("KEY")?; @@ -937,20 +942,29 @@ impl Parser { self.expect_keyword("REFERENCES")?; let foreign_table = self.parse_object_name()?; let referred_columns = self.parse_parenthesized_column_list(Mandatory)?; - Ok(TableConstraint::ForeignKey { + Ok(Some(TableConstraint::ForeignKey { name, columns, foreign_table, referred_columns, - }) + })) } Some(Token::SQLWord(ref k)) if k.keyword == "CHECK" => { self.expect_token(&Token::LParen)?; let expr = Box::new(self.parse_expr()?); self.expect_token(&Token::RParen)?; - Ok(TableConstraint::Check { name, expr }) + Ok(Some(TableConstraint::Check { name, expr })) + } + unexpected => { + if name.is_some() { + self.expected("PRIMARY, UNIQUE, FOREIGN, or CHECK", unexpected) + } else { + if unexpected.is_some() { + self.prev_token(); + } + Ok(None) + } } - _ => self.expected("PRIMARY, UNIQUE, or FOREIGN", self.peek_token()), } } @@ -959,7 +973,11 @@ impl Parser { let _ = self.parse_keyword("ONLY"); let table_name = self.parse_object_name()?; let operation = if self.parse_keyword("ADD") { - AlterOperation::AddConstraint(self.parse_table_constraint()?) + if let Some(constraint) = self.parse_optional_table_constraint()? { + AlterOperation::AddConstraint(constraint) + } else { + return self.expected("a constraint in ALTER TABLE .. ADD", self.peek_token()); + } } else { return self.expected("ADD after ALTER TABLE", self.peek_token()); }; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index d2d05aef..5a0c8754 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -660,12 +660,14 @@ fn parse_create_table() { SQLStatement::SQLCreateTable { name, columns, + constraints, external: false, file_format: None, location: None, } => { assert_eq!("uk_cities", name.to_string()); assert_eq!(3, columns.len()); + assert!(constraints.is_empty()); let c_name = &columns[0]; assert_eq!("name", c_name.name); @@ -705,12 +707,14 @@ fn parse_create_external_table() { SQLStatement::SQLCreateTable { name, columns, + constraints, external, file_format, location, } => { assert_eq!("uk_cities", name.to_string()); assert_eq!(3, columns.len()); + assert!(constraints.is_empty()); let c_name = &columns[0]; assert_eq!("name", c_name.name); @@ -761,9 +765,29 @@ fn parse_alter_table_constraints() { } _ => unreachable!(), } + verified_stmt(&format!("CREATE TABLE foo (id int, {})", constraint_text)); } } +#[test] +fn parse_bad_constraint() { + let res = parse_sql_statements("ALTER TABLE tab ADD"); + assert_eq!( + ParserError::ParserError( + "Expected a constraint in ALTER TABLE .. ADD, found: EOF".to_string() + ), + res.unwrap_err() + ); + + let res = parse_sql_statements("CREATE TABLE tab (foo int,"); + assert_eq!( + ParserError::ParserError( + "Expected column name or constraint definition, found: EOF".to_string() + ), + res.unwrap_err() + ); +} + #[test] fn parse_scalar_function_in_projection() { let sql = "SELECT sqrt(id) FROM foo"; diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 522bd74f..4c1fe3ad 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -23,12 +23,14 @@ fn parse_create_table_with_defaults() { SQLStatement::SQLCreateTable { name, columns, + constraints, external: false, file_format: None, location: None, } => { assert_eq!("public.customer", name.to_string()); assert_eq!(10, columns.len()); + assert!(constraints.is_empty()); let c_name = &columns[0]; assert_eq!("customer_id", c_name.name); @@ -69,11 +71,13 @@ fn parse_create_table_from_pg_dump() { SQLStatement::SQLCreateTable { name, columns, + constraints, external: false, file_format: None, location: None, } => { assert_eq!("public.customer", name.to_string()); + assert!(constraints.is_empty()); let c_customer_id = &columns[0]; assert_eq!("customer_id", c_customer_id.name); @@ -130,11 +134,13 @@ fn parse_create_table_with_inherit() { SQLStatement::SQLCreateTable { name, columns, + constraints, external: false, file_format: None, location: None, } => { assert_eq!("bazaar.settings", name.to_string()); + assert!(constraints.is_empty()); let c_name = &columns[0]; assert_eq!("settings_id", c_name.name);