diff --git a/compiler/mono/src/code_gen_help/mod.rs b/compiler/mono/src/code_gen_help/mod.rs index 65dfc811a6..61c4c32c9e 100644 --- a/compiler/mono/src/code_gen_help/mod.rs +++ b/compiler/mono/src/code_gen_help/mod.rs @@ -32,6 +32,12 @@ enum HelperOp { Eq, } +impl HelperOp { + fn is_decref(&self) -> bool { + matches!(self, Self::DecRef(_)) + } +} + #[derive(Debug)] struct Specialization<'a> { op: HelperOp, @@ -174,7 +180,7 @@ impl<'a> CodeGenHelp<'a> { ) -> Option> { use HelperOp::*; - debug_assert!(self.debug_recursion_depth < 10); + // debug_assert!(self.debug_recursion_depth < 100); self.debug_recursion_depth += 1; let layout = if matches!(called_layout, Layout::RecursivePointer) { @@ -225,6 +231,8 @@ impl<'a> CodeGenHelp<'a> { ) -> Symbol { use HelperOp::*; + let layout = self.replace_rec_ptr(ctx, layout); + let found = self .specializations .iter() @@ -323,6 +331,98 @@ impl<'a> CodeGenHelp<'a> { let ident_id = ident_ids.add(Ident::from(debug_name)); Symbol::new(self.home, ident_id) } + + // When creating or looking up Specializations, we need to replace RecursivePointer + // with the particular Union layout it represents at this point in the tree. + // For example if a program uses `RoseTree a : [ Tree a (List (RoseTree a)) ]` + // then it could have both `RoseTree I64` and `RoseTree Str`. In this case it + // needs *two* specializations for `List(RecursivePointer)`, not just one. + fn replace_rec_ptr(&self, ctx: &Context<'a>, layout: Layout<'a>) -> Layout<'a> { + match layout { + Layout::Builtin(Builtin::Dict(k, v)) => Layout::Builtin(Builtin::Dict( + self.arena.alloc(self.replace_rec_ptr(ctx, *k)), + self.arena.alloc(self.replace_rec_ptr(ctx, *v)), + )), + + Layout::Builtin(Builtin::Set(k)) => Layout::Builtin(Builtin::Set( + self.arena.alloc(self.replace_rec_ptr(ctx, *k)), + )), + + Layout::Builtin(Builtin::List(v)) => Layout::Builtin(Builtin::List( + self.arena.alloc(self.replace_rec_ptr(ctx, *v)), + )), + + Layout::Builtin(_) => layout, + + Layout::Struct(fields) => { + let new_fields_iter = fields.iter().map(|f| self.replace_rec_ptr(ctx, *f)); + Layout::Struct(self.arena.alloc_slice_fill_iter(new_fields_iter)) + } + + Layout::Union(UnionLayout::NonRecursive(tags)) => { + let mut new_tags = Vec::with_capacity_in(tags.len(), self.arena); + for fields in tags { + let mut new_fields = Vec::with_capacity_in(fields.len(), self.arena); + for field in fields.iter() { + new_fields.push(self.replace_rec_ptr(ctx, *field)) + } + new_tags.push(new_fields.into_bump_slice()); + } + Layout::Union(UnionLayout::NonRecursive(new_tags.into_bump_slice())) + } + + Layout::Union(_) => layout, + + Layout::LambdaSet(lambda_set) => { + self.replace_rec_ptr(ctx, lambda_set.runtime_representation()) + } + + // This line is the whole point of the function + Layout::RecursivePointer => Layout::Union(ctx.recursive_union.unwrap()), + } + } + + fn union_tail_recursion_fields( + &self, + union: UnionLayout<'a>, + ) -> (bool, Vec<'a, Option>) { + use UnionLayout::*; + match union { + NonRecursive(_) => return (false, bumpalo::vec![in self.arena]), + + Recursive(tags) => self.union_tail_recursion_fields_help(tags), + + NonNullableUnwrapped(field_layouts) => { + self.union_tail_recursion_fields_help(&[field_layouts]) + } + + NullableWrapped { + other_tags: tags, .. + } => self.union_tail_recursion_fields_help(tags), + + NullableUnwrapped { other_fields, .. } => { + self.union_tail_recursion_fields_help(&[other_fields]) + } + } + } + + fn union_tail_recursion_fields_help( + &self, + tags: &[&'a [Layout<'a>]], + ) -> (bool, Vec<'a, Option>) { + let mut can_use_tailrec = false; + let mut tailrec_indices = Vec::with_capacity_in(tags.len(), self.arena); + + for fields in tags.iter() { + let found_index = fields + .iter() + .position(|f| matches!(f, Layout::RecursivePointer)); + tailrec_indices.push(found_index); + can_use_tailrec |= found_index.is_some(); + } + + (can_use_tailrec, tailrec_indices) + } } fn let_lowlevel<'a>( diff --git a/compiler/mono/src/code_gen_help/refcount.rs b/compiler/mono/src/code_gen_help/refcount.rs index 4a27376f25..2bc4aeff92 100644 --- a/compiler/mono/src/code_gen_help/refcount.rs +++ b/compiler/mono/src/code_gen_help/refcount.rs @@ -1,3 +1,4 @@ +use bumpalo::collections::vec::Vec; use roc_builtins::bitcode::IntWidth; use roc_module::low_level::{LowLevel, LowLevel::*}; use roc_module::symbol::{IdentIds, Symbol}; @@ -6,7 +7,7 @@ use crate::code_gen_help::let_lowlevel; use crate::ir::{ BranchInfo, Call, CallType, Expr, JoinPointId, Literal, ModifyRc, Param, Stmt, UpdateModeId, }; -use crate::layout::{Builtin, Layout, UnionLayout}; +use crate::layout::{Builtin, Layout, TagIdIntType, UnionLayout}; use super::{CodeGenHelp, Context, HelperOp}; @@ -67,8 +68,9 @@ pub fn refcount_stmt<'a>( refcount_stmt(root, ident_ids, ctx, layout, modify, following) } - // Struct is stack-only, so DecRef is a no-op + // Struct and non-recursive Unions are stack-only, so DecRef is a no-op Layout::Struct(_) => following, + Layout::Union(UnionLayout::NonRecursive(_)) => following, // Inline the refcounting code instead of making a function. Don't iterate fields, // and replace any return statements with jumps to the `following` statement. @@ -112,9 +114,12 @@ pub fn refcount_generic<'a>( Layout::Struct(field_layouts) => { refcount_struct(root, ident_ids, ctx, field_layouts, structure) } - Layout::Union(_) => rc_todo(), - Layout::LambdaSet(_) => { - unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.") + Layout::Union(union_layout) => { + refcount_union(root, ident_ids, ctx, union_layout, structure) + } + Layout::LambdaSet(lambda_set) => { + let runtime_layout = lambda_set.runtime_representation(); + refcount_generic(root, ident_ids, ctx, runtime_layout, structure) } Layout::RecursivePointer => rc_todo(), } @@ -124,12 +129,32 @@ pub fn refcount_generic<'a>( // In the short term, it helps us to skip refcounting and let it leak, so we can make // progress incrementally. Kept in sync with generate_procs using assertions. pub fn is_rc_implemented_yet(layout: &Layout) -> bool { + use UnionLayout::*; + match layout { Layout::Builtin(Builtin::Dict(..) | Builtin::Set(_)) => false, Layout::Builtin(Builtin::List(elem_layout)) => is_rc_implemented_yet(elem_layout), Layout::Builtin(_) => true, Layout::Struct(fields) => fields.iter().all(is_rc_implemented_yet), - _ => false, + Layout::Union(union_layout) => match union_layout { + NonRecursive(tags) => tags + .iter() + .all(|fields| fields.iter().all(is_rc_implemented_yet)), + Recursive(tags) => tags + .iter() + .all(|fields| fields.iter().all(is_rc_implemented_yet)), + NonNullableUnwrapped(fields) => fields.iter().all(is_rc_implemented_yet), + NullableWrapped { other_tags, .. } => other_tags + .iter() + .all(|fields| fields.iter().all(is_rc_implemented_yet)), + NullableUnwrapped { other_fields, .. } => { + other_fields.iter().all(is_rc_implemented_yet) + } + }, + Layout::LambdaSet(lambda_set) => { + is_rc_implemented_yet(&lambda_set.runtime_representation()) + } + Layout::RecursivePointer => true, } } @@ -165,6 +190,7 @@ pub fn rc_ptr_from_data_ptr<'a>( ident_ids: &mut IdentIds, structure: Symbol, rc_ptr_sym: Symbol, + mask_lower_bits: bool, following: &'a Stmt<'a>, ) -> Stmt<'a> { // Typecast the structure pointer to an integer @@ -179,6 +205,21 @@ pub fn rc_ptr_from_data_ptr<'a>( }); let addr_stmt = |next| Stmt::Let(addr_sym, addr_expr, root.layout_isize, next); + // Mask for lower bits (for tag union id) + let mask_sym = root.create_symbol(ident_ids, "mask"); + let mask_expr = Expr::Literal(Literal::Int(-(root.ptr_size as i128))); + let mask_stmt = |next| Stmt::Let(mask_sym, mask_expr, root.layout_isize, next); + + let masked_sym = root.create_symbol(ident_ids, "masked"); + let and_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::And, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([addr_sym, mask_sym]), + }); + let and_stmt = |next| Stmt::Let(masked_sym, and_expr, root.layout_isize, next); + // Pointer size constant let ptr_size_sym = root.create_symbol(ident_ids, "ptr_size"); let ptr_size_expr = Expr::Literal(Literal::Int(root.ptr_size as i128)); @@ -186,38 +227,67 @@ pub fn rc_ptr_from_data_ptr<'a>( // Refcount address let rc_addr_sym = root.create_symbol(ident_ids, "rc_addr"); - let rc_addr_expr = Expr::Call(Call { + let sub_expr = Expr::Call(Call { call_type: CallType::LowLevel { op: LowLevel::NumSub, update_mode: UpdateModeId::BACKEND_DUMMY, }, - arguments: root.arena.alloc([addr_sym, ptr_size_sym]), + arguments: root.arena.alloc([ + if mask_lower_bits { + masked_sym + } else { + addr_sym + }, + ptr_size_sym, + ]), }); - let rc_addr_stmt = |next| Stmt::Let(rc_addr_sym, rc_addr_expr, root.layout_isize, next); + let sub_stmt = |next| Stmt::Let(rc_addr_sym, sub_expr, root.layout_isize, next); // Typecast the refcount address from integer to pointer - let rc_ptr_expr = Expr::Call(Call { + let cast_expr = Expr::Call(Call { call_type: CallType::LowLevel { op: LowLevel::PtrCast, update_mode: UpdateModeId::BACKEND_DUMMY, }, arguments: root.arena.alloc([rc_addr_sym]), }); - let rc_ptr_stmt = |next| Stmt::Let(rc_ptr_sym, rc_ptr_expr, LAYOUT_PTR, next); + let cast_stmt = |next| Stmt::Let(rc_ptr_sym, cast_expr, LAYOUT_PTR, next); - addr_stmt(root.arena.alloc( - // - ptr_size_stmt(root.arena.alloc( + if mask_lower_bits { + addr_stmt(root.arena.alloc( // - rc_addr_stmt(root.arena.alloc( + mask_stmt(root.arena.alloc( // - rc_ptr_stmt(root.arena.alloc( + and_stmt(root.arena.alloc( // - following, + ptr_size_stmt(root.arena.alloc( + // + sub_stmt(root.arena.alloc( + // + cast_stmt(root.arena.alloc( + // + following, + )), + )), + )), )), )), - )), - )) + )) + } else { + addr_stmt(root.arena.alloc( + // + ptr_size_stmt(root.arena.alloc( + // + sub_stmt(root.arena.alloc( + // + cast_stmt(root.arena.alloc( + // + following, + )), + )), + )), + )) + } } fn modify_refcount<'a>( @@ -332,6 +402,7 @@ fn refcount_str<'a>( ident_ids, elements, rc_ptr, + false, root.arena.alloc( // mod_rc_stmt, @@ -412,12 +483,32 @@ fn refcount_list<'a>( // // modify refcount of the list and its elements + // (elements first, to avoid use-after-free for Dec) // let rc_ptr = root.create_symbol(ident_ids, "rc_ptr"); let alignment = layout.alignment_bytes(root.ptr_size); - let modify_elems = if elem_layout.is_refcounted() && !matches!(ctx.op, HelperOp::DecRef(_)) { + let ret_stmt = rc_return_stmt(root, ident_ids, ctx); + let modify_list = modify_refcount( + root, + ident_ids, + ctx, + rc_ptr, + alignment, + arena.alloc(ret_stmt), + ); + + let get_rc_and_modify_list = rc_ptr_from_data_ptr( + root, + ident_ids, + elements, + rc_ptr, + false, + arena.alloc(modify_list), + ); + + let modify_elems_and_list = if elem_layout.is_refcounted() && !ctx.op.is_decref() { refcount_list_elems( root, ident_ids, @@ -427,36 +518,31 @@ fn refcount_list<'a>( box_union_layout, len, elements, + get_rc_and_modify_list, ) } else { - rc_return_stmt(root, ident_ids, ctx) + get_rc_and_modify_list }; - let modify_list = modify_refcount( - root, - ident_ids, - ctx, - rc_ptr, - alignment, - arena.alloc(modify_elems), - ); - - let modify_list_and_elems = elements_stmt(arena.alloc( - // - rc_ptr_from_data_ptr(root, ident_ids, elements, rc_ptr, arena.alloc(modify_list)), - )); - // // Do nothing if the list is empty // + let non_empty_branch = root.arena.alloc( + // + elements_stmt(root.arena.alloc( + // + modify_elems_and_list, + )), + ); + let if_stmt = Stmt::Switch { cond_symbol: is_empty, cond_layout: LAYOUT_BOOL, branches: root .arena .alloc([(1, BranchInfo::None, rc_return_stmt(root, ident_ids, ctx))]), - default_branch: (BranchInfo::None, root.arena.alloc(modify_list_and_elems)), + default_branch: (BranchInfo::None, non_empty_branch), ret_layout: LAYOUT_UNIT, }; @@ -482,6 +568,7 @@ fn refcount_list_elems<'a>( box_union_layout: UnionLayout<'a>, length: Symbol, elements: Symbol, + following: Stmt<'a>, ) -> Stmt<'a> { use LowLevel::*; let layout_isize = root.layout_isize; @@ -496,9 +583,9 @@ fn refcount_list_elems<'a>( // // let size = literal int - let size = root.create_symbol(ident_ids, "size"); - let size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(root.ptr_size) as i128)); - let size_stmt = |next| Stmt::Let(size, size_expr, layout_isize, next); + let elem_size = root.create_symbol(ident_ids, "elem_size"); + let elem_size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(root.ptr_size) as i128)); + let elem_size_stmt = |next| Stmt::Let(elem_size, elem_size_expr, layout_isize, next); // let list_size = len * size let list_size = root.create_symbol(ident_ids, "list_size"); @@ -508,7 +595,7 @@ fn refcount_list_elems<'a>( layout_isize, list_size, NumMul, - &[length, size], + &[length, elem_size], next, ) }; @@ -564,8 +651,16 @@ fn refcount_list_elems<'a>( // Next loop iteration // let next_addr = root.create_symbol(ident_ids, "next_addr"); - let next_addr_stmt = - |next| let_lowlevel(arena, layout_isize, next_addr, NumAdd, &[addr, size], next); + let next_addr_stmt = |next| { + let_lowlevel( + arena, + layout_isize, + next_addr, + NumAdd, + &[addr, elem_size], + next, + ) + }; // // Control flow @@ -578,9 +673,7 @@ fn refcount_list_elems<'a>( cond_symbol: is_end, cond_layout: LAYOUT_BOOL, ret_layout, - branches: root - .arena - .alloc([(1, BranchInfo::None, rc_return_stmt(root, ident_ids, ctx))]), + branches: root.arena.alloc([(1, BranchInfo::None, following)]), default_branch: ( BranchInfo::None, arena.alloc(box_stmt(arena.alloc( @@ -616,7 +709,7 @@ fn refcount_list_elems<'a>( start_stmt(arena.alloc( // - size_stmt(arena.alloc( + elem_size_stmt(arena.alloc( // list_size_stmt(arena.alloc( // @@ -667,3 +760,510 @@ fn refcount_struct<'a>( stmt } + +fn refcount_union<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union: UnionLayout<'a>, + structure: Symbol, +) -> Stmt<'a> { + use UnionLayout::*; + + let parent_rec_ptr_layout = ctx.recursive_union; + if !matches!(union, NonRecursive(_)) { + ctx.recursive_union = Some(union); + } + + let body = match union { + NonRecursive(tags) => refcount_union_nonrec(root, ident_ids, ctx, union, tags, structure), + + Recursive(tags) => { + let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union); + if is_tailrec && !ctx.op.is_decref() { + refcount_union_tailrec(root, ident_ids, ctx, union, tags, None, tail_idx, structure) + } else { + refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure) + } + } + + NonNullableUnwrapped(field_layouts) => { + // We don't do tail recursion on NonNullableUnwrapped. + // Its RecursionPointer is always nested inside a List, Option, or other sub-layout, since + // a direct RecursionPointer is only possible if there's at least one non-recursive variant. + // This nesting makes it harder to do tail recursion, so we just don't. + let tags = root.arena.alloc([field_layouts]); + refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure) + } + + NullableWrapped { + other_tags: tags, + nullable_id, + } => { + let null_id = Some(nullable_id); + let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union); + if is_tailrec && !ctx.op.is_decref() { + refcount_union_tailrec( + root, ident_ids, ctx, union, tags, null_id, tail_idx, structure, + ) + } else { + refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure) + } + } + + NullableUnwrapped { + other_fields, + nullable_id, + } => { + let null_id = Some(nullable_id as TagIdIntType); + let tags = root.arena.alloc([other_fields]); + let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union); + if is_tailrec && !ctx.op.is_decref() { + refcount_union_tailrec( + root, ident_ids, ctx, union, tags, null_id, tail_idx, structure, + ) + } else { + refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure) + } + } + }; + + ctx.recursive_union = parent_rec_ptr_layout; + + body +} + +fn refcount_union_nonrec<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union_layout: UnionLayout<'a>, + tag_layouts: &'a [&'a [Layout<'a>]], + structure: Symbol, +) -> Stmt<'a> { + let tag_id_layout = union_layout.tag_id_layout(); + + let tag_id_sym = root.create_symbol(ident_ids, "tag_id"); + let tag_id_stmt = |next| { + Stmt::Let( + tag_id_sym, + Expr::GetTagId { + structure, + union_layout, + }, + tag_id_layout, + next, + ) + }; + + let continuation = rc_return_stmt(root, ident_ids, ctx); + + let switch_stmt = refcount_union_contents( + root, + ident_ids, + ctx, + union_layout, + tag_layouts, + None, + structure, + tag_id_sym, + tag_id_layout, + continuation, + ); + + tag_id_stmt(root.arena.alloc( + // + switch_stmt, + )) +} + +#[allow(clippy::too_many_arguments)] +fn refcount_union_contents<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union_layout: UnionLayout<'a>, + tag_layouts: &'a [&'a [Layout<'a>]], + null_id: Option, + structure: Symbol, + tag_id_sym: Symbol, + tag_id_layout: Layout<'a>, + modify_union_stmt: Stmt<'a>, +) -> Stmt<'a> { + let jp_modify_union = JoinPointId(root.create_symbol(ident_ids, "jp_modify_union")); + let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena); + + if let Some(id) = null_id { + let ret = rc_return_stmt(root, ident_ids, ctx); + tag_branches.push((id as u64, BranchInfo::None, ret)); + } + + let mut tag_id: TagIdIntType = 0; + for field_layouts in tag_layouts.iter() { + match null_id { + Some(id) if id == tag_id => { + tag_id += 1; + } + _ => {} + } + + // After refcounting the fields, jump to modify the union itself + // (Order is important, to avoid use-after-free for Dec) + let following = Stmt::Jump(jp_modify_union, &[]); + + let fields_stmt = refcount_tag_fields( + root, + ident_ids, + ctx, + union_layout, + field_layouts, + structure, + tag_id, + following, + ); + + tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt)); + + tag_id += 1; + } + + let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2; + + let tag_id_switch = Stmt::Switch { + cond_symbol: tag_id_sym, + cond_layout: tag_id_layout, + branches: tag_branches.into_bump_slice(), + default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)), + ret_layout: LAYOUT_UNIT, + }; + + Stmt::Join { + id: jp_modify_union, + parameters: &[], + body: root.arena.alloc(modify_union_stmt), + remainder: root.arena.alloc(tag_id_switch), + } +} + +fn refcount_union_rec<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union_layout: UnionLayout<'a>, + tag_layouts: &'a [&'a [Layout<'a>]], + null_id: Option, + structure: Symbol, +) -> Stmt<'a> { + let tag_id_layout = union_layout.tag_id_layout(); + + let tag_id_sym = root.create_symbol(ident_ids, "tag_id"); + let tag_id_stmt = |next| { + Stmt::Let( + tag_id_sym, + Expr::GetTagId { + structure, + union_layout, + }, + tag_id_layout, + next, + ) + }; + + let rc_structure_stmt = { + let rc_ptr = root.create_symbol(ident_ids, "rc_ptr"); + + let alignment = Layout::Union(union_layout).alignment_bytes(root.ptr_size); + let ret_stmt = rc_return_stmt(root, ident_ids, ctx); + let modify_structure_stmt = modify_refcount( + root, + ident_ids, + ctx, + rc_ptr, + alignment, + root.arena.alloc(ret_stmt), + ); + + rc_ptr_from_data_ptr( + root, + ident_ids, + structure, + rc_ptr, + union_layout.stores_tag_id_in_pointer(root.ptr_size), + root.arena.alloc(modify_structure_stmt), + ) + }; + + let rc_contents_then_structure = if ctx.op.is_decref() { + rc_structure_stmt + } else { + refcount_union_contents( + root, + ident_ids, + ctx, + union_layout, + tag_layouts, + null_id, + structure, + tag_id_sym, + tag_id_layout, + rc_structure_stmt, + ) + }; + + if ctx.op.is_decref() && null_id.is_none() { + rc_contents_then_structure + } else { + tag_id_stmt(root.arena.alloc( + // + rc_contents_then_structure, + )) + } +} + +// Refcount a recursive union using tail-call elimination to limit stack growth +#[allow(clippy::too_many_arguments)] +fn refcount_union_tailrec<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union_layout: UnionLayout<'a>, + tag_layouts: &'a [&'a [Layout<'a>]], + null_id: Option, + tailrec_indices: Vec<'a, Option>, + initial_structure: Symbol, +) -> Stmt<'a> { + let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop")); + let current = root.create_symbol(ident_ids, "current"); + let next_ptr = root.create_symbol(ident_ids, "next_ptr"); + let layout = Layout::Union(union_layout); + + let tag_id_layout = union_layout.tag_id_layout(); + + let tag_id_sym = root.create_symbol(ident_ids, "tag_id"); + let tag_id_stmt = |next| { + Stmt::Let( + tag_id_sym, + Expr::GetTagId { + structure: current, + union_layout, + }, + tag_id_layout, + next, + ) + }; + + // Do refcounting on the structure itself + // In the control flow, this comes *after* refcounting the fields + // It receives a `next` parameter to pass through to the outer joinpoint + let rc_structure_stmt = { + let rc_ptr = root.create_symbol(ident_ids, "rc_ptr"); + let next_addr = root.create_symbol(ident_ids, "next_addr"); + + let exit_stmt = rc_return_stmt(root, ident_ids, ctx); + let jump_to_loop = Stmt::Jump(tailrec_loop, root.arena.alloc([next_ptr])); + + let loop_or_exit = Stmt::Switch { + cond_symbol: next_addr, + cond_layout: root.layout_isize, + branches: root.arena.alloc([(0, BranchInfo::None, exit_stmt)]), + default_branch: (BranchInfo::None, root.arena.alloc(jump_to_loop)), + ret_layout: LAYOUT_UNIT, + }; + let loop_or_exit_based_on_next_addr = { + let_lowlevel( + root.arena, + root.layout_isize, + next_addr, + PtrCast, + &[next_ptr], + root.arena.alloc(loop_or_exit), + ) + }; + + let alignment = layout.alignment_bytes(root.ptr_size); + let modify_structure_stmt = modify_refcount( + root, + ident_ids, + ctx, + rc_ptr, + alignment, + root.arena.alloc(loop_or_exit_based_on_next_addr), + ); + + rc_ptr_from_data_ptr( + root, + ident_ids, + current, + rc_ptr, + union_layout.stores_tag_id_in_pointer(root.ptr_size), + root.arena.alloc(modify_structure_stmt), + ) + }; + + let rc_contents_then_structure = { + let jp_modify_union = JoinPointId(root.create_symbol(ident_ids, "jp_modify_union")); + let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena); + + // If this is null, there is no refcount, no `next`, no fields. Just return. + if let Some(id) = null_id { + let ret = rc_return_stmt(root, ident_ids, ctx); + tag_branches.push((id as u64, BranchInfo::None, ret)); + } + + let mut tag_id: TagIdIntType = 0; + for (field_layouts, opt_tailrec_index) in tag_layouts.iter().zip(tailrec_indices) { + match null_id { + Some(id) if id == tag_id => { + tag_id += 1; + } + _ => {} + } + + // After refcounting the fields, jump to modify the union itself. + // The loop param is a pointer to the next union. It gets passed through two jumps. + let (non_tailrec_fields, jump_to_modify_union) = + if let Some(tailrec_index) = opt_tailrec_index { + let mut filtered = Vec::with_capacity_in(field_layouts.len() - 1, root.arena); + let mut tail_stmt = None; + for (i, field) in field_layouts.iter().enumerate() { + if i != tailrec_index { + filtered.push(*field); + } else { + let field_val = + root.create_symbol(ident_ids, &format!("field_{}_{}", tag_id, i)); + let field_val_expr = Expr::UnionAtIndex { + union_layout, + tag_id, + index: i as u64, + structure: current, + }; + let jump_params = root.arena.alloc([field_val]); + let jump = root.arena.alloc(Stmt::Jump(jp_modify_union, jump_params)); + tail_stmt = Some(Stmt::Let(field_val, field_val_expr, *field, jump)); + } + } + + (filtered.into_bump_slice(), tail_stmt.unwrap()) + } else { + let zero = root.create_symbol(ident_ids, "zero"); + let zero_expr = Expr::Literal(Literal::Int(0)); + let zero_stmt = |next| Stmt::Let(zero, zero_expr, root.layout_isize, next); + + let null = root.create_symbol(ident_ids, "null"); + let null_stmt = + |next| let_lowlevel(root.arena, layout, null, PtrCast, &[zero], next); + + let tail_stmt = zero_stmt(root.arena.alloc( + // + null_stmt(root.arena.alloc( + // + Stmt::Jump(jp_modify_union, root.arena.alloc([null])), + )), + )); + + (*field_layouts, tail_stmt) + }; + + let fields_stmt = refcount_tag_fields( + root, + ident_ids, + ctx, + union_layout, + non_tailrec_fields, + current, + tag_id, + jump_to_modify_union, + ); + + tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt)); + + tag_id += 1; + } + + let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2; + + let tag_id_switch = Stmt::Switch { + cond_symbol: tag_id_sym, + cond_layout: tag_id_layout, + branches: tag_branches.into_bump_slice(), + default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)), + ret_layout: LAYOUT_UNIT, + }; + + let jp_param = Param { + symbol: next_ptr, + borrow: true, + layout, + }; + + Stmt::Join { + id: jp_modify_union, + parameters: root.arena.alloc([jp_param]), + body: root.arena.alloc(rc_structure_stmt), + remainder: root.arena.alloc(tag_id_switch), + } + }; + + let loop_body = tag_id_stmt(root.arena.alloc( + // + rc_contents_then_structure, + )); + + let loop_init = Stmt::Jump(tailrec_loop, root.arena.alloc([initial_structure])); + let loop_param = Param { + symbol: current, + borrow: true, + layout: Layout::Union(union_layout), + }; + + Stmt::Join { + id: tailrec_loop, + parameters: root.arena.alloc([loop_param]), + body: root.arena.alloc(loop_body), + remainder: root.arena.alloc(loop_init), + } +} + +#[allow(clippy::too_many_arguments)] +fn refcount_tag_fields<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union_layout: UnionLayout<'a>, + field_layouts: &'a [Layout<'a>], + structure: Symbol, + tag_id: TagIdIntType, + following: Stmt<'a>, +) -> Stmt<'a> { + let mut stmt = following; + + for (i, field_layout) in field_layouts.iter().enumerate().rev() { + if field_layout.contains_refcounted() { + let field_val = root.create_symbol(ident_ids, &format!("field_{}_{}", tag_id, i)); + let field_val_expr = Expr::UnionAtIndex { + union_layout, + tag_id, + index: i as u64, + structure, + }; + let field_val_stmt = |next| Stmt::Let(field_val, field_val_expr, *field_layout, next); + + let mod_unit = root.create_symbol(ident_ids, &format!("mod_field_{}_{}", tag_id, i)); + let mod_args = refcount_args(root, ctx, field_val); + let mod_expr = root + .call_specialized_op(ident_ids, ctx, *field_layout, mod_args) + .unwrap(); + let mod_stmt = |next| Stmt::Let(mod_unit, mod_expr, LAYOUT_UNIT, next); + + stmt = field_val_stmt(root.arena.alloc( + // + mod_stmt(root.arena.alloc( + // + stmt, + )), + )) + } + } + + stmt +} diff --git a/compiler/test_gen/src/gen_compare.rs b/compiler/test_gen/src/gen_compare.rs index 65fe0a252e..38610b8cb2 100644 --- a/compiler/test_gen/src/gen_compare.rs +++ b/compiler/test_gen/src/gen_compare.rs @@ -498,6 +498,48 @@ fn eq_rosetree() { ); } +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn eq_different_rosetrees() { + // Requires two different equality procedures for `List (Rose I64)` and `List (Rose Str)` + // even though both appear in the mono Layout as `List(RecursivePointer)` + assert_evals_to!( + indoc!( + r#" + Rose a : [ Rose a (List (Rose a)) ] + + a1 : Rose I64 + a1 = Rose 999 [] + a2 : Rose I64 + a2 = Rose 0 [a1] + + b1 : Rose I64 + b1 = Rose 999 [] + b2 : Rose I64 + b2 = Rose 0 [b1] + + ab = a2 == b2 + + c1 : Rose Str + c1 = Rose "hello" [] + c2 : Rose Str + c2 = Rose "" [c1] + + d1 : Rose Str + d1 = Rose "hello" [] + d2 : Rose Str + d2 = Rose "" [d1] + + cd = c2 == d2 + + ab && cd + "# + ), + true, + bool + ); +} + #[test] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] #[ignore] diff --git a/compiler/test_gen/src/gen_refcount.rs b/compiler/test_gen/src/gen_refcount.rs index cb795d6603..87dbe52856 100644 --- a/compiler/test_gen/src/gen_refcount.rs +++ b/compiler/test_gen/src/gen_refcount.rs @@ -7,6 +7,10 @@ use indoc::indoc; #[allow(unused_imports)] use roc_std::{RocList, RocStr}; +// A "good enough" representation of a pointer for these tests, because +// we ignore the return value. As long as it's the right stack size, it's fine. +type Pointer = usize; + #[test] #[cfg(any(feature = "gen-wasm"))] fn str_inc() { @@ -150,3 +154,277 @@ fn struct_dealloc() { &[0] // s ); } + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_nonrecursive_inc() { + type TwoStr = (RocStr, RocStr, i64); + + assert_refcounts!( + indoc!( + r#" + TwoOrNone a: [ Two a a, None ] + + s = Str.concat "A long enough string " "to be heap-allocated" + + two : TwoOrNone Str + two = Two s s + + four : TwoOrNone (TwoOrNone Str) + four = Two two two + + four + "# + ), + (TwoStr, TwoStr, i64), + &[4] + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_nonrecursive_dec() { + assert_refcounts!( + indoc!( + r#" + TwoOrNone a: [ Two a a, None ] + + s = Str.concat "A long enough string " "to be heap-allocated" + + two : TwoOrNone Str + two = Two s s + + when two is + Two x _ -> x + None -> "" + "# + ), + RocStr, + &[1] // s + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_recursive_inc() { + assert_refcounts!( + indoc!( + r#" + Expr : [ Sym Str, Add Expr Expr ] + + s = Str.concat "heap_allocated" "_symbol_name" + + x : Expr + x = Sym s + + e : Expr + e = Add x x + + Pair e e + "# + ), + (Pointer, Pointer), + &[ + 4, // s + 4, // sym + 2, // e + ] + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_recursive_dec() { + assert_refcounts!( + indoc!( + r#" + Expr : [ Sym Str, Add Expr Expr ] + + s = Str.concat "heap_allocated" "_symbol_name" + + x : Expr + x = Sym s + + e : Expr + e = Add x x + + when e is + Add y _ -> y + Sym _ -> e + "# + ), + Pointer, + &[ + 1, // s + 1, // sym + 0 // e + ] + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn refcount_different_rosetrees_inc() { + // Requires two different Inc procedures for `List (Rose I64)` and `List (Rose Str)` + // even though both appear in the mono Layout as `List(RecursivePointer)` + assert_refcounts!( + indoc!( + r#" + Rose a : [ Rose a (List (Rose a)) ] + + s = Str.concat "A long enough string " "to be heap-allocated" + + i1 : Rose I64 + i1 = Rose 999 [] + + s1 : Rose Str + s1 = Rose s [] + + i2 : Rose I64 + i2 = Rose 0 [i1, i1, i1] + + s2 : Rose Str + s2 = Rose "" [s1, s1] + + Tuple i2 s2 + "# + ), + (Pointer, Pointer), + &[ + 2, // s + 3, // i1 + 2, // s1 + 1, // [i1, i1] + 1, // i2 + 1, // [s1, s1] + 1 // s2 + ] + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn refcount_different_rosetrees_dec() { + // Requires two different Dec procedures for `List (Rose I64)` and `List (Rose Str)` + // even though both appear in the mono Layout as `List(RecursivePointer)` + assert_refcounts!( + indoc!( + r#" + Rose a : [ Rose a (List (Rose a)) ] + + s = Str.concat "A long enough string " "to be heap-allocated" + + i1 : Rose I64 + i1 = Rose 999 [] + + s1 : Rose Str + s1 = Rose s [] + + i2 : Rose I64 + i2 = Rose 0 [i1, i1] + + s2 : Rose Str + s2 = Rose "" [s1, s1] + + when (Tuple i2 s2) is + Tuple (Rose x _) _ -> x + "# + ), + i64, + &[ + 0, // s + 0, // i1 + 0, // s1 + 0, // [i1, i1] + 0, // i2 + 0, // [s1, s1] + 0, // s2 + ] + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_linked_list_inc() { + assert_refcounts!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + s = Str.concat "A long enough string " "to be heap-allocated" + + linked : LinkedList Str + linked = Cons s (Cons s (Cons s Nil)) + + Tuple linked linked + "# + ), + (Pointer, Pointer), + &[ + 6, // s + 2, // Cons + 2, // Cons + 2, // Cons + ] + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_linked_list_dec() { + assert_refcounts!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + s = Str.concat "A long enough string " "to be heap-allocated" + + linked : LinkedList Str + linked = Cons s (Cons s (Cons s Nil)) + + when linked is + Cons x _ -> x + Nil -> "" + "# + ), + RocStr, + &[ + 1, // s + 0, // Cons + 0, // Cons + 0, // Cons + ] + ); +} + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_linked_list_long_dec() { + assert_refcounts!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + prependOnes = \n, tail -> + if n == 0 then + tail + else + prependOnes (n-1) (Cons 1 tail) + + main = + n = 1_000 + + linked : LinkedList I64 + linked = prependOnes n Nil + + when linked is + Cons x _ -> x + Nil -> -1 + "# + ), + i64, + &[0; 1_000] + ); +} diff --git a/compiler/test_gen/src/helpers/wasm.rs b/compiler/test_gen/src/helpers/wasm.rs index 47973ea05d..1cf56cbf2c 100644 --- a/compiler/test_gen/src/helpers/wasm.rs +++ b/compiler/test_gen/src/helpers/wasm.rs @@ -303,28 +303,38 @@ where let memory = instance.exports.get_memory(MEMORY_NAME).unwrap(); + let expected_len = num_refcounts as i32; let init_refcount_test = instance.exports.get_function("init_refcount_test").unwrap(); - let init_result = init_refcount_test.call(&[wasmer::Value::I32(num_refcounts as i32)]); - let refcount_array_addr = match init_result { + let init_result = init_refcount_test.call(&[wasmer::Value::I32(expected_len)]); + let refcount_vector_addr = match init_result { Err(e) => return Err(format!("{:?}", e)), Ok(result) => match result[0] { wasmer::Value::I32(a) => a, _ => panic!(), }, }; - // An array of refcount pointers - let refcount_ptr_array: WasmPtr, wasmer::Array> = - WasmPtr::new(refcount_array_addr as u32); - let refcount_ptrs: &[Cell>] = refcount_ptr_array - .deref(memory, 0, num_refcounts as u32) - .unwrap(); + // Run the test let test_wrapper = instance.exports.get_function(TEST_WRAPPER_NAME).unwrap(); match test_wrapper.call(&[]) { Err(e) => return Err(format!("{:?}", e)), Ok(_) => {} } + // Check we got the right number of refcounts + let refcount_vector_len: WasmPtr = WasmPtr::new(refcount_vector_addr as u32); + let actual_len = refcount_vector_len.deref(memory).unwrap().get(); + if actual_len != expected_len { + panic!("Expected {} refcounts but got {}", expected_len, actual_len); + } + + // Read the actual refcount values + let refcount_ptr_array: WasmPtr, wasmer::Array> = + WasmPtr::new(4 + refcount_vector_addr as u32); + let refcount_ptrs: &[Cell>] = refcount_ptr_array + .deref(memory, 0, num_refcounts as u32) + .unwrap(); + let mut refcounts = Vec::with_capacity(num_refcounts); for i in 0..num_refcounts { let rc_ptr = refcount_ptrs[i].get(); diff --git a/compiler/test_gen/src/helpers/wasm_test_platform.c b/compiler/test_gen/src/helpers/wasm_test_platform.c index 5ed3093c03..b01543d0de 100644 --- a/compiler/test_gen/src/helpers/wasm_test_platform.c +++ b/compiler/test_gen/src/helpers/wasm_test_platform.c @@ -3,35 +3,41 @@ // Makes test runs take 50% longer, due to linking #define ENABLE_PRINTF 0 +typedef struct +{ + size_t length; + size_t *elements[]; // flexible array member +} Vector; + // Globals for refcount testing -size_t **rc_pointers; // array of pointers to refcount values -size_t rc_pointers_len; -size_t rc_pointers_index; +Vector *rc_pointers; +size_t rc_pointers_capacity; // The rust test passes us the max number of allocations it expects to make, // and we tell it where we're going to write the refcount pointers. // It won't actually read that memory until later, when the test is done. -size_t **init_refcount_test(size_t max_allocs) +Vector *init_refcount_test(size_t capacity) { - rc_pointers = malloc(max_allocs * sizeof(size_t *)); - rc_pointers_len = max_allocs; - rc_pointers_index = 0; - for (size_t i = 0; i < max_allocs; ++i) - rc_pointers[i] = NULL; + rc_pointers_capacity = capacity; + + rc_pointers = malloc((1 + capacity) * sizeof(size_t *)); + rc_pointers->length = 0; + for (size_t i = 0; i < capacity; ++i) + rc_pointers->elements[i] = NULL; return rc_pointers; } #if ENABLE_PRINTF -#define ASSERT(x) \ - if (!(x)) \ - { \ - printf("FAILED: " #x "\n"); \ - abort(); \ +#define ASSERT(condition, format, ...) \ + if (!(condition)) \ + { \ + printf("ASSERT FAILED: " #format "\n", __VA_ARGS__); \ + abort(); \ } #else -#define ASSERT(x) \ - if (!(x)) \ +#define ASSERT(condition, format, ...) \ + if (!(condition)) \ abort(); #endif @@ -50,12 +56,13 @@ void *roc_alloc(size_t size, unsigned int alignment) if (rc_pointers) { - ASSERT(alignment >= sizeof(size_t)); - ASSERT(rc_pointers_index < rc_pointers_len); + ASSERT(alignment >= sizeof(size_t), "alignment %zd != %zd", alignment, sizeof(size_t)); + size_t num_alloc = rc_pointers->length + 1; + ASSERT(num_alloc <= rc_pointers_capacity, "Too many allocations %zd > %zd", num_alloc, rc_pointers_capacity); size_t *rc_ptr = alloc_ptr_to_rc_ptr(allocated, alignment); - rc_pointers[rc_pointers_index] = rc_ptr; - rc_pointers_index++; + rc_pointers->elements[rc_pointers->length] = rc_ptr; + rc_pointers->length++; } #if ENABLE_PRINTF @@ -94,16 +101,16 @@ void roc_dealloc(void *ptr, unsigned int alignment) // Then even if malloc reuses the space, everything still works size_t *rc_ptr = alloc_ptr_to_rc_ptr(ptr, alignment); int i = 0; - for (; i < rc_pointers_index; ++i) + for (; i < rc_pointers->length; ++i) { - if (rc_pointers[i] == rc_ptr) + if (rc_pointers->elements[i] == rc_ptr) { - rc_pointers[i] = NULL; + rc_pointers->elements[i] = NULL; break; } } - int was_found = i < rc_pointers_index; - ASSERT(was_found); + int was_found = i < rc_pointers->length; + ASSERT(was_found, "RC pointer not found %p", rc_ptr); } #if ENABLE_PRINTF