update list generated refcounting functions for dev and wasm

This commit is contained in:
Brendan Hansknecht 2024-04-11 21:33:27 -07:00
parent 3238ee7d0d
commit 48eb9c31a9
No known key found for this signature in database
GPG key ID: 0EA784685083E75B
9 changed files with 163 additions and 354 deletions

View file

@ -196,14 +196,7 @@ pub fn refcount_generic<'a>(
rc_return_stmt(root, ident_ids, ctx)
}
LayoutRepr::Builtin(Builtin::Str) => refcount_str(root, ident_ids, ctx),
LayoutRepr::Builtin(Builtin::List(elem_layout)) => refcount_list(
root,
ident_ids,
ctx,
layout_interner,
elem_layout,
structure,
),
LayoutRepr::Builtin(Builtin::List(_)) => refcount_list(root, ident_ids, ctx, structure),
LayoutRepr::Struct(field_layouts) => refcount_struct(
root,
ident_ids,
@ -954,357 +947,29 @@ fn refcount_list<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout_interner: &mut STLayoutInterner<'a>,
elem_layout: InLayout<'a>,
structure: Symbol,
_structure: Symbol,
) -> Stmt<'a> {
let layout_isize = root.layout_isize;
let arena = root.arena;
// A "Ptr" layout (heap pointer to a single list element)
let ptr_layout = layout_interner.insert_direct_no_semantic(LayoutRepr::Ptr(elem_layout));
//
// Check if the list is empty
//
let len = root.create_symbol(ident_ids, "len");
let len_stmt = |next| let_lowlevel(arena, layout_isize, len, ListLenUsize, &[structure], next);
// let zero = 0
let zero = root.create_symbol(ident_ids, "zero");
let zero_expr = Expr::Literal(Literal::Int(0i128.to_ne_bytes()));
let zero_stmt = |next| Stmt::Let(zero, zero_expr, layout_isize, next);
// let is_empty = lowlevel Eq len zero
let is_empty = root.create_symbol(ident_ids, "is_empty");
let is_empty_stmt = |next| let_lowlevel(arena, LAYOUT_BOOL, is_empty, Eq, &[len, zero], next);
//
// Check for seamless slice
//
// let capacity = StructAtIndex 2 structure
let capacity = root.create_symbol(ident_ids, "capacity");
let list_field_layouts = arena.alloc([ptr_layout, layout_isize, layout_isize]);
let capacity_expr = Expr::StructAtIndex {
index: 2,
field_layouts: list_field_layouts,
structure,
};
let capacity_stmt = |next| Stmt::Let(capacity, capacity_expr, layout_isize, next);
// let is_slice = lowlevel NumLt capacity zero
let is_slice = root.create_symbol(ident_ids, "is_slice");
let is_slice_stmt =
|next| let_lowlevel(arena, LAYOUT_BOOL, is_slice, NumLt, &[capacity, zero], next);
//
// Branch on slice vs list
//
// let first_element = StructAtIndex 0 structure
let first_element = root.create_symbol(ident_ids, "first_element");
let first_element_expr = Expr::StructAtIndex {
index: 0,
field_layouts: list_field_layouts,
structure,
};
let first_element_stmt = |next| Stmt::Let(first_element, first_element_expr, ptr_layout, next);
let jp_elements = JoinPointId(root.create_symbol(ident_ids, "jp_elements"));
let data_pointer = root.create_symbol(ident_ids, "data_pointer");
let param_data_pointer = Param {
symbol: data_pointer,
layout: Layout::OPAQUE_PTR,
};
let first_element_pointer = root.create_symbol(ident_ids, "first_element_pointer");
let param_first_element_pointer = Param {
symbol: first_element_pointer,
layout: Layout::OPAQUE_PTR,
};
// one = 1
let one = root.create_symbol(ident_ids, "one");
let one_expr = Expr::Literal(Literal::Int(1i128.to_ne_bytes()));
let one_stmt = |next| Stmt::Let(one, one_expr, layout_isize, next);
let slice_data_pointer = root.create_symbol(ident_ids, "slice_data_pointer");
let slice_data_pointer_stmt = move |next| {
one_stmt(arena.alloc(
//
let_lowlevel(
arena,
layout_isize,
slice_data_pointer,
LowLevel::NumShiftLeftBy,
&[capacity, one],
arena.alloc(next),
),
))
};
let slice_branch = slice_data_pointer_stmt(
//
Stmt::Jump(
jp_elements,
arena.alloc([slice_data_pointer, first_element]),
),
);
let list_branch = arena.alloc(
//
Stmt::Jump(jp_elements, arena.alloc([first_element, first_element])),
);
let switch_slice_list = arena.alloc(first_element_stmt(arena.alloc(
//
Stmt::if_then_else(
root.arena,
is_slice,
Layout::UNIT,
slice_branch,
arena.alloc(list_branch),
),
)));
//
// modify refcount of the list and its elements
// (elements first, to avoid use-after-free for when decrementing)
//
let alignment = Ord::max(
root.target.ptr_width() as u32,
layout_interner.alignment_bytes(elem_layout),
);
let ret_stmt = arena.alloc(rc_return_stmt(root, ident_ids, ctx));
let mut modify_refcount_stmt =
|ptr| modify_refcount(root, ident_ids, ctx, ptr, alignment, ret_stmt);
let modify_list = modify_refcount_stmt(Pointer::ToData(data_pointer));
let is_relevant_op = ctx.op.is_dec() || ctx.op.is_inc();
let modify_elems_and_list = if is_relevant_op && layout_interner.is_refcounted(elem_layout) {
refcount_list_elems(
root,
ident_ids,
ctx,
layout_interner,
elem_layout,
LAYOUT_UNIT,
ptr_layout,
len,
first_element_pointer,
modify_list,
)
} else {
modify_list
let lowlevel = match ctx.op {
HelperOp::IncN | HelperOp::Inc => LowLevel::ListIncref,
HelperOp::DecRef(_) | HelperOp::Dec => LowLevel::ListDecref,
_ => unreachable!(),
};
//
// JoinPoint for slice vs list
//
let joinpoint_elems = Stmt::Join {
id: jp_elements,
parameters: arena.alloc([param_data_pointer, param_first_element_pointer]),
body: arena.alloc(modify_elems_and_list),
remainder: arena.alloc(switch_slice_list),
};
//
// Do nothing if the list is empty
//
let non_empty_branch = arena.alloc(
//
capacity_stmt(arena.alloc(
//
is_slice_stmt(arena.alloc(
//
joinpoint_elems,
)),
)),
);
let if_empty_stmt = Stmt::if_then_else(
root.arena,
is_empty,
Layout::UNIT,
rc_return_stmt(root, ident_ids, ctx),
non_empty_branch,
);
len_stmt(arena.alloc(
//
zero_stmt(arena.alloc(
//
is_empty_stmt(arena.alloc(
//
if_empty_stmt,
)),
)),
))
}
fn refcount_list_elems<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout_interner: &mut STLayoutInterner<'a>,
elem_layout: InLayout<'a>,
ret_layout: InLayout<'a>,
ptr_layout: InLayout<'a>,
length: Symbol,
elements: Symbol,
following: Stmt<'a>,
) -> Stmt<'a> {
use LowLevel::*;
let layout_isize = root.layout_isize;
let arena = root.arena;
// Cast to integer
let start = root.create_symbol(ident_ids, "start");
let start_stmt = |next| let_lowlevel(arena, layout_isize, start, PtrCast, &[elements], next);
//
// Loop initialisation
//
// let size = literal int
let elem_size = root.create_symbol(ident_ids, "elem_size");
let elem_size_expr = Expr::Literal(Literal::Int(
(layout_interner.stack_size(elem_layout) as i128).to_ne_bytes(),
));
let elem_size_stmt = |next| Stmt::Let(elem_size, elem_size_expr, layout_isize, next);
// let list_size = len * size
let list_size = root.create_symbol(ident_ids, "list_size");
let list_size_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
list_size,
NumMul,
&[length, elem_size],
next,
)
};
// let end = start + list_size
let end = root.create_symbol(ident_ids, "end");
let end_stmt = |next| let_lowlevel(arena, layout_isize, end, NumAdd, &[start, list_size], next);
//
// Loop name & parameter
//
let elems_loop = JoinPointId(root.create_symbol(ident_ids, "elems_loop"));
let addr = root.create_symbol(ident_ids, "addr");
let param_addr = Param {
symbol: addr,
layout: layout_isize,
};
//
// if we haven't reached the end yet...
//
// Cast integer to pointer
let ptr_symbol = root.create_symbol(ident_ids, "ptr");
let ptr_stmt = |next| let_lowlevel(arena, ptr_layout, ptr_symbol, PtrCast, &[addr], next);
// Dereference the pointer to get the current element
let elem = root.create_symbol(ident_ids, "elem");
let elem_expr = Expr::ptr_load(arena.alloc(ptr_symbol));
let elem_stmt = |next| Stmt::Let(elem, elem_expr, elem_layout, next);
//
// Modify element refcount
//
let mod_elem_unit = root.create_symbol(ident_ids, "mod_elem_unit");
let mod_elem_args = refcount_args(root, ctx, elem);
let mod_elem_expr = root
.call_specialized_op(ident_ids, ctx, layout_interner, elem_layout, mod_elem_args)
.unwrap();
let mod_elem_stmt = |next| Stmt::Let(mod_elem_unit, mod_elem_expr, LAYOUT_UNIT, next);
//
// Next loop iteration
//
let next_addr = root.create_symbol(ident_ids, "next_addr");
let next_addr_stmt = |next| {
//
let_lowlevel(
arena,
layout_isize,
next_addr,
NumAddSaturated,
&[addr, elem_size],
next,
)
};
//
// Control flow
//
let is_end = root.create_symbol(ident_ids, "is_end");
let is_end_stmt = |next| let_lowlevel(arena, LAYOUT_BOOL, is_end, NumGte, &[addr, end], next);
let if_end_of_list = Stmt::if_then_else(
arena,
is_end,
ret_layout,
following,
arena.alloc(ptr_stmt(arena.alloc(
//
elem_stmt(arena.alloc(
//
mod_elem_stmt(arena.alloc(
//
next_addr_stmt(arena.alloc(
//
Stmt::Jump(elems_loop, arena.alloc([next_addr])),
)),
)),
)),
))),
);
let joinpoint_loop = Stmt::Join {
id: elems_loop,
parameters: arena.alloc([param_addr]),
body: arena.alloc(
//
is_end_stmt(
//
arena.alloc(if_end_of_list),
),
),
remainder: root
.arena
.alloc(Stmt::Jump(elems_loop, arena.alloc([start, end]))),
};
start_stmt(arena.alloc(
//
elem_size_stmt(arena.alloc(
//
list_size_stmt(arena.alloc(
//
end_stmt(arena.alloc(
//
joinpoint_loop,
)),
)),
)),
))
let list = Symbol::ARG_1;
let rc_list_args = refcount_args(root, ctx, list);
let rc_list_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: lowlevel,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: rc_list_args,
});
let rc_list_unit = root.create_symbol(ident_ids, "rc_list");
Stmt::Let(rc_list_unit, rc_list_expr, LAYOUT_UNIT, ret_stmt)
}
fn refcount_struct<'a>(