Generate refcounting helper procedures for tag unions

This commit is contained in:
Brian Carroll 2022-01-02 13:58:36 +00:00
parent 2465fe8528
commit 5d7b4018b7
3 changed files with 296 additions and 6 deletions

View file

@ -1,3 +1,4 @@
use bumpalo::collections::vec::Vec;
use roc_builtins::bitcode::IntWidth;
use roc_module::low_level::{LowLevel, LowLevel::*};
use roc_module::symbol::{IdentIds, Symbol};
@ -6,7 +7,7 @@ use crate::code_gen_help::let_lowlevel;
use crate::ir::{
BranchInfo, Call, CallType, Expr, JoinPointId, Literal, ModifyRc, Param, Stmt, UpdateModeId,
};
use crate::layout::{Builtin, Layout, UnionLayout};
use crate::layout::{Builtin, Layout, TagIdIntType, UnionLayout};
use super::{CodeGenHelp, Context, HelperOp};
@ -112,9 +113,12 @@ pub fn refcount_generic<'a>(
Layout::Struct(field_layouts) => {
refcount_struct(root, ident_ids, ctx, field_layouts, structure)
}
Layout::Union(_) => rc_todo(),
Layout::LambdaSet(_) => {
unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.")
Layout::Union(union_layout) => {
refcount_tag_union(root, ident_ids, ctx, union_layout, structure)
}
Layout::LambdaSet(lambda_set) => {
let runtime_layout = lambda_set.runtime_representation();
refcount_generic(root, ident_ids, ctx, runtime_layout, structure)
}
Layout::RecursivePointer => rc_todo(),
}
@ -124,12 +128,32 @@ pub fn refcount_generic<'a>(
// In the short term, it helps us to skip refcounting and let it leak, so we can make
// progress incrementally. Kept in sync with generate_procs using assertions.
pub fn is_rc_implemented_yet(layout: &Layout) -> bool {
use UnionLayout::*;
match layout {
Layout::Builtin(Builtin::Dict(..) | Builtin::Set(_)) => false,
Layout::Builtin(Builtin::List(elem_layout)) => is_rc_implemented_yet(elem_layout),
Layout::Builtin(_) => true,
Layout::Struct(fields) => fields.iter().all(is_rc_implemented_yet),
_ => false,
Layout::Union(union_layout) => match union_layout {
NonRecursive(tags) => tags
.iter()
.all(|fields| fields.iter().all(is_rc_implemented_yet)),
Recursive(tags) => tags
.iter()
.all(|fields| fields.iter().all(is_rc_implemented_yet)),
NonNullableUnwrapped(fields) => fields.iter().all(is_rc_implemented_yet),
NullableWrapped { other_tags, .. } => other_tags
.iter()
.all(|fields| fields.iter().all(is_rc_implemented_yet)),
NullableUnwrapped { other_fields, .. } => {
other_fields.iter().all(is_rc_implemented_yet)
}
},
Layout::LambdaSet(lambda_set) => {
is_rc_implemented_yet(&lambda_set.runtime_representation())
}
Layout::RecursivePointer => true,
}
}
@ -417,7 +441,7 @@ fn refcount_list<'a>(
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let alignment = layout.alignment_bytes(root.ptr_size);
let modify_elems = if elem_layout.is_refcounted() && !matches!(ctx.op, HelperOp::DecRef(_)) {
let modify_elems = if elem_layout.is_refcounted() && !ctx.op.is_decref() {
refcount_list_elems(
root,
ident_ids,
@ -667,3 +691,216 @@ fn refcount_struct<'a>(
stmt
}
fn refcount_tag_union<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
structure: Symbol,
) -> Stmt<'a> {
use UnionLayout::*;
let parent_rec_ptr_layout = ctx.recursive_union;
if !matches!(union_layout, NonRecursive(_)) {
ctx.recursive_union = Some(union_layout);
}
let body = match union_layout {
NonRecursive(tags) => {
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, None, structure)
}
Recursive(tags) => {
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, None, structure)
}
NonNullableUnwrapped(field_layouts) => {
let tags = root.arena.alloc([field_layouts]);
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, None, structure)
}
NullableWrapped {
other_tags: tags,
nullable_id,
} => {
let null_id = Some(nullable_id);
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, null_id, structure)
}
NullableUnwrapped {
other_fields,
nullable_id,
} => {
let null_id = Some(nullable_id as TagIdIntType);
let tags = root.arena.alloc([other_fields]);
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, null_id, structure)
}
};
ctx.recursive_union = parent_rec_ptr_layout;
body
}
fn refcount_tag_union_help<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
structure: Symbol,
) -> Stmt<'a> {
let is_non_recursive = matches!(union_layout, UnionLayout::NonRecursive(_));
/* TODO: tail recursion
let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop"));
let structure = if is_non_recursive {
initial_structure
} else {
// current value in the tail-recursive loop
root.create_symbol(ident_ids, "structure")
};
*/
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
let tag_id_stmt = |next| {
Stmt::Let(
tag_id_sym,
Expr::GetTagId {
structure,
union_layout,
},
tag_id_layout,
next,
)
};
let modify_fields_stmt = if ctx.op.is_decref() {
rc_return_stmt(root, ident_ids, ctx)
} else {
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), root.arena);
let mut tag_id: TagIdIntType = 0;
for field_layouts in tag_layouts.iter() {
if let Some(id) = null_id {
if tag_id == id {
tag_id += 1;
}
}
let fields_stmt = refcount_tag_fields(
root,
ident_ids,
ctx,
union_layout,
field_layouts,
structure,
tag_id as TagIdIntType,
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: tag_branches.into_bump_slice(),
default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)),
ret_layout: LAYOUT_UNIT,
}
};
let rc_structure_stmt = if is_non_recursive {
modify_fields_stmt
} else {
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let alignment = Layout::Union(union_layout).alignment_bytes(root.ptr_size);
let modify_structure_stmt = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
root.arena.alloc(modify_fields_stmt),
);
let rc_ptr_stmt = rc_ptr_from_data_ptr(
root,
ident_ids,
structure,
rc_ptr,
root.arena.alloc(modify_structure_stmt),
);
if let Some(id) = null_id {
let null_branch = (
id as u64,
BranchInfo::None,
rc_return_stmt(root, ident_ids, ctx),
);
Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: root.arena.alloc([null_branch]),
default_branch: (BranchInfo::None, root.arena.alloc(rc_ptr_stmt)),
ret_layout: LAYOUT_UNIT,
}
} else {
rc_ptr_stmt
}
};
tag_id_stmt(root.arena.alloc(
//
rc_structure_stmt,
))
}
fn refcount_tag_fields<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
field_layouts: &'a [Layout<'a>],
structure: Symbol,
tag_id: TagIdIntType,
) -> Stmt<'a> {
let mut stmt = rc_return_stmt(root, ident_ids, ctx);
for (i, field_layout) in field_layouts.iter().enumerate().rev() {
if field_layout.contains_refcounted() {
let field_val = root.create_symbol(ident_ids, &format!("field_{}_{}", tag_id, i));
let field_val_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure,
};
let field_val_stmt = |next| Stmt::Let(field_val, field_val_expr, *field_layout, next);
let mod_unit = root.create_symbol(ident_ids, &format!("mod_field_{}_{}", tag_id, i));
let mod_args = refcount_args(root, ctx, field_val);
let mod_expr = root
.call_specialized_op(ident_ids, ctx, *field_layout, mod_args)
.unwrap();
let mod_stmt = |next| Stmt::Let(mod_unit, mod_expr, LAYOUT_UNIT, next);
stmt = field_val_stmt(root.arena.alloc(
//
mod_stmt(root.arena.alloc(
//
stmt,
)),
))
}
}
stmt
}