mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 13:29:12 +00:00
Implement tail recursion for union refcounting procs
This commit is contained in:
parent
8ebdc8ea7f
commit
5560ecb63e
3 changed files with 276 additions and 7 deletions
|
@ -381,6 +381,48 @@ impl<'a> CodeGenHelp<'a> {
|
|||
Layout::RecursivePointer => Layout::Union(ctx.recursive_union.unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
fn union_tail_recursion_fields(
|
||||
&self,
|
||||
union: UnionLayout<'a>,
|
||||
) -> (bool, Vec<'a, Option<usize>>) {
|
||||
use UnionLayout::*;
|
||||
match union {
|
||||
NonRecursive(_) => return (false, bumpalo::vec![in self.arena]),
|
||||
|
||||
Recursive(tags) => self.union_tail_recursion_fields_help(tags),
|
||||
|
||||
NonNullableUnwrapped(field_layouts) => {
|
||||
self.union_tail_recursion_fields_help(&[field_layouts])
|
||||
}
|
||||
|
||||
NullableWrapped {
|
||||
other_tags: tags, ..
|
||||
} => self.union_tail_recursion_fields_help(tags),
|
||||
|
||||
NullableUnwrapped { other_fields, .. } => {
|
||||
self.union_tail_recursion_fields_help(&[other_fields])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn union_tail_recursion_fields_help(
|
||||
&self,
|
||||
tags: &[&'a [Layout<'a>]],
|
||||
) -> (bool, Vec<'a, Option<usize>>) {
|
||||
let mut can_use_tailrec = false;
|
||||
let mut tailrec_indices = Vec::with_capacity_in(tags.len(), self.arena);
|
||||
|
||||
for fields in tags.iter() {
|
||||
let found_index = fields
|
||||
.iter()
|
||||
.position(|f| matches!(f, Layout::RecursivePointer));
|
||||
tailrec_indices.push(found_index);
|
||||
can_use_tailrec |= found_index.is_some();
|
||||
}
|
||||
|
||||
(can_use_tailrec, tailrec_indices)
|
||||
}
|
||||
}
|
||||
|
||||
fn let_lowlevel<'a>(
|
||||
|
|
|
@ -779,12 +779,21 @@ fn refcount_union<'a>(
|
|||
NonRecursive(tags) => refcount_union_nonrec(root, ident_ids, ctx, union, tags, structure),
|
||||
|
||||
Recursive(tags) => {
|
||||
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, None, structure)
|
||||
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
|
||||
if is_tailrec && !ctx.op.is_decref() {
|
||||
refcount_union_tailrec(root, ident_ids, ctx, union, tags, None, tail_idx, structure)
|
||||
} else {
|
||||
refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure)
|
||||
}
|
||||
}
|
||||
|
||||
NonNullableUnwrapped(field_layouts) => {
|
||||
// We don't do tail recursion on NonNullableUnwrapped.
|
||||
// Its RecursionPointer is always nested inside a List, Option, or other sub-layout, since
|
||||
// a direct RecursionPointer is only possible if there's at least one non-recursive variant.
|
||||
// This nesting makes it harder to do tail recursion, so we just don't.
|
||||
let tags = root.arena.alloc([field_layouts]);
|
||||
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, None, structure)
|
||||
refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure)
|
||||
}
|
||||
|
||||
NullableWrapped {
|
||||
|
@ -792,7 +801,14 @@ fn refcount_union<'a>(
|
|||
nullable_id,
|
||||
} => {
|
||||
let null_id = Some(nullable_id);
|
||||
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
|
||||
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
|
||||
if is_tailrec && !ctx.op.is_decref() {
|
||||
refcount_union_tailrec(
|
||||
root, ident_ids, ctx, union, tags, null_id, tail_idx, structure,
|
||||
)
|
||||
} else {
|
||||
refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
|
||||
}
|
||||
}
|
||||
|
||||
NullableUnwrapped {
|
||||
|
@ -801,7 +817,14 @@ fn refcount_union<'a>(
|
|||
} => {
|
||||
let null_id = Some(nullable_id as TagIdIntType);
|
||||
let tags = root.arena.alloc([other_fields]);
|
||||
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
|
||||
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
|
||||
if is_tailrec && !ctx.op.is_decref() {
|
||||
refcount_union_tailrec(
|
||||
root, ident_ids, ctx, union, tags, null_id, tail_idx, structure,
|
||||
)
|
||||
} else {
|
||||
refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -922,7 +945,7 @@ fn refcount_union_contents<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
fn refcount_tag_union_rec<'a>(
|
||||
fn refcount_union_rec<'a>(
|
||||
root: &mut CodeGenHelp<'a>,
|
||||
ident_ids: &mut IdentIds,
|
||||
ctx: &mut Context<'a>,
|
||||
|
@ -997,6 +1020,210 @@ fn refcount_tag_union_rec<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
// Refcount a recursive union using tail-call elimination to limit stack growth
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn refcount_union_tailrec<'a>(
|
||||
root: &mut CodeGenHelp<'a>,
|
||||
ident_ids: &mut IdentIds,
|
||||
ctx: &mut Context<'a>,
|
||||
union_layout: UnionLayout<'a>,
|
||||
tag_layouts: &'a [&'a [Layout<'a>]],
|
||||
null_id: Option<TagIdIntType>,
|
||||
tailrec_indices: Vec<'a, Option<usize>>,
|
||||
initial_structure: Symbol,
|
||||
) -> Stmt<'a> {
|
||||
let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop"));
|
||||
let current = root.create_symbol(ident_ids, "current");
|
||||
let next_ptr = root.create_symbol(ident_ids, "next_ptr");
|
||||
let layout = Layout::Union(union_layout);
|
||||
|
||||
let tag_id_layout = union_layout.tag_id_layout();
|
||||
|
||||
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
|
||||
let tag_id_stmt = |next| {
|
||||
Stmt::Let(
|
||||
tag_id_sym,
|
||||
Expr::GetTagId {
|
||||
structure: current,
|
||||
union_layout,
|
||||
},
|
||||
tag_id_layout,
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
// Do refcounting on the structure itself
|
||||
// In the control flow, this comes *after* refcounting the fields
|
||||
// It receives a `next` parameter to pass through to the outer joinpoint
|
||||
let rc_structure_stmt = {
|
||||
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
|
||||
let next_addr = root.create_symbol(ident_ids, "next_addr");
|
||||
|
||||
let exit_stmt = rc_return_stmt(root, ident_ids, ctx);
|
||||
let jump_to_loop = Stmt::Jump(tailrec_loop, root.arena.alloc([next_ptr]));
|
||||
|
||||
let loop_or_exit = Stmt::Switch {
|
||||
cond_symbol: next_addr,
|
||||
cond_layout: root.layout_isize,
|
||||
branches: root.arena.alloc([(0, BranchInfo::None, exit_stmt)]),
|
||||
default_branch: (BranchInfo::None, root.arena.alloc(jump_to_loop)),
|
||||
ret_layout: LAYOUT_UNIT,
|
||||
};
|
||||
let loop_or_exit_based_on_next_addr = {
|
||||
let_lowlevel(
|
||||
root.arena,
|
||||
root.layout_isize,
|
||||
next_addr,
|
||||
PtrCast,
|
||||
&[next_ptr],
|
||||
root.arena.alloc(loop_or_exit),
|
||||
)
|
||||
};
|
||||
|
||||
let alignment = layout.alignment_bytes(root.ptr_size);
|
||||
let modify_structure_stmt = modify_refcount(
|
||||
root,
|
||||
ident_ids,
|
||||
ctx,
|
||||
rc_ptr,
|
||||
alignment,
|
||||
root.arena.alloc(loop_or_exit_based_on_next_addr),
|
||||
);
|
||||
|
||||
rc_ptr_from_data_ptr(
|
||||
root,
|
||||
ident_ids,
|
||||
current,
|
||||
rc_ptr,
|
||||
union_layout.stores_tag_id_in_pointer(root.ptr_size),
|
||||
root.arena.alloc(modify_structure_stmt),
|
||||
)
|
||||
};
|
||||
|
||||
let rc_contents_then_structure = {
|
||||
let jp_modify_union = JoinPointId(root.create_symbol(ident_ids, "jp_modify_union"));
|
||||
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena);
|
||||
|
||||
// If this is null, there is no refcount, no `next`, no fields. Just return.
|
||||
if let Some(id) = null_id {
|
||||
let ret = rc_return_stmt(root, ident_ids, ctx);
|
||||
tag_branches.push((id as u64, BranchInfo::None, ret));
|
||||
}
|
||||
|
||||
let mut tag_id: TagIdIntType = 0;
|
||||
for (field_layouts, opt_tailrec_index) in tag_layouts.iter().zip(tailrec_indices) {
|
||||
match null_id {
|
||||
Some(id) if id == tag_id => {
|
||||
tag_id += 1;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// After refcounting the fields, jump to modify the union itself.
|
||||
// The loop param is a pointer to the next union. It gets passed through two jumps.
|
||||
let (non_tailrec_fields, jump_to_modify_union) =
|
||||
if let Some(tailrec_index) = opt_tailrec_index {
|
||||
let mut filtered = Vec::with_capacity_in(field_layouts.len() - 1, root.arena);
|
||||
let mut tail_stmt = None;
|
||||
for (i, field) in field_layouts.iter().enumerate() {
|
||||
if i != tailrec_index {
|
||||
filtered.push(*field);
|
||||
} else {
|
||||
let field_val =
|
||||
root.create_symbol(ident_ids, &format!("field_{}_{}", tag_id, i));
|
||||
let field_val_expr = Expr::UnionAtIndex {
|
||||
union_layout,
|
||||
tag_id,
|
||||
index: i as u64,
|
||||
structure: current,
|
||||
};
|
||||
let jump_params = root.arena.alloc([field_val]);
|
||||
let jump = root.arena.alloc(Stmt::Jump(jp_modify_union, jump_params));
|
||||
tail_stmt = Some(Stmt::Let(field_val, field_val_expr, *field, jump));
|
||||
}
|
||||
}
|
||||
|
||||
(filtered.into_bump_slice(), tail_stmt.unwrap())
|
||||
} else {
|
||||
let zero = root.create_symbol(ident_ids, "zero");
|
||||
let zero_expr = Expr::Literal(Literal::Int(0));
|
||||
let zero_stmt = |next| Stmt::Let(zero, zero_expr, root.layout_isize, next);
|
||||
|
||||
let null = root.create_symbol(ident_ids, "null");
|
||||
let null_stmt =
|
||||
|next| let_lowlevel(root.arena, layout, null, PtrCast, &[zero], next);
|
||||
|
||||
let tail_stmt = zero_stmt(root.arena.alloc(
|
||||
//
|
||||
null_stmt(root.arena.alloc(
|
||||
//
|
||||
Stmt::Jump(jp_modify_union, root.arena.alloc([null])),
|
||||
)),
|
||||
));
|
||||
|
||||
(*field_layouts, tail_stmt)
|
||||
};
|
||||
|
||||
let fields_stmt = refcount_tag_fields(
|
||||
root,
|
||||
ident_ids,
|
||||
ctx,
|
||||
union_layout,
|
||||
non_tailrec_fields,
|
||||
current,
|
||||
tag_id,
|
||||
jump_to_modify_union,
|
||||
);
|
||||
|
||||
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
|
||||
|
||||
tag_id += 1;
|
||||
}
|
||||
|
||||
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
|
||||
|
||||
let tag_id_switch = Stmt::Switch {
|
||||
cond_symbol: tag_id_sym,
|
||||
cond_layout: tag_id_layout,
|
||||
branches: tag_branches.into_bump_slice(),
|
||||
default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)),
|
||||
ret_layout: LAYOUT_UNIT,
|
||||
};
|
||||
|
||||
let jp_param = Param {
|
||||
symbol: next_ptr,
|
||||
borrow: true,
|
||||
layout,
|
||||
};
|
||||
|
||||
Stmt::Join {
|
||||
id: jp_modify_union,
|
||||
parameters: root.arena.alloc([jp_param]),
|
||||
body: root.arena.alloc(rc_structure_stmt),
|
||||
remainder: root.arena.alloc(tag_id_switch),
|
||||
}
|
||||
};
|
||||
|
||||
let loop_body = tag_id_stmt(root.arena.alloc(
|
||||
//
|
||||
rc_contents_then_structure,
|
||||
));
|
||||
|
||||
let loop_init = Stmt::Jump(tailrec_loop, root.arena.alloc([initial_structure]));
|
||||
let loop_param = Param {
|
||||
symbol: current,
|
||||
borrow: true,
|
||||
layout: Layout::Union(union_layout),
|
||||
};
|
||||
|
||||
Stmt::Join {
|
||||
id: tailrec_loop,
|
||||
parameters: root.arena.alloc([loop_param]),
|
||||
body: root.arena.alloc(loop_body),
|
||||
remainder: root.arena.alloc(loop_init),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn refcount_tag_fields<'a>(
|
||||
root: &mut CodeGenHelp<'a>,
|
||||
|
|
|
@ -265,7 +265,7 @@ fn union_recursive_dec() {
|
|||
#[test]
|
||||
#[cfg(any(feature = "gen-wasm"))]
|
||||
fn refcount_different_rosetrees_inc() {
|
||||
// Requires two different equality procedures for `List (Rose I64)` and `List (Rose Str)`
|
||||
// Requires two different Inc procedures for `List (Rose I64)` and `List (Rose Str)`
|
||||
// even though both appear in the mono Layout as `List(RecursivePointer)`
|
||||
assert_refcounts!(
|
||||
indoc!(
|
||||
|
@ -305,7 +305,7 @@ fn refcount_different_rosetrees_inc() {
|
|||
#[test]
|
||||
#[cfg(any(feature = "gen-wasm"))]
|
||||
fn refcount_different_rosetrees_dec() {
|
||||
// Requires two different equality procedures for `List (Rose I64)` and `List (Rose Str)`
|
||||
// Requires two different Dec procedures for `List (Rose I64)` and `List (Rose Str)`
|
||||
// even though both appear in the mono Layout as `List(RecursivePointer)`
|
||||
assert_refcounts!(
|
||||
indoc!(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue