Adding support for parsing CREATE TRIGGER and DROP TRIGGER statements (#1352)

Co-authored-by: hulk <hulk.website@gmail.com>
Co-authored-by: Ifeanyi Ubah <ify1992@yahoo.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
This commit is contained in:
Luca Cappelletti 2024-08-14 15:11:16 +02:00 committed by GitHub
parent f5b818e74b
commit b072ce2589
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1022 additions and 25 deletions

View file

@ -3623,7 +3623,7 @@ fn parse_drop_function() {
pg().verified_stmt(sql),
Statement::DropFunction {
if_exists: true,
func_desc: vec![DropFunctionDesc {
func_desc: vec![FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_func".to_string(),
quote_style: None
@ -3639,7 +3639,7 @@ fn parse_drop_function() {
pg().verified_stmt(sql),
Statement::DropFunction {
if_exists: true,
func_desc: vec![DropFunctionDesc {
func_desc: vec![FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_func".to_string(),
quote_style: None
@ -3664,7 +3664,7 @@ fn parse_drop_function() {
Statement::DropFunction {
if_exists: true,
func_desc: vec![
DropFunctionDesc {
FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_func1".to_string(),
quote_style: None
@ -3682,7 +3682,7 @@ fn parse_drop_function() {
}
]),
},
DropFunctionDesc {
FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_func2".to_string(),
quote_style: None
@ -3713,7 +3713,7 @@ fn parse_drop_procedure() {
pg().verified_stmt(sql),
Statement::DropProcedure {
if_exists: true,
proc_desc: vec![DropFunctionDesc {
proc_desc: vec![FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_proc".to_string(),
quote_style: None
@ -3729,7 +3729,7 @@ fn parse_drop_procedure() {
pg().verified_stmt(sql),
Statement::DropProcedure {
if_exists: true,
proc_desc: vec![DropFunctionDesc {
proc_desc: vec![FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_proc".to_string(),
quote_style: None
@ -3754,7 +3754,7 @@ fn parse_drop_procedure() {
Statement::DropProcedure {
if_exists: true,
proc_desc: vec![
DropFunctionDesc {
FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_proc1".to_string(),
quote_style: None
@ -3772,7 +3772,7 @@ fn parse_drop_procedure() {
}
]),
},
DropFunctionDesc {
FunctionDesc {
name: ObjectName(vec![Ident {
value: "test_proc2".to_string(),
quote_style: None
@ -4455,6 +4455,478 @@ fn test_escaped_string_literal() {
}
}
#[test]
fn parse_create_simple_before_insert_trigger() {
let sql = "CREATE TRIGGER check_insert BEFORE INSERT ON accounts FOR EACH ROW EXECUTE FUNCTION check_account_insert";
let expected = Statement::CreateTrigger {
or_replace: false,
is_constraint: false,
name: ObjectName(vec![Ident::new("check_insert")]),
period: TriggerPeriod::Before,
events: vec![TriggerEvent::Insert],
table_name: ObjectName(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
condition: None,
exec_body: TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
func_desc: FunctionDesc {
name: ObjectName(vec![Ident::new("check_account_insert")]),
args: None,
},
},
characteristics: None,
};
assert_eq!(pg().verified_stmt(sql), expected);
}
#[test]
fn parse_create_after_update_trigger_with_condition() {
let sql = "CREATE TRIGGER check_update AFTER UPDATE ON accounts FOR EACH ROW WHEN (NEW.balance > 10000) EXECUTE FUNCTION check_account_update";
let expected = Statement::CreateTrigger {
or_replace: false,
is_constraint: false,
name: ObjectName(vec![Ident::new("check_update")]),
period: TriggerPeriod::After,
events: vec![TriggerEvent::Update(vec![])],
table_name: ObjectName(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
condition: Some(Expr::Nested(Box::new(Expr::BinaryOp {
left: Box::new(Expr::CompoundIdentifier(vec![
Ident::new("NEW"),
Ident::new("balance"),
])),
op: BinaryOperator::Gt,
right: Box::new(Expr::Value(number("10000"))),
}))),
exec_body: TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
func_desc: FunctionDesc {
name: ObjectName(vec![Ident::new("check_account_update")]),
args: None,
},
},
characteristics: None,
};
assert_eq!(pg().verified_stmt(sql), expected);
}
#[test]
fn parse_create_instead_of_delete_trigger() {
let sql = "CREATE TRIGGER check_delete INSTEAD OF DELETE ON accounts FOR EACH ROW EXECUTE FUNCTION check_account_deletes";
let expected = Statement::CreateTrigger {
or_replace: false,
is_constraint: false,
name: ObjectName(vec![Ident::new("check_delete")]),
period: TriggerPeriod::InsteadOf,
events: vec![TriggerEvent::Delete],
table_name: ObjectName(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
condition: None,
exec_body: TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
func_desc: FunctionDesc {
name: ObjectName(vec![Ident::new("check_account_deletes")]),
args: None,
},
},
characteristics: None,
};
assert_eq!(pg().verified_stmt(sql), expected);
}
#[test]
fn parse_create_trigger_with_multiple_events_and_deferrable() {
let sql = "CREATE CONSTRAINT TRIGGER check_multiple_events BEFORE INSERT OR UPDATE OR DELETE ON accounts DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION check_account_changes";
let expected = Statement::CreateTrigger {
or_replace: false,
is_constraint: true,
name: ObjectName(vec![Ident::new("check_multiple_events")]),
period: TriggerPeriod::Before,
events: vec![
TriggerEvent::Insert,
TriggerEvent::Update(vec![]),
TriggerEvent::Delete,
],
table_name: ObjectName(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
condition: None,
exec_body: TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
func_desc: FunctionDesc {
name: ObjectName(vec![Ident::new("check_account_changes")]),
args: None,
},
},
characteristics: Some(ConstraintCharacteristics {
deferrable: Some(true),
initially: Some(DeferrableInitial::Deferred),
enforced: None,
}),
};
assert_eq!(pg().verified_stmt(sql), expected);
}
#[test]
fn parse_create_trigger_with_referencing() {
let sql = "CREATE TRIGGER check_referencing BEFORE INSERT ON accounts REFERENCING NEW TABLE AS new_accounts OLD TABLE AS old_accounts FOR EACH ROW EXECUTE FUNCTION check_account_referencing";
let expected = Statement::CreateTrigger {
or_replace: false,
is_constraint: false,
name: ObjectName(vec![Ident::new("check_referencing")]),
period: TriggerPeriod::Before,
events: vec![TriggerEvent::Insert],
table_name: ObjectName(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![
TriggerReferencing {
refer_type: TriggerReferencingType::NewTable,
is_as: true,
transition_relation_name: ObjectName(vec![Ident::new("new_accounts")]),
},
TriggerReferencing {
refer_type: TriggerReferencingType::OldTable,
is_as: true,
transition_relation_name: ObjectName(vec![Ident::new("old_accounts")]),
},
],
trigger_object: TriggerObject::Row,
include_each: true,
condition: None,
exec_body: TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
func_desc: FunctionDesc {
name: ObjectName(vec![Ident::new("check_account_referencing")]),
args: None,
},
},
characteristics: None,
};
assert_eq!(pg().verified_stmt(sql), expected);
}
#[test]
/// While in the parse_create_trigger test we test the full syntax of the CREATE TRIGGER statement,
/// here we test the invalid cases of the CREATE TRIGGER statement which should cause an appropriate
/// error to be returned.
fn parse_create_trigger_invalid_cases() {
// Test invalid cases for the CREATE TRIGGER statement
let invalid_cases = vec![
(
"CREATE TRIGGER check_update BEFORE UPDATE ON accounts FUNCTION check_account_update",
"Expected: FOR, found: FUNCTION"
),
(
"CREATE TRIGGER check_update TOMORROW UPDATE ON accounts EXECUTE FUNCTION check_account_update",
"Expected: one of BEFORE or AFTER or INSTEAD, found: TOMORROW"
),
(
"CREATE TRIGGER check_update BEFORE SAVE ON accounts EXECUTE FUNCTION check_account_update",
"Expected: one of INSERT or UPDATE or DELETE or TRUNCATE, found: SAVE"
)
];
for (sql, expected_error) in invalid_cases {
let res = pg().parse_sql_statements(sql);
assert_eq!(
format!("sql parser error: {expected_error}"),
res.unwrap_err().to_string()
);
}
}
#[test]
fn parse_drop_trigger() {
for if_exists in [true, false] {
for option in [
None,
Some(ReferentialAction::Cascade),
Some(ReferentialAction::Restrict),
] {
let sql = &format!(
"DROP TRIGGER{} check_update ON table_name{}",
if if_exists { " IF EXISTS" } else { "" },
option
.map(|o| format!(" {}", o))
.unwrap_or_else(|| "".to_string())
);
assert_eq!(
pg().verified_stmt(sql),
Statement::DropTrigger {
if_exists,
trigger_name: ObjectName(vec![Ident::new("check_update")]),
table_name: ObjectName(vec![Ident::new("table_name")]),
option
}
);
}
}
}
#[test]
fn parse_drop_trigger_invalid_cases() {
// Test invalid cases for the DROP TRIGGER statement
let invalid_cases = vec![
(
"DROP TRIGGER check_update ON table_name CASCADE RESTRICT",
"Expected: end of statement, found: RESTRICT",
),
(
"DROP TRIGGER check_update ON table_name CASCADE CASCADE",
"Expected: end of statement, found: CASCADE",
),
(
"DROP TRIGGER check_update ON table_name CASCADE CASCADE CASCADE",
"Expected: end of statement, found: CASCADE",
),
];
for (sql, expected_error) in invalid_cases {
let res = pg().parse_sql_statements(sql);
assert_eq!(
format!("sql parser error: {expected_error}"),
res.unwrap_err().to_string()
);
}
}
#[test]
fn parse_trigger_related_functions() {
// First we define all parts of the trigger definition,
// including the table creation, the function creation, the trigger creation and the trigger drop.
// The following example is taken from the PostgreSQL documentation <https://www.postgresql.org/docs/current/plpgsql-trigger.html>
let sql_table_creation = r#"
CREATE TABLE emp (
empname text,
salary integer,
last_date timestamp,
last_user text
);
"#;
let sql_create_function = r#"
CREATE FUNCTION emp_stamp() RETURNS trigger AS $emp_stamp$
BEGIN
-- Check that empname and salary are given
IF NEW.empname IS NULL THEN
RAISE EXCEPTION 'empname cannot be null';
END IF;
IF NEW.salary IS NULL THEN
RAISE EXCEPTION '% cannot have null salary', NEW.empname;
END IF;
-- Who works for us when they must pay for it?
IF NEW.salary < 0 THEN
RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;
END IF;
-- Remember who changed the payroll when
NEW.last_date := current_timestamp;
NEW.last_user := current_user;
RETURN NEW;
END;
$emp_stamp$ LANGUAGE plpgsql;
"#;
let sql_create_trigger = r#"
CREATE TRIGGER emp_stamp BEFORE INSERT OR UPDATE ON emp
FOR EACH ROW EXECUTE FUNCTION emp_stamp();
"#;
let sql_drop_trigger = r#"
DROP TRIGGER emp_stamp ON emp;
"#;
// Now we parse the statements and check if they are parsed correctly.
let mut statements = pg()
.parse_sql_statements(&format!(
"{}{}{}{}",
sql_table_creation, sql_create_function, sql_create_trigger, sql_drop_trigger
))
.unwrap();
assert_eq!(statements.len(), 4);
let drop_trigger = statements.pop().unwrap();
let create_trigger = statements.pop().unwrap();
let create_function = statements.pop().unwrap();
let create_table = statements.pop().unwrap();
// Check the first statement
let create_table = match create_table {
Statement::CreateTable(create_table) => create_table,
_ => panic!("Expected CreateTable statement"),
};
assert_eq!(
create_table,
CreateTable {
or_replace: false,
temporary: false,
external: false,
global: None,
if_not_exists: false,
transient: false,
volatile: false,
name: ObjectName(vec![Ident::new("emp")]),
columns: vec![
ColumnDef {
name: "empname".into(),
data_type: DataType::Text,
collation: None,
options: vec![],
},
ColumnDef {
name: "salary".into(),
data_type: DataType::Integer(None),
collation: None,
options: vec![],
},
ColumnDef {
name: "last_date".into(),
data_type: DataType::Timestamp(None, TimezoneInfo::None),
collation: None,
options: vec![],
},
ColumnDef {
name: "last_user".into(),
data_type: DataType::Text,
collation: None,
options: vec![],
},
],
constraints: vec![],
hive_distribution: HiveDistributionStyle::NONE,
hive_formats: Some(HiveFormat {
row_format: None,
serde_properties: None,
storage: None,
location: None
}),
table_properties: vec![],
with_options: vec![],
file_format: None,
location: None,
query: None,
without_rowid: false,
like: None,
clone: None,
engine: None,
comment: None,
auto_increment_offset: None,
default_charset: None,
collation: None,
on_commit: None,
on_cluster: None,
primary_key: None,
order_by: None,
partition_by: None,
cluster_by: None,
options: None,
strict: false,
copy_grants: false,
enable_schema_evolution: None,
change_tracking: None,
data_retention_time_in_days: None,
max_data_extension_time_in_days: None,
default_ddl_collation: None,
with_aggregation_policy: None,
with_row_access_policy: None,
with_tags: None,
}
);
// Check the second statement
assert_eq!(
create_function,
Statement::CreateFunction {
or_replace: false,
temporary: false,
if_not_exists: false,
name: ObjectName(vec![Ident::new("emp_stamp")]),
args: None,
return_type: Some(DataType::Trigger),
function_body: Some(
CreateFunctionBody::AsBeforeOptions(
Expr::Value(
Value::DollarQuotedString(
DollarQuotedString {
value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n \n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n \n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(),
tag: Some(
"emp_stamp".to_owned(),
),
},
),
),
),
),
behavior: None,
called_on_null: None,
parallel: None,
using: None,
language: Some(Ident::new("plpgsql")),
determinism_specifier: None,
options: None,
remote_connection: None
}
);
// Check the third statement
assert_eq!(
create_trigger,
Statement::CreateTrigger {
or_replace: false,
is_constraint: false,
name: ObjectName(vec![Ident::new("emp_stamp")]),
period: TriggerPeriod::Before,
events: vec![TriggerEvent::Insert, TriggerEvent::Update(vec![])],
table_name: ObjectName(vec![Ident::new("emp")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
condition: None,
exec_body: TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
func_desc: FunctionDesc {
name: ObjectName(vec![Ident::new("emp_stamp")]),
args: None,
}
},
characteristics: None
}
);
// Check the fourth statement
assert_eq!(
drop_trigger,
Statement::DropTrigger {
if_exists: false,
trigger_name: ObjectName(vec![Ident::new("emp_stamp")]),
table_name: ObjectName(vec![Ident::new("emp")]),
option: None
}
);
}
#[test]
fn test_unicode_string_literal() {
let pairs = [