diff --git a/compiler/mono/src/code_gen_help/refcount.rs b/compiler/mono/src/code_gen_help/refcount.rs index 755648c2ce..338c8ad445 100644 --- a/compiler/mono/src/code_gen_help/refcount.rs +++ b/compiler/mono/src/code_gen_help/refcount.rs @@ -867,6 +867,7 @@ fn refcount_union_contents<'a>( tag_id_layout: Layout<'a>, modify_union_stmt: Stmt<'a>, ) -> Stmt<'a> { + let jp_modify_union = JoinPointId(root.create_symbol(ident_ids, "jp_modify_union")); let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena); if let Some(id) = null_id { @@ -883,6 +884,10 @@ fn refcount_union_contents<'a>( _ => {} } + // After refcounting the fields, jump to modify the union itself + // (Order is important, to avoid use-after-free for Dec) + let following = Stmt::Jump(jp_modify_union, &[]); + let fields_stmt = refcount_tag_fields( root, ident_ids, @@ -891,7 +896,7 @@ fn refcount_union_contents<'a>( field_layouts, structure, tag_id, - modify_union_stmt.clone(), // TODO: Use a jump, this is a bit bloated + following, ); tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt)); @@ -901,12 +906,19 @@ fn refcount_union_contents<'a>( let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2; - Stmt::Switch { + let tag_id_switch = Stmt::Switch { cond_symbol: tag_id_sym, cond_layout: tag_id_layout, branches: tag_branches.into_bump_slice(), default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)), ret_layout: LAYOUT_UNIT, + }; + + Stmt::Join { + id: jp_modify_union, + parameters: &[], + body: root.arena.alloc(modify_union_stmt), + remainder: root.arena.alloc(tag_id_switch), } }