diff --git a/compiler/mono/src/code_gen_help/mod.rs b/compiler/mono/src/code_gen_help/mod.rs index b708d3324e..154984b383 100644 --- a/compiler/mono/src/code_gen_help/mod.rs +++ b/compiler/mono/src/code_gen_help/mod.rs @@ -35,7 +35,7 @@ enum HelperOp { impl HelperOp { fn is_decref(&self) -> bool { matches!(self, Self::DecRef(_)) - } + } } #[derive(Debug)] diff --git a/compiler/mono/src/code_gen_help/refcount.rs b/compiler/mono/src/code_gen_help/refcount.rs index c95fb55f10..943bb07463 100644 --- a/compiler/mono/src/code_gen_help/refcount.rs +++ b/compiler/mono/src/code_gen_help/refcount.rs @@ -189,6 +189,7 @@ pub fn rc_ptr_from_data_ptr<'a>( ident_ids: &mut IdentIds, structure: Symbol, rc_ptr_sym: Symbol, + mask_lower_bits: bool, following: &'a Stmt<'a>, ) -> Stmt<'a> { // Typecast the structure pointer to an integer @@ -203,6 +204,21 @@ pub fn rc_ptr_from_data_ptr<'a>( }); let addr_stmt = |next| Stmt::Let(addr_sym, addr_expr, root.layout_isize, next); + // Mask for lower bits (for tag union id) + let mask_sym = root.create_symbol(ident_ids, "mask"); + let mask_expr = Expr::Literal(Literal::Int(-(root.ptr_size as i128))); + let mask_stmt = |next| Stmt::Let(mask_sym, mask_expr, root.layout_isize, next); + + let masked_sym = root.create_symbol(ident_ids, "masked"); + let and_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::And, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([addr_sym, mask_sym]), + }); + let and_stmt = |next| Stmt::Let(masked_sym, and_expr, root.layout_isize, next); + // Pointer size constant let ptr_size_sym = root.create_symbol(ident_ids, "ptr_size"); let ptr_size_expr = Expr::Literal(Literal::Int(root.ptr_size as i128)); @@ -210,38 +226,67 @@ pub fn rc_ptr_from_data_ptr<'a>( // Refcount address let rc_addr_sym = root.create_symbol(ident_ids, "rc_addr"); - let rc_addr_expr = Expr::Call(Call { + let sub_expr = Expr::Call(Call { call_type: CallType::LowLevel { op: LowLevel::NumSub, update_mode: UpdateModeId::BACKEND_DUMMY, }, - arguments: root.arena.alloc([addr_sym, ptr_size_sym]), + arguments: root.arena.alloc([ + if mask_lower_bits { + masked_sym + } else { + addr_sym + }, + ptr_size_sym, + ]), }); - let rc_addr_stmt = |next| Stmt::Let(rc_addr_sym, rc_addr_expr, root.layout_isize, next); + let sub_stmt = |next| Stmt::Let(rc_addr_sym, sub_expr, root.layout_isize, next); // Typecast the refcount address from integer to pointer - let rc_ptr_expr = Expr::Call(Call { + let cast_expr = Expr::Call(Call { call_type: CallType::LowLevel { op: LowLevel::PtrCast, update_mode: UpdateModeId::BACKEND_DUMMY, }, arguments: root.arena.alloc([rc_addr_sym]), }); - let rc_ptr_stmt = |next| Stmt::Let(rc_ptr_sym, rc_ptr_expr, LAYOUT_PTR, next); + let cast_stmt = |next| Stmt::Let(rc_ptr_sym, cast_expr, LAYOUT_PTR, next); - addr_stmt(root.arena.alloc( - // - ptr_size_stmt(root.arena.alloc( + if mask_lower_bits { + addr_stmt(root.arena.alloc( // - rc_addr_stmt(root.arena.alloc( + mask_stmt(root.arena.alloc( // - rc_ptr_stmt(root.arena.alloc( + and_stmt(root.arena.alloc( // - following, + ptr_size_stmt(root.arena.alloc( + // + sub_stmt(root.arena.alloc( + // + cast_stmt(root.arena.alloc( + // + following, + )), + )), + )), )), )), - )), - )) + )) + } else { + addr_stmt(root.arena.alloc( + // + ptr_size_stmt(root.arena.alloc( + // + sub_stmt(root.arena.alloc( + // + cast_stmt(root.arena.alloc( + // + following, + )), + )), + )), + )) + } } fn modify_refcount<'a>( @@ -356,6 +401,7 @@ fn refcount_str<'a>( ident_ids, elements, rc_ptr, + false, root.arena.alloc( // mod_rc_stmt, @@ -467,7 +513,14 @@ fn refcount_list<'a>( let modify_list_and_elems = elements_stmt(arena.alloc( // - rc_ptr_from_data_ptr(root, ident_ids, elements, rc_ptr, arena.alloc(modify_list)), + rc_ptr_from_data_ptr( + root, + ident_ids, + elements, + rc_ptr, + false, + arena.alloc(modify_list), + ), )); // @@ -753,15 +806,6 @@ fn refcount_tag_union_help<'a>( structure: Symbol, ) -> Stmt<'a> { let is_non_recursive = matches!(union_layout, UnionLayout::NonRecursive(_)); - /* TODO: tail recursion - let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop")); - let structure = if is_non_recursive { - initial_structure - } else { - // current value in the tail-recursive loop - root.create_symbol(ident_ids, "structure") - }; - */ let tag_id_layout = union_layout.tag_id_layout(); let tag_id_sym = root.create_symbol(ident_ids, "tag_id"); @@ -836,6 +880,7 @@ fn refcount_tag_union_help<'a>( ident_ids, structure, rc_ptr, + union_layout.stores_tag_id_in_pointer(root.ptr_size), root.arena.alloc(modify_structure_stmt), ); diff --git a/compiler/test_gen/src/gen_refcount.rs b/compiler/test_gen/src/gen_refcount.rs index 2c82756e8c..490c1b7cb0 100644 --- a/compiler/test_gen/src/gen_refcount.rs +++ b/compiler/test_gen/src/gen_refcount.rs @@ -206,21 +206,54 @@ fn union_recursive_inc() { r#" Expr : [ Sym Str, Add Expr Expr ] - s = Str.concat "heap_allocated_" "symbol_name" + s = Str.concat "heap_allocated" "_symbol_name" + + x : Expr + x = Sym s e : Expr - e = Add (Sym s) (Sym s) + e = Add x x [e, e] "# ), - RocStr, + // test_wrapper receives a List, doesn't matter kind of elements it points to + RocList, &[ 4, // s - 1, // Sym - 1, // Sym - 2, // Add + 4, // sym + 2, // e 1 // list ] ); } + +#[test] +#[cfg(any(feature = "gen-wasm"))] +fn union_recursive_dec() { + assert_refcounts!( + indoc!( + r#" + Expr : [ Sym Str, Add Expr Expr ] + + s = Str.concat "heap_allocated" "_symbol_name" + + x : Expr + x = Sym s + + e : Expr + e = Add x x + + when e is + Add y _ -> y + Sym _ -> e + "# + ), + &RocStr, + &[ + 1, // s + 1, // sym + 0 // e + ] + ); +}