Merge pull request #5312 from roc-lang/dev-refcount-seamless-slice

Dev refcount seamless slice
This commit is contained in:
Brian Carroll 2023-05-14 18:37:23 +01:00 committed by GitHub
commit cfcd2a5289
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 410 additions and 119 deletions

View file

@ -719,11 +719,17 @@ fn rc_ptr_from_data_ptr_help<'a>(
}
}
enum Pointer {
ToData(Symbol),
#[allow(unused)]
ToRefcount(Symbol),
}
fn modify_refcount<'a>(
root: &CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
data_ptr: Symbol,
ptr: Pointer,
alignment: u32,
following: &'a Stmt<'a>,
) -> Stmt<'a> {
@ -731,28 +737,39 @@ fn modify_refcount<'a>(
let zig_call_result = root.create_symbol(ident_ids, "zig_call_result");
match ctx.op {
HelperOp::Inc => {
let (op, ptr) = match ptr {
Pointer::ToData(s) => (LowLevel::RefCountIncDataPtr, s),
Pointer::ToRefcount(s) => (LowLevel::RefCountIncRcPtr, s),
};
let zig_call_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::RefCountIncDataPtr,
op,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([data_ptr, Symbol::ARG_2]),
arguments: root.arena.alloc([ptr, Symbol::ARG_2]),
});
Stmt::Let(zig_call_result, zig_call_expr, LAYOUT_UNIT, following)
}
HelperOp::Dec | HelperOp::DecRef(_) => {
debug_assert!(alignment >= root.target_info.ptr_width() as u32);
let (op, ptr) = match ptr {
Pointer::ToData(s) => (LowLevel::RefCountDecDataPtr, s),
Pointer::ToRefcount(s) => (LowLevel::RefCountDecRcPtr, s),
};
let alignment_sym = root.create_symbol(ident_ids, "alignment");
let alignment_expr = Expr::Literal(Literal::Int((alignment as i128).to_ne_bytes()));
let alignment_stmt = |next| Stmt::Let(alignment_sym, alignment_expr, LAYOUT_U32, next);
let zig_call_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::RefCountDecDataPtr,
op,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([data_ptr, alignment_sym]),
arguments: root.arena.alloc([ptr, alignment_sym]),
});
let zig_call_stmt = Stmt::Let(zig_call_result, zig_call_expr, LAYOUT_UNIT, following);
@ -772,11 +789,10 @@ fn refcount_str<'a>(
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
) -> Stmt<'a> {
let arena = root.arena;
let string = Symbol::ARG_1;
let layout_isize = root.layout_isize;
let field_layouts = root
.arena
.alloc([Layout::OPAQUE_PTR, layout_isize, layout_isize]);
let field_layouts = arena.alloc([Layout::OPAQUE_PTR, layout_isize, layout_isize]);
// Get the last word as a signed int
let last_word = root.create_symbol(ident_ids, "last_word");
@ -794,57 +810,154 @@ fn refcount_str<'a>(
// is_big_str = (last_word >= 0);
// Treat last word as isize so that the small string flag is the same as the sign bit
// (assuming a little-endian target, where the sign bit is in the last byte of the word)
let is_big_str = root.create_symbol(ident_ids, "is_big_str");
let is_big_str_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::NumGte,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([last_word, zero]),
});
let is_big_str_stmt = |next| Stmt::Let(is_big_str, is_big_str_expr, LAYOUT_BOOL, next);
let is_big_str_stmt = |next| {
let_lowlevel(
arena,
LAYOUT_BOOL,
is_big_str,
NumGte,
&[last_word, zero],
next,
)
};
// Get the pointer to the string elements
let elements = root.create_symbol(ident_ids, "characters");
let elements_expr = Expr::StructAtIndex {
//
// Check for seamless slice
//
// Get the length field as a signed int
let length = root.create_symbol(ident_ids, "length");
let length_expr = Expr::StructAtIndex {
index: 1,
field_layouts,
structure: string,
};
let length_stmt = |next| Stmt::Let(length, length_expr, layout_isize, next);
let alignment = root.target_info.ptr_width() as u32;
// let is_slice = lowlevel NumLt length zero
let is_slice = root.create_symbol(ident_ids, "is_slice");
let is_slice_stmt =
|next| let_lowlevel(arena, LAYOUT_BOOL, is_slice, NumLt, &[length, zero], next);
//
// Branch on seamless slice vs "real" string
//
let return_unit = arena.alloc(rc_return_stmt(root, ident_ids, ctx));
let one = root.create_symbol(ident_ids, "one");
let one_expr = Expr::Literal(Literal::Int(1i128.to_ne_bytes()));
let one_stmt = |next| Stmt::Let(one, one_expr, layout_isize, next);
let data_ptr_int = root.create_symbol(ident_ids, "data_ptr_int");
let data_ptr_int_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
data_ptr_int,
PtrCast,
&[last_word],
next,
)
};
let data_ptr = root.create_symbol(ident_ids, "data_ptr");
let data_ptr_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
data_ptr,
NumShiftLeftBy,
&[data_ptr_int, one],
next,
)
};
// when the string is a slice, the capacity field is a pointer to the refcount
let slice_branch = one_stmt(arena.alloc(
//
data_ptr_int_stmt(arena.alloc(
//
data_ptr_stmt(arena.alloc(
//
modify_refcount(
root,
ident_ids,
ctx,
Pointer::ToData(data_ptr),
alignment,
return_unit,
),
)),
)),
));
// Characters pointer for a real string
let string_chars = root.create_symbol(ident_ids, "string_chars");
let string_chars_expr = Expr::StructAtIndex {
index: 0,
field_layouts,
structure: string,
};
let elements_stmt = |next| Stmt::Let(elements, elements_expr, layout_isize, next);
let string_chars_stmt = |next| Stmt::Let(string_chars, string_chars_expr, layout_isize, next);
// A pointer to the refcount value itself
let alignment = root.target_info.ptr_width() as u32;
let ret_unit_stmt = rc_return_stmt(root, ident_ids, ctx);
let mod_rc_stmt = modify_refcount(
let modify_refcount_stmt = modify_refcount(
root,
ident_ids,
ctx,
elements,
Pointer::ToData(string_chars),
alignment,
root.arena.alloc(ret_unit_stmt),
return_unit,
);
// Generate an `if` to skip small strings but modify big strings
let then_branch = elements_stmt(root.arena.alloc(mod_rc_stmt));
let string_branch = arena.alloc(
//
string_chars_stmt(arena.alloc(
//
modify_refcount_stmt,
)),
);
let if_stmt = Stmt::if_then_else(
let if_slice = Stmt::if_then_else(
root.arena,
is_slice,
Layout::UNIT,
slice_branch,
string_branch,
);
//
// JoinPoint for slice vs list
//
let modify_stmt = length_stmt(arena.alloc(
//
is_slice_stmt(arena.alloc(
//
if_slice,
)),
));
let if_big_stmt = Stmt::if_then_else(
root.arena,
is_big_str,
Layout::UNIT,
then_branch,
modify_stmt,
root.arena.alloc(rc_return_stmt(root, ident_ids, ctx)),
);
// Combine the statements in sequence
last_word_stmt(root.arena.alloc(
last_word_stmt(arena.alloc(
//
zero_stmt(root.arena.alloc(
zero_stmt(arena.alloc(
//
is_big_str_stmt(root.arena.alloc(
is_big_str_stmt(arena.alloc(
//
if_stmt,
if_big_stmt,
)),
)),
))
@ -871,34 +984,109 @@ fn refcount_list<'a>(
let len = root.create_symbol(ident_ids, "len");
let len_stmt = |next| let_lowlevel(arena, layout_isize, len, ListLen, &[structure], next);
// Zero
// let zero = 0
let zero = root.create_symbol(ident_ids, "zero");
let zero_expr = Expr::Literal(Literal::Int(0i128.to_ne_bytes()));
let zero_stmt = |next| Stmt::Let(zero, zero_expr, layout_isize, next);
// let is_empty = lowlevel Eq len zero
let is_empty = root.create_symbol(ident_ids, "is_empty");
let is_empty_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::Eq,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([len, zero]),
});
let is_empty_stmt = |next| Stmt::Let(is_empty, is_empty_expr, LAYOUT_BOOL, next);
let is_empty_stmt = |next| let_lowlevel(arena, LAYOUT_BOOL, is_empty, Eq, &[len, zero], next);
// get elements pointer
let elements = root.create_symbol(ident_ids, "elements");
let elements_expr = Expr::StructAtIndex {
index: 0,
field_layouts: arena.alloc([box_layout, layout_isize, layout_isize]),
//
// Check for seamless slice
//
// let capacity = StructAtIndex 2 structure
let capacity = root.create_symbol(ident_ids, "capacity");
let list_field_layouts = arena.alloc([box_layout, layout_isize, layout_isize]);
let capacity_expr = Expr::StructAtIndex {
index: 2,
field_layouts: list_field_layouts,
structure,
};
let elements_stmt = |next| Stmt::Let(elements, elements_expr, box_layout, next);
let capacity_stmt = |next| Stmt::Let(capacity, capacity_expr, layout_isize, next);
// let is_slice = lowlevel NumLt capacity zero
let is_slice = root.create_symbol(ident_ids, "is_slice");
let is_slice_stmt =
|next| let_lowlevel(arena, LAYOUT_BOOL, is_slice, NumLt, &[capacity, zero], next);
//
// Branch on slice vs list
//
// let first_element = StructAtIndex 0 structure
let first_element = root.create_symbol(ident_ids, "first_element");
let first_element_expr = Expr::StructAtIndex {
index: 0,
field_layouts: list_field_layouts,
structure,
};
let first_element_stmt = |next| Stmt::Let(first_element, first_element_expr, box_layout, next);
let jp_elements = JoinPointId(root.create_symbol(ident_ids, "jp_elements"));
let data_pointer = root.create_symbol(ident_ids, "data_pointer");
let param_data_pointer = Param {
symbol: data_pointer,
ownership: Ownership::Owned,
layout: Layout::OPAQUE_PTR,
};
let first_element_pointer = root.create_symbol(ident_ids, "first_element_pointer");
let param_first_element_pointer = Param {
symbol: first_element_pointer,
ownership: Ownership::Owned,
layout: Layout::OPAQUE_PTR,
};
// one = 1
let one = root.create_symbol(ident_ids, "one");
let one_expr = Expr::Literal(Literal::Int(1i128.to_ne_bytes()));
let one_stmt = |next| Stmt::Let(one, one_expr, layout_isize, next);
let slice_data_pointer = root.create_symbol(ident_ids, "slice_data_pointer");
let slice_data_pointer_stmt = move |next| {
one_stmt(arena.alloc(
//
let_lowlevel(
arena,
layout_isize,
slice_data_pointer,
LowLevel::NumShiftLeftBy,
&[capacity, one],
arena.alloc(next),
),
))
};
let slice_branch = slice_data_pointer_stmt(
//
Stmt::Jump(
jp_elements,
arena.alloc([slice_data_pointer, first_element]),
),
);
let list_branch = arena.alloc(
//
Stmt::Jump(jp_elements, arena.alloc([first_element, first_element])),
);
let switch_slice_list = arena.alloc(first_element_stmt(arena.alloc(
//
Stmt::if_then_else(
root.arena,
is_slice,
Layout::UNIT,
slice_branch,
arena.alloc(list_branch),
),
)));
//
// modify refcount of the list and its elements
// (elements first, to avoid use-after-free for Dec)
// (elements first, to avoid use-after-free for when decrementing)
//
let alignment = Ord::max(
@ -906,47 +1094,58 @@ fn refcount_list<'a>(
layout_interner.alignment_bytes(elem_layout),
);
let ret_stmt = rc_return_stmt(root, ident_ids, ctx);
let modify_list = modify_refcount(
root,
ident_ids,
ctx,
elements,
alignment,
arena.alloc(ret_stmt),
);
let ret_stmt = arena.alloc(rc_return_stmt(root, ident_ids, ctx));
let mut modify_refcount_stmt =
|ptr| modify_refcount(root, ident_ids, ctx, ptr, alignment, ret_stmt);
let relevant_op = ctx.op.is_dec() || ctx.op.is_inc();
let modify_elems_and_list = if relevant_op && layout_interner.get(elem_layout).is_refcounted() {
refcount_list_elems(
root,
ident_ids,
ctx,
layout_interner,
elem_layout,
LAYOUT_UNIT,
box_layout,
len,
elements,
modify_list,
)
} else {
modify_list
let modify_list = modify_refcount_stmt(Pointer::ToData(data_pointer));
let is_relevant_op = ctx.op.is_dec() || ctx.op.is_inc();
let modify_elems_and_list =
if is_relevant_op && layout_interner.get(elem_layout).is_refcounted() {
refcount_list_elems(
root,
ident_ids,
ctx,
layout_interner,
elem_layout,
LAYOUT_UNIT,
box_layout,
len,
first_element_pointer,
modify_list,
)
} else {
modify_list
};
//
// JoinPoint for slice vs list
//
let joinpoint_elems = Stmt::Join {
id: jp_elements,
parameters: arena.alloc([param_data_pointer, param_first_element_pointer]),
body: arena.alloc(modify_elems_and_list),
remainder: arena.alloc(switch_slice_list),
};
//
// Do nothing if the list is empty
//
let non_empty_branch = root.arena.alloc(
let non_empty_branch = arena.alloc(
//
elements_stmt(root.arena.alloc(
capacity_stmt(arena.alloc(
//
modify_elems_and_list,
is_slice_stmt(arena.alloc(
//
joinpoint_elems,
)),
)),
);
let if_stmt = Stmt::if_then_else(
let if_empty_stmt = Stmt::if_then_else(
root.arena,
is_empty,
Layout::UNIT,
@ -960,7 +1159,7 @@ fn refcount_list<'a>(
//
is_empty_stmt(arena.alloc(
//
if_stmt,
if_empty_stmt,
)),
)),
))
@ -1019,8 +1218,8 @@ fn refcount_list_elems<'a>(
//
let elems_loop = JoinPointId(root.create_symbol(ident_ids, "elems_loop"));
let addr = root.create_symbol(ident_ids, "addr");
let addr = root.create_symbol(ident_ids, "addr");
let param_addr = Param {
symbol: addr,
ownership: Ownership::Owned,
@ -1054,13 +1253,15 @@ fn refcount_list_elems<'a>(
//
// Next loop iteration
//
let next_addr = root.create_symbol(ident_ids, "next_addr");
let next_addr_stmt = |next| {
//
let_lowlevel(
arena,
layout_isize,
next_addr,
NumAdd,
NumAddSaturated,
&[addr, elem_size],
next,
)
@ -1073,28 +1274,25 @@ fn refcount_list_elems<'a>(
let is_end = root.create_symbol(ident_ids, "is_end");
let is_end_stmt = |next| let_lowlevel(arena, LAYOUT_BOOL, is_end, NumGte, &[addr, end], next);
let if_end_of_list = Stmt::Switch {
cond_symbol: is_end,
cond_layout: LAYOUT_BOOL,
let if_end_of_list = Stmt::if_then_else(
arena,
is_end,
ret_layout,
branches: root.arena.alloc([(1, BranchInfo::None, following)]),
default_branch: (
BranchInfo::None,
arena.alloc(box_stmt(arena.alloc(
following,
arena.alloc(box_stmt(arena.alloc(
//
elem_stmt(arena.alloc(
//
elem_stmt(arena.alloc(
mod_elem_stmt(arena.alloc(
//
mod_elem_stmt(arena.alloc(
next_addr_stmt(arena.alloc(
//
next_addr_stmt(arena.alloc(
//
Stmt::Jump(elems_loop, arena.alloc([next_addr])),
)),
Stmt::Jump(elems_loop, arena.alloc([next_addr])),
)),
)),
))),
),
};
)),
))),
);
let joinpoint_loop = Stmt::Join {
id: elems_loop,
@ -1108,7 +1306,7 @@ fn refcount_list_elems<'a>(
),
remainder: root
.arena
.alloc(Stmt::Jump(elems_loop, arena.alloc([start]))),
.alloc(Stmt::Jump(elems_loop, arena.alloc([start, end]))),
};
start_stmt(arena.alloc(
@ -1491,7 +1689,7 @@ fn refcount_union_rec<'a>(
root,
ident_ids,
ctx,
structure,
Pointer::ToData(structure),
alignment,
root.arena.alloc(ret_stmt),
)
@ -1589,7 +1787,7 @@ fn refcount_union_tailrec<'a>(
root,
ident_ids,
ctx,
current,
Pointer::ToData(current),
alignment,
root.arena.alloc(loop_or_exit_based_on_next_addr),
)
@ -1811,7 +2009,7 @@ fn refcount_boxed<'a>(
root,
ident_ids,
ctx,
outer,
Pointer::ToData(outer),
alignment,
arena.alloc(ret_stmt),
);