Mask out union tag ID from pointer when calculating refcount address

This commit is contained in:
Brian Carroll 2022-01-03 15:08:29 +00:00
parent 94dea1df9f
commit 5e642c880c
3 changed files with 108 additions and 30 deletions

View file

@ -189,6 +189,7 @@ pub fn rc_ptr_from_data_ptr<'a>(
ident_ids: &mut IdentIds,
structure: Symbol,
rc_ptr_sym: Symbol,
mask_lower_bits: bool,
following: &'a Stmt<'a>,
) -> Stmt<'a> {
// Typecast the structure pointer to an integer
@ -203,6 +204,21 @@ pub fn rc_ptr_from_data_ptr<'a>(
});
let addr_stmt = |next| Stmt::Let(addr_sym, addr_expr, root.layout_isize, next);
// Mask for lower bits (for tag union id)
let mask_sym = root.create_symbol(ident_ids, "mask");
let mask_expr = Expr::Literal(Literal::Int(-(root.ptr_size as i128)));
let mask_stmt = |next| Stmt::Let(mask_sym, mask_expr, root.layout_isize, next);
let masked_sym = root.create_symbol(ident_ids, "masked");
let and_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::And,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([addr_sym, mask_sym]),
});
let and_stmt = |next| Stmt::Let(masked_sym, and_expr, root.layout_isize, next);
// Pointer size constant
let ptr_size_sym = root.create_symbol(ident_ids, "ptr_size");
let ptr_size_expr = Expr::Literal(Literal::Int(root.ptr_size as i128));
@ -210,32 +226,60 @@ pub fn rc_ptr_from_data_ptr<'a>(
// Refcount address
let rc_addr_sym = root.create_symbol(ident_ids, "rc_addr");
let rc_addr_expr = Expr::Call(Call {
let sub_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::NumSub,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([addr_sym, ptr_size_sym]),
arguments: root.arena.alloc([
if mask_lower_bits {
masked_sym
} else {
addr_sym
},
ptr_size_sym,
]),
});
let rc_addr_stmt = |next| Stmt::Let(rc_addr_sym, rc_addr_expr, root.layout_isize, next);
let sub_stmt = |next| Stmt::Let(rc_addr_sym, sub_expr, root.layout_isize, next);
// Typecast the refcount address from integer to pointer
let rc_ptr_expr = Expr::Call(Call {
let cast_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([rc_addr_sym]),
});
let rc_ptr_stmt = |next| Stmt::Let(rc_ptr_sym, rc_ptr_expr, LAYOUT_PTR, next);
let cast_stmt = |next| Stmt::Let(rc_ptr_sym, cast_expr, LAYOUT_PTR, next);
if mask_lower_bits {
addr_stmt(root.arena.alloc(
//
mask_stmt(root.arena.alloc(
//
and_stmt(root.arena.alloc(
//
ptr_size_stmt(root.arena.alloc(
//
sub_stmt(root.arena.alloc(
//
cast_stmt(root.arena.alloc(
//
following,
)),
)),
)),
)),
)),
))
} else {
addr_stmt(root.arena.alloc(
//
ptr_size_stmt(root.arena.alloc(
//
rc_addr_stmt(root.arena.alloc(
sub_stmt(root.arena.alloc(
//
rc_ptr_stmt(root.arena.alloc(
cast_stmt(root.arena.alloc(
//
following,
)),
@ -243,6 +287,7 @@ pub fn rc_ptr_from_data_ptr<'a>(
)),
))
}
}
fn modify_refcount<'a>(
root: &CodeGenHelp<'a>,
@ -356,6 +401,7 @@ fn refcount_str<'a>(
ident_ids,
elements,
rc_ptr,
false,
root.arena.alloc(
//
mod_rc_stmt,
@ -467,7 +513,14 @@ fn refcount_list<'a>(
let modify_list_and_elems = elements_stmt(arena.alloc(
//
rc_ptr_from_data_ptr(root, ident_ids, elements, rc_ptr, arena.alloc(modify_list)),
rc_ptr_from_data_ptr(
root,
ident_ids,
elements,
rc_ptr,
false,
arena.alloc(modify_list),
),
));
//
@ -753,15 +806,6 @@ fn refcount_tag_union_help<'a>(
structure: Symbol,
) -> Stmt<'a> {
let is_non_recursive = matches!(union_layout, UnionLayout::NonRecursive(_));
/* TODO: tail recursion
let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop"));
let structure = if is_non_recursive {
initial_structure
} else {
// current value in the tail-recursive loop
root.create_symbol(ident_ids, "structure")
};
*/
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
@ -836,6 +880,7 @@ fn refcount_tag_union_help<'a>(
ident_ids,
structure,
rc_ptr,
union_layout.stores_tag_id_in_pointer(root.ptr_size),
root.arena.alloc(modify_structure_stmt),
);

View file

@ -206,21 +206,54 @@ fn union_recursive_inc() {
r#"
Expr : [ Sym Str, Add Expr Expr ]
s = Str.concat "heap_allocated_" "symbol_name"
s = Str.concat "heap_allocated" "_symbol_name"
x : Expr
x = Sym s
e : Expr
e = Add (Sym s) (Sym s)
e = Add x x
[e, e]
"#
),
RocStr,
// test_wrapper receives a List, doesn't matter kind of elements it points to
RocList<usize>,
&[
4, // s
1, // Sym
1, // Sym
2, // Add
4, // sym
2, // e
1 // list
]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_recursive_dec() {
assert_refcounts!(
indoc!(
r#"
Expr : [ Sym Str, Add Expr Expr ]
s = Str.concat "heap_allocated" "_symbol_name"
x : Expr
x = Sym s
e : Expr
e = Add x x
when e is
Add y _ -> y
Sym _ -> e
"#
),
&RocStr,
&[
1, // s
1, // sym
0 // e
]
);
}