Implement ON CONFLICT and RETURNING (#666)

* Implement RETURNING on INSERT/UPDATE/DELETE

* Implement INSERT ... ON CONFLICT

* Fix tests

* cargo fmt

* tests: on conflict and returning

Co-authored-by: gamife <gamife9886@gmail.com>
This commit is contained in:
main() 2022-11-11 22:15:31 +01:00 committed by GitHub
parent ae1c69034e
commit 814367a6ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 250 additions and 8 deletions

View file

@ -1049,6 +1049,8 @@ pub enum Statement {
/// whether the insert has the table keyword (Hive) /// whether the insert has the table keyword (Hive)
table: bool, table: bool,
on: Option<OnInsert>, on: Option<OnInsert>,
/// RETURNING
returning: Option<Vec<SelectItem>>,
}, },
// TODO: Support ROW FORMAT // TODO: Support ROW FORMAT
Directory { Directory {
@ -1089,6 +1091,8 @@ pub enum Statement {
from: Option<TableWithJoins>, from: Option<TableWithJoins>,
/// WHERE /// WHERE
selection: Option<Expr>, selection: Option<Expr>,
/// RETURNING
returning: Option<Vec<SelectItem>>,
}, },
/// DELETE /// DELETE
Delete { Delete {
@ -1098,6 +1102,8 @@ pub enum Statement {
using: Option<TableFactor>, using: Option<TableFactor>,
/// WHERE /// WHERE
selection: Option<Expr>, selection: Option<Expr>,
/// RETURNING
returning: Option<Vec<SelectItem>>,
}, },
/// CREATE VIEW /// CREATE VIEW
CreateView { CreateView {
@ -1679,6 +1685,7 @@ impl fmt::Display for Statement {
source, source,
table, table,
on, on,
returning,
} => { } => {
if let Some(action) = or { if let Some(action) = or {
write!(f, "INSERT OR {} INTO {} ", action, table_name)?; write!(f, "INSERT OR {} INTO {} ", action, table_name)?;
@ -1706,10 +1713,14 @@ impl fmt::Display for Statement {
write!(f, "{}", source)?; write!(f, "{}", source)?;
if let Some(on) = on { if let Some(on) = on {
write!(f, "{}", on) write!(f, "{}", on)?;
} else {
Ok(())
} }
if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
}
Ok(())
} }
Statement::Copy { Statement::Copy {
@ -1753,6 +1764,7 @@ impl fmt::Display for Statement {
assignments, assignments,
from, from,
selection, selection,
returning,
} => { } => {
write!(f, "UPDATE {}", table)?; write!(f, "UPDATE {}", table)?;
if !assignments.is_empty() { if !assignments.is_empty() {
@ -1764,12 +1776,16 @@ impl fmt::Display for Statement {
if let Some(selection) = selection { if let Some(selection) = selection {
write!(f, " WHERE {}", selection)?; write!(f, " WHERE {}", selection)?;
} }
if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
}
Ok(()) Ok(())
} }
Statement::Delete { Statement::Delete {
table_name, table_name,
using, using,
selection, selection,
returning,
} => { } => {
write!(f, "DELETE FROM {}", table_name)?; write!(f, "DELETE FROM {}", table_name)?;
if let Some(using) = using { if let Some(using) = using {
@ -1778,6 +1794,9 @@ impl fmt::Display for Statement {
if let Some(selection) = selection { if let Some(selection) = selection {
write!(f, " WHERE {}", selection)?; write!(f, " WHERE {}", selection)?;
} }
if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
}
Ok(()) Ok(())
} }
Statement::Close { cursor } => { Statement::Close { cursor } => {
@ -2610,6 +2629,21 @@ pub enum MinMaxValue {
pub enum OnInsert { pub enum OnInsert {
/// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead) /// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead)
DuplicateKeyUpdate(Vec<Assignment>), DuplicateKeyUpdate(Vec<Assignment>),
/// ON CONFLICT is a PostgreSQL and Sqlite extension
OnConflict(OnConflict),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct OnConflict {
pub conflict_target: Vec<Ident>,
pub action: OnConflictAction,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum OnConflictAction {
DoNothing,
DoUpdate(Vec<Assignment>),
} }
impl fmt::Display for OnInsert { impl fmt::Display for OnInsert {
@ -2620,6 +2654,24 @@ impl fmt::Display for OnInsert {
" ON DUPLICATE KEY UPDATE {}", " ON DUPLICATE KEY UPDATE {}",
display_comma_separated(expr) display_comma_separated(expr)
), ),
Self::OnConflict(o) => write!(f, " {o}"),
}
}
}
impl fmt::Display for OnConflict {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, " ON CONFLICT")?;
if !self.conflict_target.is_empty() {
write!(f, "({})", display_comma_separated(&self.conflict_target))?;
}
write!(f, " {}", self.action)
}
}
impl fmt::Display for OnConflictAction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::DoNothing => write!(f, "DO NOTHING"),
Self::DoUpdate(a) => write!(f, "DO UPDATE SET {}", display_comma_separated(a)),
} }
} }
} }

View file

@ -144,6 +144,7 @@ define_keywords!(
COMMITTED, COMMITTED,
COMPUTE, COMPUTE,
CONDITION, CONDITION,
CONFLICT,
CONNECT, CONNECT,
CONNECTION, CONNECTION,
CONSTRAINT, CONSTRAINT,
@ -200,6 +201,7 @@ define_keywords!(
DISCONNECT, DISCONNECT,
DISTINCT, DISTINCT,
DISTRIBUTE, DISTRIBUTE,
DO,
DOUBLE, DOUBLE,
DOW, DOW,
DOY, DOY,
@ -370,6 +372,7 @@ define_keywords!(
NOSCAN, NOSCAN,
NOSUPERUSER, NOSUPERUSER,
NOT, NOT,
NOTHING,
NTH_VALUE, NTH_VALUE,
NTILE, NTILE,
NULL, NULL,
@ -464,6 +467,7 @@ define_keywords!(
RESTRICT, RESTRICT,
RESULT, RESULT,
RETURN, RETURN,
RETURNING,
RETURNS, RETURNS,
REVOKE, REVOKE,
RIGHT, RIGHT,

View file

@ -4070,10 +4070,17 @@ impl<'a> Parser<'a> {
None None
}; };
let returning = if self.parse_keyword(Keyword::RETURNING) {
Some(self.parse_comma_separated(Parser::parse_select_item)?)
} else {
None
};
Ok(Statement::Delete { Ok(Statement::Delete {
table_name, table_name,
using, using,
selection, selection,
returning,
}) })
} }
@ -5191,12 +5198,38 @@ impl<'a> Parser<'a> {
let source = Box::new(self.parse_query()?); let source = Box::new(self.parse_query()?);
let on = if self.parse_keyword(Keyword::ON) { let on = if self.parse_keyword(Keyword::ON) {
self.expect_keyword(Keyword::DUPLICATE)?; if self.parse_keyword(Keyword::CONFLICT) {
self.expect_keyword(Keyword::KEY)?; let conflict_target =
self.expect_keyword(Keyword::UPDATE)?; self.parse_parenthesized_column_list(IsOptional::Optional)?;
let l = self.parse_comma_separated(Parser::parse_assignment)?;
Some(OnInsert::DuplicateKeyUpdate(l)) self.expect_keyword(Keyword::DO)?;
let action = if self.parse_keyword(Keyword::NOTHING) {
OnConflictAction::DoNothing
} else {
self.expect_keyword(Keyword::UPDATE)?;
self.expect_keyword(Keyword::SET)?;
let l = self.parse_comma_separated(Parser::parse_assignment)?;
OnConflictAction::DoUpdate(l)
};
Some(OnInsert::OnConflict(OnConflict {
conflict_target,
action,
}))
} else {
self.expect_keyword(Keyword::DUPLICATE)?;
self.expect_keyword(Keyword::KEY)?;
self.expect_keyword(Keyword::UPDATE)?;
let l = self.parse_comma_separated(Parser::parse_assignment)?;
Some(OnInsert::DuplicateKeyUpdate(l))
}
} else {
None
};
let returning = if self.parse_keyword(Keyword::RETURNING) {
Some(self.parse_comma_separated(Parser::parse_select_item)?)
} else { } else {
None None
}; };
@ -5212,6 +5245,7 @@ impl<'a> Parser<'a> {
source, source,
table, table,
on, on,
returning,
}) })
} }
} }
@ -5230,11 +5264,17 @@ impl<'a> Parser<'a> {
} else { } else {
None None
}; };
let returning = if self.parse_keyword(Keyword::RETURNING) {
Some(self.parse_comma_separated(Parser::parse_select_item)?)
} else {
None
};
Ok(Statement::Update { Ok(Statement::Update {
table, table,
assignments, assignments,
from, from,
selection, selection,
returning,
}) })
} }

View file

@ -195,6 +195,7 @@ fn parse_update_with_table_alias() {
assignments, assignments,
from: _from, from: _from,
selection, selection,
returning,
} => { } => {
assert_eq!( assert_eq!(
TableWithJoins { TableWithJoins {
@ -231,6 +232,7 @@ fn parse_update_with_table_alias() {
}), }),
selection selection
); );
assert_eq!(None, returning);
} }
_ => unreachable!(), _ => unreachable!(),
} }
@ -278,6 +280,7 @@ fn parse_where_delete_statement() {
table_name, table_name,
using, using,
selection, selection,
returning,
} => { } => {
assert_eq!( assert_eq!(
TableFactor::Table { TableFactor::Table {
@ -298,6 +301,7 @@ fn parse_where_delete_statement() {
}, },
selection.unwrap(), selection.unwrap(),
); );
assert_eq!(None, returning);
} }
_ => unreachable!(), _ => unreachable!(),
} }
@ -313,6 +317,7 @@ fn parse_where_delete_with_alias_statement() {
table_name, table_name,
using, using,
selection, selection,
returning,
} => { } => {
assert_eq!( assert_eq!(
TableFactor::Table { TableFactor::Table {
@ -353,6 +358,7 @@ fn parse_where_delete_with_alias_statement() {
}, },
selection.unwrap(), selection.unwrap(),
); );
assert_eq!(None, returning);
} }
_ => unreachable!(), _ => unreachable!(),
} }

View file

@ -814,6 +814,7 @@ fn parse_update_with_joins() {
assignments, assignments,
from: _from, from: _from,
selection, selection,
returning,
} => { } => {
assert_eq!( assert_eq!(
TableWithJoins { TableWithJoins {
@ -869,6 +870,7 @@ fn parse_update_with_joins() {
}), }),
selection selection
); );
assert_eq!(None, returning);
} }
_ => unreachable!(), _ => unreachable!(),
} }

View file

@ -564,6 +564,7 @@ fn parse_update_set_from() {
Ident::new("id") Ident::new("id")
])), ])),
}), }),
returning: None,
} }
); );
} }
@ -1177,6 +1178,143 @@ fn parse_prepare() {
); );
} }
#[test]
fn parse_pg_on_conflict() {
let stmt = pg_and_generic().verified_stmt(
"INSERT INTO distributors (did, dname) \
VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \
ON CONFLICT(did) \
DO UPDATE SET dname = EXCLUDED.dname",
);
match stmt {
Statement::Insert {
on:
Some(OnInsert::OnConflict(OnConflict {
conflict_target,
action,
})),
..
} => {
assert_eq!(vec![Ident::from("did")], conflict_target);
assert_eq!(
OnConflictAction::DoUpdate(vec![Assignment {
id: vec!["dname".into()],
value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()])
},]),
action
);
}
_ => unreachable!(),
};
let stmt = pg_and_generic().verified_stmt(
"INSERT INTO distributors (did, dname, area) \
VALUES (5, 'Gizmo Transglobal', 'Mars'), (6, 'Associated Computing, Inc', 'Venus') \
ON CONFLICT(did, area) \
DO UPDATE SET dname = EXCLUDED.dname, area = EXCLUDED.area",
);
match stmt {
Statement::Insert {
on:
Some(OnInsert::OnConflict(OnConflict {
conflict_target,
action,
})),
..
} => {
assert_eq!(
vec![Ident::from("did"), Ident::from("area"),],
conflict_target
);
assert_eq!(
OnConflictAction::DoUpdate(vec![
Assignment {
id: vec!["dname".into()],
value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()])
},
Assignment {
id: vec!["area".into()],
value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "area".into()])
},
]),
action
);
}
_ => unreachable!(),
};
let stmt = pg_and_generic().verified_stmt(
"INSERT INTO distributors (did, dname) \
VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \
ON CONFLICT DO NOTHING",
);
match stmt {
Statement::Insert {
on:
Some(OnInsert::OnConflict(OnConflict {
conflict_target,
action,
})),
..
} => {
assert_eq!(Vec::<Ident>::new(), conflict_target);
assert_eq!(OnConflictAction::DoNothing, action);
}
_ => unreachable!(),
};
}
#[test]
fn parse_pg_returning() {
let stmt = pg_and_generic().verified_stmt(
"INSERT INTO distributors (did, dname) VALUES (DEFAULT, 'XYZ Widgets') RETURNING did",
);
match stmt {
Statement::Insert { returning, .. } => {
assert_eq!(
Some(vec![SelectItem::UnnamedExpr(Expr::Identifier(
"did".into()
)),]),
returning
);
}
_ => unreachable!(),
};
let stmt = pg_and_generic().verified_stmt(
"UPDATE weather SET temp_lo = temp_lo + 1, temp_hi = temp_lo + 15, prcp = DEFAULT \
WHERE city = 'San Francisco' AND date = '2003-07-03' \
RETURNING temp_lo AS lo, temp_hi AS hi, prcp",
);
match stmt {
Statement::Update { returning, .. } => {
assert_eq!(
Some(vec![
SelectItem::ExprWithAlias {
expr: Expr::Identifier("temp_lo".into()),
alias: "lo".into()
},
SelectItem::ExprWithAlias {
expr: Expr::Identifier("temp_hi".into()),
alias: "hi".into()
},
SelectItem::UnnamedExpr(Expr::Identifier("prcp".into())),
]),
returning
);
}
_ => unreachable!(),
};
let stmt =
pg_and_generic().verified_stmt("DELETE FROM tasks WHERE status = 'DONE' RETURNING *");
match stmt {
Statement::Delete { returning, .. } => {
assert_eq!(Some(vec![SelectItem::Wildcard,]), returning);
}
_ => unreachable!(),
};
}
#[test] #[test]
fn parse_pg_bitwise_binary_ops() { fn parse_pg_bitwise_binary_ops() {
let bitwise_ops = &[ let bitwise_ops = &[