diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 9ce0a8d4..83d15aa8 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -471,6 +471,7 @@ impl ToString for SQLStatement { } => { let mut s = format!("UPDATE {}", table_name.to_string()); if !assignments.is_empty() { + s += " SET "; s += &comma_separated_string(assignments); } if let Some(selection) = selection { @@ -560,13 +561,13 @@ impl ToString for SQLObjectName { /// SQL assignment `foo = expr` as used in SQLUpdate #[derive(Debug, Clone, PartialEq)] pub struct SQLAssignment { - id: SQLIdent, - value: ASTNode, + pub id: SQLIdent, + pub value: ASTNode, } impl ToString for SQLAssignment { fn to_string(&self) -> String { - format!("SET {} = {}", self.id, self.value.to_string()) + format!("{} = {}", self.id, self.value.to_string()) } } diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 85121eef..5937db7e 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -117,6 +117,7 @@ impl Parser { "DROP" => Ok(self.parse_drop()?), "DELETE" => Ok(self.parse_delete()?), "INSERT" => Ok(self.parse_insert()?), + "UPDATE" => Ok(self.parse_update()?), "ALTER" => Ok(self.parse_alter()?), "COPY" => Ok(self.parse_copy()?), _ => parser_err!(format!( @@ -1606,6 +1607,31 @@ impl Parser { }) } + pub fn parse_update(&mut self) -> Result { + let table_name = self.parse_object_name()?; + self.expect_keyword("SET")?; + let mut assignments = vec![]; + loop { + let id = self.parse_identifier()?; + self.expect_token(&Token::Eq)?; + let value = self.parse_expr()?; + assignments.push(SQLAssignment { id, value }); + if !self.consume_token(&Token::Comma) { + break; + } + } + let selection = if self.parse_keyword("WHERE") { + Some(self.parse_expr()?) + } else { + None + }; + Ok(SQLStatement::SQLUpdate { + table_name, + assignments, + selection, + }) + } + /// Parse a comma-delimited list of SQL expressions pub fn parse_expr_list(&mut self) -> Result, ParserError> { let mut expr_list: Vec = vec![]; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index ab5ce4d2..0a911685 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -64,6 +64,56 @@ fn parse_insert_invalid() { ); } +#[test] +fn parse_update() { + let sql = "UPDATE t SET a = 1, b = 2, c = 3 WHERE d"; + match verified_stmt(sql) { + SQLStatement::SQLUpdate { + table_name, + assignments, + selection, + .. + } => { + assert_eq!(table_name.to_string(), "t".to_string()); + assert_eq!( + assignments, + vec![ + SQLAssignment { + id: "a".into(), + value: ASTNode::SQLValue(Value::Long(1)), + }, + SQLAssignment { + id: "b".into(), + value: ASTNode::SQLValue(Value::Long(2)), + }, + SQLAssignment { + id: "c".into(), + value: ASTNode::SQLValue(Value::Long(3)), + }, + ] + ); + assert_eq!(selection.unwrap(), ASTNode::SQLIdentifier("d".into())); + } + _ => unreachable!(), + } + + verified_stmt("UPDATE t SET a = 1, a = 2, a = 3"); + + let sql = "UPDATE t WHERE 1"; + let res = parse_sql_statements(sql); + assert_eq!( + ParserError::ParserError("Expected SET, found: WHERE".to_string()), + res.unwrap_err() + ); + + let sql = "UPDATE t SET a = 1 extrabadstuff"; + let res = parse_sql_statements(sql); + assert_eq!( + ParserError::ParserError("Expected end of statement, found: extrabadstuff".to_string()), + res.unwrap_err() + ); +} + #[test] fn parse_invalid_table_name() { let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);