Implement tail recursion for union refcounting procs

This commit is contained in:
Brian Carroll 2022-01-06 22:42:39 +00:00
parent 8ebdc8ea7f
commit 5560ecb63e
3 changed files with 276 additions and 7 deletions

View file

@ -381,6 +381,48 @@ impl<'a> CodeGenHelp<'a> {
Layout::RecursivePointer => Layout::Union(ctx.recursive_union.unwrap()),
}
}
fn union_tail_recursion_fields(
&self,
union: UnionLayout<'a>,
) -> (bool, Vec<'a, Option<usize>>) {
use UnionLayout::*;
match union {
NonRecursive(_) => return (false, bumpalo::vec![in self.arena]),
Recursive(tags) => self.union_tail_recursion_fields_help(tags),
NonNullableUnwrapped(field_layouts) => {
self.union_tail_recursion_fields_help(&[field_layouts])
}
NullableWrapped {
other_tags: tags, ..
} => self.union_tail_recursion_fields_help(tags),
NullableUnwrapped { other_fields, .. } => {
self.union_tail_recursion_fields_help(&[other_fields])
}
}
}
fn union_tail_recursion_fields_help(
&self,
tags: &[&'a [Layout<'a>]],
) -> (bool, Vec<'a, Option<usize>>) {
let mut can_use_tailrec = false;
let mut tailrec_indices = Vec::with_capacity_in(tags.len(), self.arena);
for fields in tags.iter() {
let found_index = fields
.iter()
.position(|f| matches!(f, Layout::RecursivePointer));
tailrec_indices.push(found_index);
can_use_tailrec |= found_index.is_some();
}
(can_use_tailrec, tailrec_indices)
}
}
fn let_lowlevel<'a>(

View file

@ -779,12 +779,21 @@ fn refcount_union<'a>(
NonRecursive(tags) => refcount_union_nonrec(root, ident_ids, ctx, union, tags, structure),
Recursive(tags) => {
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, None, structure)
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
if is_tailrec && !ctx.op.is_decref() {
refcount_union_tailrec(root, ident_ids, ctx, union, tags, None, tail_idx, structure)
} else {
refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure)
}
}
NonNullableUnwrapped(field_layouts) => {
// We don't do tail recursion on NonNullableUnwrapped.
// Its RecursionPointer is always nested inside a List, Option, or other sub-layout, since
// a direct RecursionPointer is only possible if there's at least one non-recursive variant.
// This nesting makes it harder to do tail recursion, so we just don't.
let tags = root.arena.alloc([field_layouts]);
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, None, structure)
refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure)
}
NullableWrapped {
@ -792,7 +801,14 @@ fn refcount_union<'a>(
nullable_id,
} => {
let null_id = Some(nullable_id);
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
if is_tailrec && !ctx.op.is_decref() {
refcount_union_tailrec(
root, ident_ids, ctx, union, tags, null_id, tail_idx, structure,
)
} else {
refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
}
}
NullableUnwrapped {
@ -801,7 +817,14 @@ fn refcount_union<'a>(
} => {
let null_id = Some(nullable_id as TagIdIntType);
let tags = root.arena.alloc([other_fields]);
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
if is_tailrec && !ctx.op.is_decref() {
refcount_union_tailrec(
root, ident_ids, ctx, union, tags, null_id, tail_idx, structure,
)
} else {
refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
}
}
};
@ -922,7 +945,7 @@ fn refcount_union_contents<'a>(
}
}
fn refcount_tag_union_rec<'a>(
fn refcount_union_rec<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
@ -997,6 +1020,210 @@ fn refcount_tag_union_rec<'a>(
}
}
// Refcount a recursive union using tail-call elimination to limit stack growth
#[allow(clippy::too_many_arguments)]
fn refcount_union_tailrec<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
tailrec_indices: Vec<'a, Option<usize>>,
initial_structure: Symbol,
) -> Stmt<'a> {
let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop"));
let current = root.create_symbol(ident_ids, "current");
let next_ptr = root.create_symbol(ident_ids, "next_ptr");
let layout = Layout::Union(union_layout);
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
let tag_id_stmt = |next| {
Stmt::Let(
tag_id_sym,
Expr::GetTagId {
structure: current,
union_layout,
},
tag_id_layout,
next,
)
};
// Do refcounting on the structure itself
// In the control flow, this comes *after* refcounting the fields
// It receives a `next` parameter to pass through to the outer joinpoint
let rc_structure_stmt = {
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let next_addr = root.create_symbol(ident_ids, "next_addr");
let exit_stmt = rc_return_stmt(root, ident_ids, ctx);
let jump_to_loop = Stmt::Jump(tailrec_loop, root.arena.alloc([next_ptr]));
let loop_or_exit = Stmt::Switch {
cond_symbol: next_addr,
cond_layout: root.layout_isize,
branches: root.arena.alloc([(0, BranchInfo::None, exit_stmt)]),
default_branch: (BranchInfo::None, root.arena.alloc(jump_to_loop)),
ret_layout: LAYOUT_UNIT,
};
let loop_or_exit_based_on_next_addr = {
let_lowlevel(
root.arena,
root.layout_isize,
next_addr,
PtrCast,
&[next_ptr],
root.arena.alloc(loop_or_exit),
)
};
let alignment = layout.alignment_bytes(root.ptr_size);
let modify_structure_stmt = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
root.arena.alloc(loop_or_exit_based_on_next_addr),
);
rc_ptr_from_data_ptr(
root,
ident_ids,
current,
rc_ptr,
union_layout.stores_tag_id_in_pointer(root.ptr_size),
root.arena.alloc(modify_structure_stmt),
)
};
let rc_contents_then_structure = {
let jp_modify_union = JoinPointId(root.create_symbol(ident_ids, "jp_modify_union"));
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena);
// If this is null, there is no refcount, no `next`, no fields. Just return.
if let Some(id) = null_id {
let ret = rc_return_stmt(root, ident_ids, ctx);
tag_branches.push((id as u64, BranchInfo::None, ret));
}
let mut tag_id: TagIdIntType = 0;
for (field_layouts, opt_tailrec_index) in tag_layouts.iter().zip(tailrec_indices) {
match null_id {
Some(id) if id == tag_id => {
tag_id += 1;
}
_ => {}
}
// After refcounting the fields, jump to modify the union itself.
// The loop param is a pointer to the next union. It gets passed through two jumps.
let (non_tailrec_fields, jump_to_modify_union) =
if let Some(tailrec_index) = opt_tailrec_index {
let mut filtered = Vec::with_capacity_in(field_layouts.len() - 1, root.arena);
let mut tail_stmt = None;
for (i, field) in field_layouts.iter().enumerate() {
if i != tailrec_index {
filtered.push(*field);
} else {
let field_val =
root.create_symbol(ident_ids, &format!("field_{}_{}", tag_id, i));
let field_val_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure: current,
};
let jump_params = root.arena.alloc([field_val]);
let jump = root.arena.alloc(Stmt::Jump(jp_modify_union, jump_params));
tail_stmt = Some(Stmt::Let(field_val, field_val_expr, *field, jump));
}
}
(filtered.into_bump_slice(), tail_stmt.unwrap())
} else {
let zero = root.create_symbol(ident_ids, "zero");
let zero_expr = Expr::Literal(Literal::Int(0));
let zero_stmt = |next| Stmt::Let(zero, zero_expr, root.layout_isize, next);
let null = root.create_symbol(ident_ids, "null");
let null_stmt =
|next| let_lowlevel(root.arena, layout, null, PtrCast, &[zero], next);
let tail_stmt = zero_stmt(root.arena.alloc(
//
null_stmt(root.arena.alloc(
//
Stmt::Jump(jp_modify_union, root.arena.alloc([null])),
)),
));
(*field_layouts, tail_stmt)
};
let fields_stmt = refcount_tag_fields(
root,
ident_ids,
ctx,
union_layout,
non_tailrec_fields,
current,
tag_id,
jump_to_modify_union,
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
let tag_id_switch = Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: tag_branches.into_bump_slice(),
default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)),
ret_layout: LAYOUT_UNIT,
};
let jp_param = Param {
symbol: next_ptr,
borrow: true,
layout,
};
Stmt::Join {
id: jp_modify_union,
parameters: root.arena.alloc([jp_param]),
body: root.arena.alloc(rc_structure_stmt),
remainder: root.arena.alloc(tag_id_switch),
}
};
let loop_body = tag_id_stmt(root.arena.alloc(
//
rc_contents_then_structure,
));
let loop_init = Stmt::Jump(tailrec_loop, root.arena.alloc([initial_structure]));
let loop_param = Param {
symbol: current,
borrow: true,
layout: Layout::Union(union_layout),
};
Stmt::Join {
id: tailrec_loop,
parameters: root.arena.alloc([loop_param]),
body: root.arena.alloc(loop_body),
remainder: root.arena.alloc(loop_init),
}
}
#[allow(clippy::too_many_arguments)]
fn refcount_tag_fields<'a>(
root: &mut CodeGenHelp<'a>,

View file

@ -265,7 +265,7 @@ fn union_recursive_dec() {
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn refcount_different_rosetrees_inc() {
// Requires two different equality procedures for `List (Rose I64)` and `List (Rose Str)`
// Requires two different Inc procedures for `List (Rose I64)` and `List (Rose Str)`
// even though both appear in the mono Layout as `List(RecursivePointer)`
assert_refcounts!(
indoc!(
@ -305,7 +305,7 @@ fn refcount_different_rosetrees_inc() {
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn refcount_different_rosetrees_dec() {
// Requires two different equality procedures for `List (Rose I64)` and `List (Rose Str)`
// Requires two different Dec procedures for `List (Rose I64)` and `List (Rose Str)`
// even though both appear in the mono Layout as `List(RecursivePointer)`
assert_refcounts!(
indoc!(