diff --git a/compiler/mono/src/code_gen_help.rs b/compiler/mono/src/code_gen_help.rs deleted file mode 100644 index 8dd4addbf5..0000000000 --- a/compiler/mono/src/code_gen_help.rs +++ /dev/null @@ -1,1430 +0,0 @@ -use bumpalo::collections::vec::Vec; -use bumpalo::Bump; -use roc_builtins::bitcode::IntWidth; -use roc_module::ident::Ident; -use roc_module::low_level::LowLevel; -use roc_module::symbol::{IdentIds, ModuleId, Symbol}; - -use crate::ir::{ - BranchInfo, Call, CallSpecId, CallType, Expr, HostExposedLayouts, JoinPointId, Literal, - ModifyRc, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId, -}; -use crate::layout::{Builtin, Layout, TagIdIntType, UnionLayout}; - -const LAYOUT_BOOL: Layout = Layout::Builtin(Builtin::Bool); -const LAYOUT_UNIT: Layout = Layout::Struct(&[]); -const LAYOUT_PTR: Layout = Layout::RecursivePointer; -const LAYOUT_U32: Layout = Layout::Builtin(Builtin::Int(IntWidth::U32)); - -const ARG_1: Symbol = Symbol::ARG_1; -const ARG_2: Symbol = Symbol::ARG_2; - -/// "Infinite" reference count, for static values -/// Ref counts are encoded as negative numbers where isize::MIN represents 1 -pub const REFCOUNT_MAX: usize = 0; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum HelperOp { - Inc, - Dec, - DecRef, - Eq, -} - -impl From<&ModifyRc> for HelperOp { - fn from(modify: &ModifyRc) -> Self { - match modify { - ModifyRc::Inc(..) => Self::Inc, - ModifyRc::Dec(_) => Self::Dec, - ModifyRc::DecRef(_) => Self::DecRef, - } - } -} - -#[derive(Debug)] -struct Specialization<'a> { - op: HelperOp, - layout: Layout<'a>, - symbol: Symbol, - proc: Option>, -} - -#[derive(Debug)] -struct Context<'a> { - new_linker_data: Vec<'a, (Symbol, ProcLayout<'a>)>, - recursive_union: Option>, - op: HelperOp, -} - -/// Generate specialized helper procs for code gen -/// ---------------------------------------------- -/// -/// Some low level operations need specialized helper procs to traverse data structures at runtime. -/// This includes refcounting, hashing, and equality checks. -/// -/// For example, when checking List equality, we need to visit each element and compare them. -/// Depending on the type of the list elements, we may need to recurse deeper into each element. -/// For tag unions, we may need branches for different tag IDs, etc. -/// -/// This module creates specialized helper procs for all such operations and types used in the program. -/// -/// The backend drives the process, in two steps: -/// 1) When it sees the relevant node, it calls CodeGenHelp to get the replacement IR. -/// CodeGenHelp returns IR for a call to the helper proc, and remembers the specialization. -/// 2) After the backend has generated code for all user procs, it takes the IR for all of the -/// specialized helpers procs, and generates target code for them too. -/// -pub struct CodeGenHelp<'a> { - arena: &'a Bump, - home: ModuleId, - ptr_size: u32, - layout_isize: Layout<'a>, - specializations: Vec<'a, Specialization<'a>>, - debug_recursion_depth: usize, -} - -impl<'a> CodeGenHelp<'a> { - pub fn new(arena: &'a Bump, intwidth_isize: IntWidth, home: ModuleId) -> Self { - CodeGenHelp { - arena, - home, - ptr_size: intwidth_isize.stack_size(), - layout_isize: Layout::Builtin(Builtin::Int(intwidth_isize)), - specializations: Vec::with_capacity_in(16, arena), - debug_recursion_depth: 0, - } - } - - pub fn take_procs(&mut self) -> Vec<'a, Proc<'a>> { - let procs_iter = self - .specializations - .drain(0..) - .map(|spec| spec.proc.unwrap()); - Vec::from_iter_in(procs_iter, self.arena) - } - - // ============================================================================ - // - // CALL GENERATED PROCS - // - // ============================================================================ - - /// Expand a `Refcounting` node to a `Let` node that calls a specialized helper proc. - /// The helper procs themselves are to be generated later with `generate_procs` - pub fn expand_refcount_stmt( - &mut self, - ident_ids: &mut IdentIds, - layout: Layout<'a>, - modify: &ModifyRc, - following: &'a Stmt<'a>, - ) -> (&'a Stmt<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) { - if !Self::is_rc_implemented_yet(&layout) { - // Just a warning, so we can decouple backend development from refcounting development. - // When we are closer to completion, we can change it to a panic. - println!( - "WARNING! MEMORY LEAK! Refcounting not yet implemented for Layout {:?}", - layout - ); - return (following, Vec::new_in(self.arena)); - } - - let arena = self.arena; - - let mut ctx = Context { - new_linker_data: Vec::new_in(self.arena), - recursive_union: None, - op: HelperOp::from(modify), - }; - - match modify { - ModifyRc::Inc(structure, amount) => { - let layout_isize = self.layout_isize; - - // Define a constant for the amount to increment - let amount_sym = self.create_symbol(ident_ids, "amount"); - let amount_expr = Expr::Literal(Literal::Int(*amount as i128)); - let amount_stmt = |next| Stmt::Let(amount_sym, amount_expr, layout_isize, next); - - // Call helper proc, passing the Roc structure and constant amount - let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); - let call_expr = self.call_specialized_op( - ident_ids, - &mut ctx, - layout, - arena.alloc([*structure, amount_sym]), - ); - let call_stmt = Stmt::Let(call_result_empty, call_expr, LAYOUT_UNIT, following); - let rc_stmt = arena.alloc(amount_stmt(arena.alloc(call_stmt))); - - (rc_stmt, ctx.new_linker_data) - } - - ModifyRc::Dec(structure) => { - // Call helper proc, passing the Roc structure - let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); - let call_expr = self.call_specialized_op( - ident_ids, - &mut ctx, - layout, - arena.alloc([*structure]), - ); - - let rc_stmt = arena.alloc(Stmt::Let( - call_result_empty, - call_expr, - LAYOUT_UNIT, - following, - )); - - (rc_stmt, ctx.new_linker_data) - } - - ModifyRc::DecRef(structure) => { - // No generated procs for DecRef, just lowlevel ops - let rc_ptr_sym = self.create_symbol(ident_ids, "rc_ptr"); - - // Pass the refcount pointer to the lowlevel call (see utils.zig) - let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); - let call_expr = Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::RefCountDec, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: arena.alloc([rc_ptr_sym]), - }); - let call_stmt = Stmt::Let(call_result_empty, call_expr, LAYOUT_UNIT, following); - - let rc_stmt = arena.alloc(self.rc_ptr_from_struct( - ident_ids, - *structure, - rc_ptr_sym, - arena.alloc(call_stmt), - )); - - (rc_stmt, ctx.new_linker_data) - } - } - } - - // Check if refcounting is implemented yet. In the long term, this will be deleted. - // In the short term, it helps us to skip refcounting and let it leak, so we can make - // progress incrementally. Kept in sync with generate_procs using assertions. - fn is_rc_implemented_yet(layout: &Layout) -> bool { - matches!(layout, Layout::Builtin(Builtin::Str)) - } - - /// Replace a generic `Lowlevel::Eq` call with a specialized helper proc. - /// The helper procs themselves are to be generated later with `generate_procs` - pub fn call_specialized_equals( - &mut self, - ident_ids: &mut IdentIds, - layout: &Layout<'a>, - arguments: &'a [Symbol], - ) -> (Expr<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) { - let mut ctx = Context { - new_linker_data: Vec::new_in(self.arena), - recursive_union: None, - op: HelperOp::Eq, - }; - - let expr = self.call_specialized_op(ident_ids, &mut ctx, *layout, arguments); - - (expr, ctx.new_linker_data) - } - - // ============================================================================ - // - // CALL SPECIALIZED OP - // - // ============================================================================ - - fn call_specialized_op( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - called_layout: Layout<'a>, - arguments: &[Symbol], - ) -> Expr<'a> { - use HelperOp::*; - - debug_assert!(self.debug_recursion_depth < 10); - self.debug_recursion_depth += 1; - - let layout = if matches!(called_layout, Layout::RecursivePointer) { - let union_layout = ctx.recursive_union.unwrap(); - Layout::Union(union_layout) - } else { - called_layout - }; - - if layout_needs_helper_proc(&layout, ctx.op) { - let proc_name = self.find_or_create_proc(ident_ids, ctx, layout); - - let (ret_layout, arg_layouts): (&'a Layout<'a>, &'a [Layout<'a>]) = { - match ctx.op { - Dec | DecRef => (&LAYOUT_UNIT, self.arena.alloc([layout])), - Inc => (&LAYOUT_UNIT, self.arena.alloc([layout, self.layout_isize])), - Eq => (&LAYOUT_BOOL, self.arena.alloc([layout, layout])), - } - }; - - Expr::Call(Call { - call_type: CallType::ByName { - name: proc_name, - ret_layout, - arg_layouts, - specialization_id: CallSpecId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc_slice_copy(arguments), - }) - } else { - Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::Eq, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc_slice_copy(arguments), - }) - } - } - - fn find_or_create_proc( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - layout: Layout<'a>, - ) -> Symbol { - use HelperOp::*; - - let found = self - .specializations - .iter() - .find(|spec| spec.op == ctx.op && spec.layout == layout); - - if let Some(spec) = found { - return spec.symbol; - } - - // Procs can be recursive, so we need to create the symbol before the body is complete - // But with nested recursion, that means Symbols and Procs can end up in different orders. - // We want the same order, especially for function indices in Wasm. So create an empty slot and fill it in later. - let (proc_symbol, proc_layout) = self.create_proc_symbol(ident_ids, ctx, &layout); - ctx.new_linker_data.push((proc_symbol, proc_layout)); - let spec_index = self.specializations.len(); - self.specializations.push(Specialization { - op: ctx.op, - layout, - symbol: proc_symbol, - proc: None, - }); - - // Recursively generate the body of the Proc and sub-procs - let (ret_layout, body) = match ctx.op { - Inc | Dec | DecRef => (LAYOUT_UNIT, self.refcount_generic(ident_ids, ctx, layout)), - Eq => (LAYOUT_BOOL, self.eq_generic(ident_ids, ctx, layout)), - }; - - let args: &'a [(Layout<'a>, Symbol)] = { - let roc_value = (layout, ARG_1); - match ctx.op { - Inc => { - let inc_amount = (self.layout_isize, ARG_2); - self.arena.alloc([roc_value, inc_amount]) - } - Dec | DecRef => self.arena.alloc([roc_value]), - Eq => self.arena.alloc([roc_value, (layout, ARG_2)]), - } - }; - - self.specializations[spec_index].proc = Some(Proc { - name: proc_symbol, - args, - body, - closure_data_layout: None, - ret_layout, - is_self_recursive: SelfRecursive::NotSelfRecursive, - must_own_arguments: false, - host_exposed_layouts: HostExposedLayouts::NotHostExposed, - }); - - proc_symbol - } - - fn create_proc_symbol( - &self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - layout: &Layout<'a>, - ) -> (Symbol, ProcLayout<'a>) { - let debug_name = format!( - "#help{}_{:?}_{:?}", - self.specializations.len(), - ctx.op, - layout - ) - .replace("Builtin", ""); - let proc_symbol: Symbol = self.create_symbol(ident_ids, &debug_name); - - let proc_layout = match ctx.op { - HelperOp::Inc => ProcLayout { - arguments: self.arena.alloc([*layout, self.layout_isize]), - result: LAYOUT_UNIT, - }, - HelperOp::Dec => ProcLayout { - arguments: self.arena.alloc([*layout]), - result: LAYOUT_UNIT, - }, - HelperOp::DecRef => unreachable!("No generated Proc for DecRef"), - HelperOp::Eq => ProcLayout { - arguments: self.arena.alloc([*layout, *layout]), - result: LAYOUT_BOOL, - }, - }; - - (proc_symbol, proc_layout) - } - - fn create_symbol(&self, ident_ids: &mut IdentIds, debug_name: &str) -> Symbol { - let ident_id = ident_ids.add(Ident::from(debug_name)); - Symbol::new(self.home, ident_id) - } - - // ============================================================================ - // - // GENERATE REFCOUNTING - // - // ============================================================================ - - fn refcount_generic( - &self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - layout: Layout<'a>, - ) -> Stmt<'a> { - debug_assert!(Self::is_rc_implemented_yet(&layout)); - let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout); - - match layout { - Layout::Builtin( - Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, - ) => unreachable!("Not refcounted: {:?}", layout), - Layout::Builtin(Builtin::Str) => self.refcount_str(ident_ids, ctx), - Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => rc_todo(), - Layout::Struct(_) => rc_todo(), - Layout::Union(_) => rc_todo(), - Layout::LambdaSet(_) => { - unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.") - } - Layout::RecursivePointer => rc_todo(), - } - } - - fn return_unit(&self, ident_ids: &mut IdentIds) -> Stmt<'a> { - let unit = self.create_symbol(ident_ids, "unit"); - let ret_stmt = self.arena.alloc(Stmt::Ret(unit)); - Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt) - } - - // Subtract a constant from a pointer to find the refcount - // Also does some type casting, so that we have different Symbols and Layouts - // for the 'pointer' and 'integer' versions of the address. - // This helps to avoid issues with the backends Symbol->Layout mapping. - fn rc_ptr_from_struct( - &self, - ident_ids: &mut IdentIds, - structure: Symbol, - rc_ptr_sym: Symbol, - following: &'a Stmt<'a>, - ) -> Stmt<'a> { - // Typecast the structure pointer to an integer - // Backends expect a number Layout to choose the right "subtract" instruction - let addr_sym = self.create_symbol(ident_ids, "addr"); - let addr_expr = Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::PtrCast, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([structure]), - }); - let addr_stmt = |next| Stmt::Let(addr_sym, addr_expr, self.layout_isize, next); - - // Pointer size constant - let ptr_size_sym = self.create_symbol(ident_ids, "ptr_size"); - let ptr_size_expr = Expr::Literal(Literal::Int(self.ptr_size as i128)); - let ptr_size_stmt = |next| Stmt::Let(ptr_size_sym, ptr_size_expr, self.layout_isize, next); - - // Refcount address - let rc_addr_sym = self.create_symbol(ident_ids, "rc_addr"); - let rc_addr_expr = Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::NumSub, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([structure, ptr_size_sym]), - }); - let rc_addr_stmt = |next| Stmt::Let(rc_addr_sym, rc_addr_expr, self.layout_isize, next); - - // Typecast the refcount address from integer to pointer - let rc_ptr_expr = Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::PtrCast, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([rc_addr_sym]), - }); - let rc_ptr_stmt = |next| Stmt::Let(rc_ptr_sym, rc_ptr_expr, LAYOUT_PTR, next); - - addr_stmt(self.arena.alloc( - // - ptr_size_stmt(self.arena.alloc( - // - rc_addr_stmt(self.arena.alloc( - // - rc_ptr_stmt(self.arena.alloc( - // - following, - )), - )), - )), - )) - } - - /// Generate a procedure to modify the reference count of a Str - fn refcount_str(&self, ident_ids: &mut IdentIds, ctx: &mut Context<'a>) -> Stmt<'a> { - let op = ctx.op; - - let string = ARG_1; - let layout_isize = self.layout_isize; - - // Get the string length as a signed int - let len = self.create_symbol(ident_ids, "len"); - let len_expr = Expr::StructAtIndex { - index: 1, - field_layouts: self.arena.alloc([LAYOUT_PTR, layout_isize]), - structure: string, - }; - let len_stmt = |next| Stmt::Let(len, len_expr, layout_isize, next); - - // Zero - let zero = self.create_symbol(ident_ids, "zero"); - let zero_expr = Expr::Literal(Literal::Int(0)); - let zero_stmt = |next| Stmt::Let(zero, zero_expr, layout_isize, next); - - // is_big_str = (len >= 0); - // Treat len as isize so that the small string flag is the same as the sign bit - let is_big_str = self.create_symbol(ident_ids, "is_big_str"); - let is_big_str_expr = Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::NumGte, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([len, zero]), - }); - let is_big_str_stmt = |next| Stmt::Let(is_big_str, is_big_str_expr, LAYOUT_BOOL, next); - - // Get the pointer to the string elements - let elements = self.create_symbol(ident_ids, "elements"); - let elements_expr = Expr::StructAtIndex { - index: 0, - field_layouts: self.arena.alloc([LAYOUT_PTR, layout_isize]), - structure: string, - }; - let elements_stmt = |next| Stmt::Let(elements, elements_expr, layout_isize, next); - - // A pointer to the refcount value itself - let rc_ptr = self.create_symbol(ident_ids, "rc_ptr"); - - // Alignment constant (same value as ptr_size but different layout) - let alignment = self.create_symbol(ident_ids, "alignment"); - let alignment_expr = Expr::Literal(Literal::Int(self.ptr_size as i128)); - let alignment_stmt = |next| Stmt::Let(alignment, alignment_expr, LAYOUT_U32, next); - - // Call the relevant Zig lowlevel to actually modify the refcount - let zig_call_result = self.create_symbol(ident_ids, "zig_call_result"); - let zig_call_expr = match op { - HelperOp::Inc => Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::RefCountInc, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([rc_ptr, ARG_2]), - }), - HelperOp::Dec | HelperOp::DecRef => Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::RefCountDec, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([rc_ptr, alignment]), - }), - _ => unreachable!(), - }; - let zig_call_stmt = |next| Stmt::Let(zig_call_result, zig_call_expr, LAYOUT_UNIT, next); - - // Generate an `if` to skip small strings but modify big strings - let then_branch = elements_stmt(self.arena.alloc( - // - self.rc_ptr_from_struct( - ident_ids, - elements, - rc_ptr, - self.arena.alloc( - // - alignment_stmt(self.arena.alloc( - // - zig_call_stmt(self.arena.alloc( - // - Stmt::Ret(zig_call_result), - )), - )), - ), - ), - )); - let if_stmt = Stmt::Switch { - cond_symbol: is_big_str, - cond_layout: LAYOUT_BOOL, - branches: self.arena.alloc([(1, BranchInfo::None, then_branch)]), - default_branch: ( - BranchInfo::None, - self.arena.alloc(self.return_unit(ident_ids)), - ), - ret_layout: LAYOUT_UNIT, - }; - - // Combine the statements in sequence - len_stmt(self.arena.alloc( - // - zero_stmt(self.arena.alloc( - // - is_big_str_stmt(self.arena.alloc( - // - if_stmt, - )), - )), - )) - } - - // ============================================================================ - // - // GENERATE EQUALS - // - // ============================================================================ - - fn eq_generic( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - layout: Layout<'a>, - ) -> Stmt<'a> { - let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout); - - let main_body = match layout { - Layout::Builtin( - Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, - ) => unreachable!( - "No generated proc for `==`. Use direct code gen for {:?}", - layout - ), - Layout::Builtin(Builtin::Str) => { - unreachable!("No generated helper proc for `==` on Str. Use Zig function.") - } - Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(), - Layout::Builtin(Builtin::List(elem_layout)) => { - self.eq_list(ident_ids, ctx, elem_layout) - } - Layout::Struct(field_layouts) => self.eq_struct(ident_ids, ctx, field_layouts), - Layout::Union(union_layout) => self.eq_tag_union(ident_ids, ctx, union_layout), - Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"), - Layout::RecursivePointer => { - unreachable!("Can't perform `==` on RecursivePointer. Should have been replaced by a tag union.") - } - }; - - Stmt::Let( - Symbol::BOOL_TRUE, - Expr::Literal(Literal::Int(1)), - LAYOUT_BOOL, - self.arena.alloc(Stmt::Let( - Symbol::BOOL_FALSE, - Expr::Literal(Literal::Int(0)), - LAYOUT_BOOL, - self.arena.alloc(main_body), - )), - ) - } - - fn if_pointers_equal_return_true( - &self, - ident_ids: &mut IdentIds, - operands: [Symbol; 2], - following: &'a Stmt<'a>, - ) -> Stmt<'a> { - let ptr1_addr = self.create_symbol(ident_ids, "addr1"); - let ptr2_addr = self.create_symbol(ident_ids, "addr2"); - let ptr_eq = self.create_symbol(ident_ids, "eq_addr"); - - Stmt::Let( - ptr1_addr, - Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::PtrCast, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([operands[0]]), - }), - self.layout_isize, - self.arena.alloc(Stmt::Let( - ptr2_addr, - Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::PtrCast, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([operands[1]]), - }), - self.layout_isize, - self.arena.alloc(Stmt::Let( - ptr_eq, - Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::Eq, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([ptr1_addr, ptr2_addr]), - }), - LAYOUT_BOOL, - self.arena.alloc(Stmt::Switch { - cond_symbol: ptr_eq, - cond_layout: LAYOUT_BOOL, - branches: self.arena.alloc([( - 1, - BranchInfo::None, - Stmt::Ret(Symbol::BOOL_TRUE), - )]), - default_branch: (BranchInfo::None, following), - ret_layout: LAYOUT_BOOL, - }), - )), - )), - ) - } - - fn if_false_return_false(&self, symbol: Symbol, following: &'a Stmt<'a>) -> Stmt<'a> { - Stmt::Switch { - cond_symbol: symbol, - cond_layout: LAYOUT_BOOL, - branches: self - .arena - .alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]), - default_branch: (BranchInfo::None, following), - ret_layout: LAYOUT_BOOL, - } - } - - fn eq_struct( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - field_layouts: &'a [Layout<'a>], - ) -> Stmt<'a> { - let mut else_stmt = Stmt::Ret(Symbol::BOOL_TRUE); - for (i, layout) in field_layouts.iter().enumerate().rev() { - let field1_sym = self.create_symbol(ident_ids, &format!("field_1_{}", i)); - let field1_expr = Expr::StructAtIndex { - index: i as u64, - field_layouts, - structure: ARG_1, - }; - let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next); - - let field2_sym = self.create_symbol(ident_ids, &format!("field_2_{}", i)); - let field2_expr = Expr::StructAtIndex { - index: i as u64, - field_layouts, - structure: ARG_2, - }; - let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next); - - let eq_call_expr = self.call_specialized_op( - ident_ids, - ctx, - *layout, - self.arena.alloc([field1_sym, field2_sym]), - ); - - let eq_call_name = format!("eq_call_{}", i); - let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name); - let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next); - - else_stmt = field1_stmt(self.arena.alloc( - // - field2_stmt(self.arena.alloc( - // - eq_call_stmt(self.arena.alloc( - // - self.if_false_return_false(eq_call_sym, self.arena.alloc(else_stmt)), - )), - )), - )) - } - - self.if_pointers_equal_return_true(ident_ids, [ARG_1, ARG_2], self.arena.alloc(else_stmt)) - } - - fn eq_tag_union( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - union_layout: UnionLayout<'a>, - ) -> Stmt<'a> { - use UnionLayout::*; - - let parent_rec_ptr_layout = ctx.recursive_union; - if !matches!(union_layout, NonRecursive(_)) { - ctx.recursive_union = Some(union_layout); - } - - let body = match union_layout { - NonRecursive(tags) => self.eq_tag_union_help(ident_ids, ctx, union_layout, tags, None), - - Recursive(tags) => self.eq_tag_union_help(ident_ids, ctx, union_layout, tags, None), - - NonNullableUnwrapped(field_layouts) => { - let tags = self.arena.alloc([field_layouts]); - self.eq_tag_union_help(ident_ids, ctx, union_layout, tags, None) - } - - NullableWrapped { - other_tags, - nullable_id, - } => { - self.eq_tag_union_help(ident_ids, ctx, union_layout, other_tags, Some(nullable_id)) - } - - NullableUnwrapped { - other_fields, - nullable_id, - } => self.eq_tag_union_help( - ident_ids, - ctx, - union_layout, - self.arena.alloc([other_fields]), - Some(nullable_id as TagIdIntType), - ), - }; - - ctx.recursive_union = parent_rec_ptr_layout; - - body - } - - fn eq_tag_union_help( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - union_layout: UnionLayout<'a>, - tag_layouts: &'a [&'a [Layout<'a>]], - nullable_id: Option, - ) -> Stmt<'a> { - let tailrec_loop = JoinPointId(self.create_symbol(ident_ids, "tailrec_loop")); - let is_non_recursive = matches!(union_layout, UnionLayout::NonRecursive(_)); - let operands = if is_non_recursive { - [ARG_1, ARG_2] - } else { - [ - self.create_symbol(ident_ids, "a"), - self.create_symbol(ident_ids, "b"), - ] - }; - - let tag_id_layout = union_layout.tag_id_layout(); - - let tag_id_a = self.create_symbol(ident_ids, "tag_id_a"); - let tag_id_a_stmt = |next| { - Stmt::Let( - tag_id_a, - Expr::GetTagId { - structure: operands[0], - union_layout, - }, - tag_id_layout, - next, - ) - }; - - let tag_id_b = self.create_symbol(ident_ids, "tag_id_b"); - let tag_id_b_stmt = |next| { - Stmt::Let( - tag_id_b, - Expr::GetTagId { - structure: operands[1], - union_layout, - }, - tag_id_layout, - next, - ) - }; - - let tag_ids_eq = self.create_symbol(ident_ids, "tag_ids_eq"); - let tag_ids_expr = Expr::Call(Call { - call_type: CallType::LowLevel { - op: LowLevel::Eq, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: self.arena.alloc([tag_id_a, tag_id_b]), - }); - let tag_ids_eq_stmt = |next| Stmt::Let(tag_ids_eq, tag_ids_expr, LAYOUT_BOOL, next); - - let if_equal_ids_branches = - self.arena - .alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]); - - // - // Switch statement by tag ID - // - - let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), self.arena); - - // If there's a null tag, check it first. We might not need to load any data from memory. - if let Some(id) = nullable_id { - tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))) - } - - let mut tag_id: TagIdIntType = 0; - for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) { - if let Some(null_id) = nullable_id { - if tag_id == null_id as TagIdIntType { - tag_id += 1; - } - } - - let tag_stmt = self.eq_tag_fields( - ident_ids, - ctx, - tailrec_loop, - union_layout, - field_layouts, - operands, - tag_id, - ); - tag_branches.push((tag_id as u64, BranchInfo::None, tag_stmt)); - - tag_id += 1; - } - - let tag_switch_stmt = Stmt::Switch { - cond_symbol: tag_id_a, - cond_layout: tag_id_layout, - branches: tag_branches.into_bump_slice(), - default_branch: ( - BranchInfo::None, - self.arena.alloc(self.eq_tag_fields( - ident_ids, - ctx, - tailrec_loop, - union_layout, - tag_layouts.last().unwrap(), - operands, - tag_id, - )), - ), - ret_layout: LAYOUT_BOOL, - }; - - let if_equal_ids_stmt = Stmt::Switch { - cond_symbol: tag_ids_eq, - cond_layout: LAYOUT_BOOL, - branches: if_equal_ids_branches, - default_branch: (BranchInfo::None, self.arena.alloc(tag_switch_stmt)), - ret_layout: LAYOUT_BOOL, - }; - - // - // combine all the statments - // - let compare_values = tag_id_a_stmt(self.arena.alloc( - // - tag_id_b_stmt(self.arena.alloc( - // - tag_ids_eq_stmt(self.arena.alloc( - // - if_equal_ids_stmt, - )), - )), - )); - - let compare_ptr_or_value = self.if_pointers_equal_return_true( - ident_ids, - operands, - self.arena.alloc(compare_values), - ); - - if is_non_recursive { - compare_ptr_or_value - } else { - let loop_params_iter = operands.iter().map(|arg| Param { - symbol: *arg, - borrow: true, - layout: Layout::Union(union_layout), - }); - - let loop_start = Stmt::Jump(tailrec_loop, self.arena.alloc([ARG_1, ARG_2])); - - Stmt::Join { - id: tailrec_loop, - parameters: self.arena.alloc_slice_fill_iter(loop_params_iter), - body: self.arena.alloc(compare_ptr_or_value), - remainder: self.arena.alloc(loop_start), - } - } - } - - #[allow(clippy::too_many_arguments)] - fn eq_tag_fields( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - tailrec_loop: JoinPointId, - union_layout: UnionLayout<'a>, - field_layouts: &'a [Layout<'a>], - operands: [Symbol; 2], - tag_id: TagIdIntType, - ) -> Stmt<'a> { - // Find a RecursivePointer to use in the tail recursion loop - // (If there are more than one, the others will use non-tail recursion) - let rec_ptr_index = field_layouts - .iter() - .position(|field| matches!(field, Layout::RecursivePointer)); - - let (tailrec_index, innermost_stmt) = match rec_ptr_index { - None => { - // This tag has no RecursivePointers. Set tailrec_index out of range. - (field_layouts.len(), Stmt::Ret(Symbol::BOOL_TRUE)) - } - - Some(i) => { - // Implement tail recursion on this RecursivePointer, - // in the innermost `else` clause after all other fields have been checked - let field1_sym = - self.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i)); - let field2_sym = - self.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i)); - - let field1_expr = Expr::UnionAtIndex { - union_layout, - tag_id, - index: i as u64, - structure: operands[0], - }; - - let field2_expr = Expr::UnionAtIndex { - union_layout, - tag_id, - index: i as u64, - structure: operands[1], - }; - - let inner = Stmt::Let( - field1_sym, - field1_expr, - field_layouts[i], - self.arena.alloc( - // - Stmt::Let( - field2_sym, - field2_expr, - field_layouts[i], - self.arena.alloc( - // - Stmt::Jump( - tailrec_loop, - self.arena.alloc([field1_sym, field2_sym]), - ), - ), - ), - ), - ); - - (i, inner) - } - }; - - let mut stmt = innermost_stmt; - for (i, layout) in field_layouts.iter().enumerate().rev() { - if i == tailrec_index { - continue; // the tail-recursive field is handled elsewhere - } - - let field1_sym = self.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i)); - let field2_sym = self.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i)); - - let field1_expr = Expr::UnionAtIndex { - union_layout, - tag_id, - index: i as u64, - structure: operands[0], - }; - - let field2_expr = Expr::UnionAtIndex { - union_layout, - tag_id, - index: i as u64, - structure: operands[1], - }; - - let eq_call_expr = self.call_specialized_op( - ident_ids, - ctx, - *layout, - self.arena.alloc([field1_sym, field2_sym]), - ); - - let eq_call_name = format!("eq_call_{}", i); - let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name); - - stmt = Stmt::Let( - field1_sym, - field1_expr, - field_layouts[i], - self.arena.alloc( - // - Stmt::Let( - field2_sym, - field2_expr, - field_layouts[i], - self.arena.alloc( - // - Stmt::Let( - eq_call_sym, - eq_call_expr, - LAYOUT_BOOL, - self.arena.alloc( - // - self.if_false_return_false( - eq_call_sym, - self.arena.alloc( - // - stmt, - ), - ), - ), - ), - ), - ), - ), - ) - } - stmt - } - - /// List equality - /// We can't use `ListGetUnsafe` because it increments the refcount, and we don't want that. - /// Another way to dereference a heap pointer is to use `Expr::UnionAtIndex`. - /// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout. - /// Then we can increment the Box pointer in a loop, dereferencing it each time. - /// (An alternative approach would be to create a new lowlevel like ListPeekUnsafe.) - fn eq_list( - &mut self, - ident_ids: &mut IdentIds, - ctx: &mut Context<'a>, - elem_layout: &Layout<'a>, - ) -> Stmt<'a> { - use LowLevel::*; - let layout_isize = self.layout_isize; - let arena = self.arena; - - // A "Box" layout (heap pointer to a single list element) - let box_union_layout = UnionLayout::NonNullableUnwrapped(self.arena.alloc([*elem_layout])); - let box_layout = Layout::Union(box_union_layout); - - // Compare lengths - - let len_1 = self.create_symbol(ident_ids, "len_1"); - let len_2 = self.create_symbol(ident_ids, "len_2"); - let len_1_stmt = |next| let_lowlevel(arena, layout_isize, len_1, ListLen, &[ARG_1], next); - let len_2_stmt = |next| let_lowlevel(arena, layout_isize, len_2, ListLen, &[ARG_2], next); - - let eq_len = self.create_symbol(ident_ids, "eq_len"); - let eq_len_stmt = - |next| let_lowlevel(arena, LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next); - - // if lengths are equal... - - // get element pointers - let elements_1 = self.create_symbol(ident_ids, "elements_1"); - let elements_2 = self.create_symbol(ident_ids, "elements_2"); - let elements_1_expr = Expr::StructAtIndex { - index: 0, - field_layouts: self.arena.alloc([box_layout, layout_isize]), - structure: ARG_1, - }; - let elements_2_expr = Expr::StructAtIndex { - index: 0, - field_layouts: self.arena.alloc([box_layout, layout_isize]), - structure: ARG_2, - }; - let elements_1_stmt = |next| Stmt::Let(elements_1, elements_1_expr, box_layout, next); - let elements_2_stmt = |next| Stmt::Let(elements_2, elements_2_expr, box_layout, next); - - // Cast to integers - let start_1 = self.create_symbol(ident_ids, "start_1"); - let start_2 = self.create_symbol(ident_ids, "start_2"); - let start_1_stmt = - |next| let_lowlevel(arena, layout_isize, start_1, PtrCast, &[elements_1], next); - let start_2_stmt = - |next| let_lowlevel(arena, layout_isize, start_2, PtrCast, &[elements_2], next); - - // - // Loop initialisation - // - - // let size = literal int - let size = self.create_symbol(ident_ids, "size"); - let size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(self.ptr_size) as i128)); - let size_stmt = |next| Stmt::Let(size, size_expr, layout_isize, next); - - // let list_size = len_1 * size - let list_size = self.create_symbol(ident_ids, "list_size"); - let list_size_stmt = - |next| let_lowlevel(arena, layout_isize, list_size, NumMul, &[len_1, size], next); - - // let end_1 = start_1 + list_size - let end_1 = self.create_symbol(ident_ids, "end_1"); - let end_1_stmt = |next| { - let_lowlevel( - arena, - layout_isize, - end_1, - NumAdd, - &[start_1, list_size], - next, - ) - }; - - // - // Loop name & parameters - // - - let elems_loop = JoinPointId(self.create_symbol(ident_ids, "elems_loop")); - let addr1 = self.create_symbol(ident_ids, "addr1"); - let addr2 = self.create_symbol(ident_ids, "addr2"); - - let param_addr1 = Param { - symbol: addr1, - borrow: false, - layout: layout_isize, - }; - - let param_addr2 = Param { - symbol: addr2, - borrow: false, - layout: layout_isize, - }; - - // - // if we haven't reached the end yet... - // - - // Cast integers to box pointers - let box1 = self.create_symbol(ident_ids, "box1"); - let box2 = self.create_symbol(ident_ids, "box2"); - let box1_stmt = |next| let_lowlevel(arena, box_layout, box1, PtrCast, &[addr1], next); - let box2_stmt = |next| let_lowlevel(arena, box_layout, box2, PtrCast, &[addr2], next); - - // Dereference the box pointers to get the current elements - let elem1 = self.create_symbol(ident_ids, "elem1"); - let elem2 = self.create_symbol(ident_ids, "elem2"); - let elem1_expr = Expr::UnionAtIndex { - structure: box1, - union_layout: box_union_layout, - tag_id: 0, - index: 0, - }; - let elem2_expr = Expr::UnionAtIndex { - structure: box2, - union_layout: box_union_layout, - tag_id: 0, - index: 0, - }; - let elem1_stmt = |next| Stmt::Let(elem1, elem1_expr, *elem_layout, next); - let elem2_stmt = |next| Stmt::Let(elem2, elem2_expr, *elem_layout, next); - - // Compare the two current elements - let eq_elems = self.create_symbol(ident_ids, "eq_elems"); - let eq_elems_expr = self.call_specialized_op(ident_ids, ctx, *elem_layout, &[elem1, elem2]); - - let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next); - - // If current elements are equal, loop back again - let next_1 = self.create_symbol(ident_ids, "next_1"); - let next_2 = self.create_symbol(ident_ids, "next_2"); - let next_1_stmt = - |next| let_lowlevel(arena, layout_isize, next_1, NumAdd, &[addr1, size], next); - let next_2_stmt = - |next| let_lowlevel(arena, layout_isize, next_2, NumAdd, &[addr2, size], next); - - let jump_back = Stmt::Jump(elems_loop, self.arena.alloc([next_1, next_2])); - - // - // Control flow - // - - let is_end = self.create_symbol(ident_ids, "is_end"); - let is_end_stmt = - |next| let_lowlevel(arena, LAYOUT_BOOL, is_end, NumGte, &[addr1, end_1], next); - - let if_elems_not_equal = self.if_false_return_false( - eq_elems, - // else - self.arena.alloc( - // - next_1_stmt(self.arena.alloc( - // - next_2_stmt(self.arena.alloc( - // - jump_back, - )), - )), - ), - ); - - let if_end_of_list = Stmt::Switch { - cond_symbol: is_end, - cond_layout: LAYOUT_BOOL, - ret_layout: LAYOUT_BOOL, - branches: self - .arena - .alloc([(1, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))]), - default_branch: ( - BranchInfo::None, - self.arena.alloc( - // - box1_stmt(self.arena.alloc( - // - box2_stmt(self.arena.alloc( - // - elem1_stmt(self.arena.alloc( - // - elem2_stmt(self.arena.alloc( - // - eq_elems_stmt(self.arena.alloc( - // - if_elems_not_equal, - )), - )), - )), - )), - )), - ), - ), - }; - - let joinpoint_loop = Stmt::Join { - id: elems_loop, - parameters: self.arena.alloc([param_addr1, param_addr2]), - body: self.arena.alloc( - // - is_end_stmt( - // - self.arena.alloc(if_end_of_list), - ), - ), - remainder: self - .arena - .alloc(Stmt::Jump(elems_loop, self.arena.alloc([start_1, start_2]))), - }; - - let if_different_lengths = self.if_false_return_false( - eq_len, - // else - self.arena.alloc( - // - elements_1_stmt(self.arena.alloc( - // - elements_2_stmt(self.arena.alloc( - // - start_1_stmt(self.arena.alloc( - // - start_2_stmt(self.arena.alloc( - // - size_stmt(self.arena.alloc( - // - list_size_stmt(self.arena.alloc( - // - end_1_stmt(self.arena.alloc( - // - joinpoint_loop, - )), - )), - )), - )), - )), - )), - )), - ), - ); - - let pointers_else = len_1_stmt(self.arena.alloc( - // - len_2_stmt(self.arena.alloc( - // - eq_len_stmt(self.arena.alloc( - // - if_different_lengths, - )), - )), - )); - - self.if_pointers_equal_return_true( - ident_ids, - [ARG_1, ARG_2], - self.arena.alloc(pointers_else), - ) - } -} - -fn let_lowlevel<'a>( - arena: &'a Bump, - result_layout: Layout<'a>, - result: Symbol, - op: LowLevel, - arguments: &[Symbol], - next: &'a Stmt<'a>, -) -> Stmt<'a> { - Stmt::Let( - result, - Expr::Call(Call { - call_type: CallType::LowLevel { - op, - update_mode: UpdateModeId::BACKEND_DUMMY, - }, - arguments: arena.alloc_slice_copy(arguments), - }), - result_layout, - next, - ) -} - -fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool { - match layout { - Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => { - false - } - - Layout::Builtin(Builtin::Str) => { - // Str type can use either Zig functions or generated IR, since it's not generic. - // Eq uses a Zig function, refcount uses generated IR. - // Both are fine, they were just developed at different times. - matches!(op, HelperOp::Inc | HelperOp::Dec | HelperOp::DecRef) - } - - Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => true, - - Layout::Struct(fields) => !fields.is_empty(), - - Layout::Union(UnionLayout::NonRecursive(tags)) => !tags.is_empty(), - - Layout::Union(_) => true, - - Layout::LambdaSet(_) | Layout::RecursivePointer => false, - } -} diff --git a/compiler/mono/src/code_gen_help/equality.rs b/compiler/mono/src/code_gen_help/equality.rs new file mode 100644 index 0000000000..16fe3b1ba2 --- /dev/null +++ b/compiler/mono/src/code_gen_help/equality.rs @@ -0,0 +1,792 @@ +use bumpalo::collections::vec::Vec; +use roc_module::low_level::LowLevel; +use roc_module::symbol::{IdentIds, Symbol}; + +use crate::ir::{ + BranchInfo, Call, CallType, Expr, JoinPointId, Literal, Param, Stmt, UpdateModeId, +}; +use crate::layout::{Builtin, Layout, TagIdIntType, UnionLayout}; + +use super::{let_lowlevel, CodeGenHelp, Context, LAYOUT_BOOL}; + +const ARG_1: Symbol = Symbol::ARG_1; +const ARG_2: Symbol = Symbol::ARG_2; + +pub fn eq_generic<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + layout: Layout<'a>, +) -> Stmt<'a> { + let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout); + + let main_body = match layout { + Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => { + unreachable!( + "No generated proc for `==`. Use direct code gen for {:?}", + layout + ) + } + Layout::Builtin(Builtin::Str) => { + unreachable!("No generated helper proc for `==` on Str. Use Zig function.") + } + Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(), + Layout::Builtin(Builtin::List(elem_layout)) => eq_list(root, ident_ids, ctx, elem_layout), + Layout::Struct(field_layouts) => eq_struct(root, ident_ids, ctx, field_layouts), + Layout::Union(union_layout) => eq_tag_union(root, ident_ids, ctx, union_layout), + Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"), + Layout::RecursivePointer => { + unreachable!( + "Can't perform `==` on RecursivePointer. Should have been replaced by a tag union." + ) + } + }; + + Stmt::Let( + Symbol::BOOL_TRUE, + Expr::Literal(Literal::Int(1)), + LAYOUT_BOOL, + root.arena.alloc(Stmt::Let( + Symbol::BOOL_FALSE, + Expr::Literal(Literal::Int(0)), + LAYOUT_BOOL, + root.arena.alloc(main_body), + )), + ) +} + +fn if_pointers_equal_return_true<'a>( + root: &CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + operands: [Symbol; 2], + following: &'a Stmt<'a>, +) -> Stmt<'a> { + let ptr1_addr = root.create_symbol(ident_ids, "addr1"); + let ptr2_addr = root.create_symbol(ident_ids, "addr2"); + let ptr_eq = root.create_symbol(ident_ids, "eq_addr"); + + Stmt::Let( + ptr1_addr, + Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::PtrCast, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([operands[0]]), + }), + root.layout_isize, + root.arena.alloc(Stmt::Let( + ptr2_addr, + Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::PtrCast, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([operands[1]]), + }), + root.layout_isize, + root.arena.alloc(Stmt::Let( + ptr_eq, + Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::Eq, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([ptr1_addr, ptr2_addr]), + }), + LAYOUT_BOOL, + root.arena.alloc(Stmt::Switch { + cond_symbol: ptr_eq, + cond_layout: LAYOUT_BOOL, + branches: root.arena.alloc([( + 1, + BranchInfo::None, + Stmt::Ret(Symbol::BOOL_TRUE), + )]), + default_branch: (BranchInfo::None, following), + ret_layout: LAYOUT_BOOL, + }), + )), + )), + ) +} + +fn if_false_return_false<'a>( + root: &CodeGenHelp<'a>, + symbol: Symbol, + following: &'a Stmt<'a>, +) -> Stmt<'a> { + Stmt::Switch { + cond_symbol: symbol, + cond_layout: LAYOUT_BOOL, + branches: root + .arena + .alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]), + default_branch: (BranchInfo::None, following), + ret_layout: LAYOUT_BOOL, + } +} + +fn eq_struct<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + field_layouts: &'a [Layout<'a>], +) -> Stmt<'a> { + let mut else_stmt = Stmt::Ret(Symbol::BOOL_TRUE); + for (i, layout) in field_layouts.iter().enumerate().rev() { + let field1_sym = root.create_symbol(ident_ids, &format!("field_1_{}", i)); + let field1_expr = Expr::StructAtIndex { + index: i as u64, + field_layouts, + structure: ARG_1, + }; + let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next); + + let field2_sym = root.create_symbol(ident_ids, &format!("field_2_{}", i)); + let field2_expr = Expr::StructAtIndex { + index: i as u64, + field_layouts, + structure: ARG_2, + }; + let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next); + + let eq_call_expr = root.call_specialized_op( + ident_ids, + ctx, + *layout, + root.arena.alloc([field1_sym, field2_sym]), + ); + + let eq_call_name = format!("eq_call_{}", i); + let eq_call_sym = root.create_symbol(ident_ids, &eq_call_name); + let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next); + + else_stmt = field1_stmt(root.arena.alloc( + // + field2_stmt(root.arena.alloc( + // + eq_call_stmt(root.arena.alloc( + // + if_false_return_false(root, eq_call_sym, root.arena.alloc(else_stmt)), + )), + )), + )) + } + + if_pointers_equal_return_true(root, ident_ids, [ARG_1, ARG_2], root.arena.alloc(else_stmt)) +} + +fn eq_tag_union<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union_layout: UnionLayout<'a>, +) -> Stmt<'a> { + use UnionLayout::*; + + let parent_rec_ptr_layout = ctx.recursive_union; + if !matches!(union_layout, NonRecursive(_)) { + ctx.recursive_union = Some(union_layout); + } + + let body = match union_layout { + NonRecursive(tags) => eq_tag_union_help(root, ident_ids, ctx, union_layout, tags, None), + + Recursive(tags) => eq_tag_union_help(root, ident_ids, ctx, union_layout, tags, None), + + NonNullableUnwrapped(field_layouts) => { + let tags = root.arena.alloc([field_layouts]); + eq_tag_union_help(root, ident_ids, ctx, union_layout, tags, None) + } + + NullableWrapped { + other_tags, + nullable_id, + } => eq_tag_union_help( + root, + ident_ids, + ctx, + union_layout, + other_tags, + Some(nullable_id), + ), + + NullableUnwrapped { + other_fields, + nullable_id, + } => eq_tag_union_help( + root, + ident_ids, + ctx, + union_layout, + root.arena.alloc([other_fields]), + Some(nullable_id as TagIdIntType), + ), + }; + + ctx.recursive_union = parent_rec_ptr_layout; + + body +} + +fn eq_tag_union_help<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + union_layout: UnionLayout<'a>, + tag_layouts: &'a [&'a [Layout<'a>]], + nullable_id: Option, +) -> Stmt<'a> { + let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop")); + let is_non_recursive = matches!(union_layout, UnionLayout::NonRecursive(_)); + let operands = if is_non_recursive { + [ARG_1, ARG_2] + } else { + [ + root.create_symbol(ident_ids, "a"), + root.create_symbol(ident_ids, "b"), + ] + }; + + let tag_id_layout = union_layout.tag_id_layout(); + + let tag_id_a = root.create_symbol(ident_ids, "tag_id_a"); + let tag_id_a_stmt = |next| { + Stmt::Let( + tag_id_a, + Expr::GetTagId { + structure: operands[0], + union_layout, + }, + tag_id_layout, + next, + ) + }; + + let tag_id_b = root.create_symbol(ident_ids, "tag_id_b"); + let tag_id_b_stmt = |next| { + Stmt::Let( + tag_id_b, + Expr::GetTagId { + structure: operands[1], + union_layout, + }, + tag_id_layout, + next, + ) + }; + + let tag_ids_eq = root.create_symbol(ident_ids, "tag_ids_eq"); + let tag_ids_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::Eq, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([tag_id_a, tag_id_b]), + }); + let tag_ids_eq_stmt = |next| Stmt::Let(tag_ids_eq, tag_ids_expr, LAYOUT_BOOL, next); + + let if_equal_ids_branches = + root.arena + .alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]); + + // + // Switch statement by tag ID + // + + let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), root.arena); + + // If there's a null tag, check it first. We might not need to load any data from memory. + if let Some(id) = nullable_id { + tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))) + } + + let mut tag_id: TagIdIntType = 0; + for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) { + if let Some(null_id) = nullable_id { + if tag_id == null_id as TagIdIntType { + tag_id += 1; + } + } + + let tag_stmt = eq_tag_fields( + root, + ident_ids, + ctx, + tailrec_loop, + union_layout, + field_layouts, + operands, + tag_id, + ); + tag_branches.push((tag_id as u64, BranchInfo::None, tag_stmt)); + + tag_id += 1; + } + + let tag_switch_stmt = Stmt::Switch { + cond_symbol: tag_id_a, + cond_layout: tag_id_layout, + branches: tag_branches.into_bump_slice(), + default_branch: ( + BranchInfo::None, + root.arena.alloc(eq_tag_fields( + root, + ident_ids, + ctx, + tailrec_loop, + union_layout, + tag_layouts.last().unwrap(), + operands, + tag_id, + )), + ), + ret_layout: LAYOUT_BOOL, + }; + + let if_equal_ids_stmt = Stmt::Switch { + cond_symbol: tag_ids_eq, + cond_layout: LAYOUT_BOOL, + branches: if_equal_ids_branches, + default_branch: (BranchInfo::None, root.arena.alloc(tag_switch_stmt)), + ret_layout: LAYOUT_BOOL, + }; + + // + // combine all the statments + // + let compare_values = tag_id_a_stmt(root.arena.alloc( + // + tag_id_b_stmt(root.arena.alloc( + // + tag_ids_eq_stmt(root.arena.alloc( + // + if_equal_ids_stmt, + )), + )), + )); + + let compare_ptr_or_value = + if_pointers_equal_return_true(root, ident_ids, operands, root.arena.alloc(compare_values)); + + if is_non_recursive { + compare_ptr_or_value + } else { + let loop_params_iter = operands.iter().map(|arg| Param { + symbol: *arg, + borrow: true, + layout: Layout::Union(union_layout), + }); + + let loop_start = Stmt::Jump(tailrec_loop, root.arena.alloc([ARG_1, ARG_2])); + + Stmt::Join { + id: tailrec_loop, + parameters: root.arena.alloc_slice_fill_iter(loop_params_iter), + body: root.arena.alloc(compare_ptr_or_value), + remainder: root.arena.alloc(loop_start), + } + } +} + +#[allow(clippy::too_many_arguments)] +fn eq_tag_fields<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + tailrec_loop: JoinPointId, + union_layout: UnionLayout<'a>, + field_layouts: &'a [Layout<'a>], + operands: [Symbol; 2], + tag_id: TagIdIntType, +) -> Stmt<'a> { + // Find a RecursivePointer to use in the tail recursion loop + // (If there are more than one, the others will use non-tail recursion) + let rec_ptr_index = field_layouts + .iter() + .position(|field| matches!(field, Layout::RecursivePointer)); + + let (tailrec_index, innermost_stmt) = match rec_ptr_index { + None => { + // This tag has no RecursivePointers. Set tailrec_index out of range. + (field_layouts.len(), Stmt::Ret(Symbol::BOOL_TRUE)) + } + + Some(i) => { + // Implement tail recursion on this RecursivePointer, + // in the innermost `else` clause after all other fields have been checked + let field1_sym = root.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i)); + let field2_sym = root.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i)); + + let field1_expr = Expr::UnionAtIndex { + union_layout, + tag_id, + index: i as u64, + structure: operands[0], + }; + + let field2_expr = Expr::UnionAtIndex { + union_layout, + tag_id, + index: i as u64, + structure: operands[1], + }; + + let inner = Stmt::Let( + field1_sym, + field1_expr, + field_layouts[i], + root.arena.alloc( + // + Stmt::Let( + field2_sym, + field2_expr, + field_layouts[i], + root.arena.alloc( + // + Stmt::Jump(tailrec_loop, root.arena.alloc([field1_sym, field2_sym])), + ), + ), + ), + ); + + (i, inner) + } + }; + + let mut stmt = innermost_stmt; + for (i, layout) in field_layouts.iter().enumerate().rev() { + if i == tailrec_index { + continue; // the tail-recursive field is handled elsewhere + } + + let field1_sym = root.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i)); + let field2_sym = root.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i)); + + let field1_expr = Expr::UnionAtIndex { + union_layout, + tag_id, + index: i as u64, + structure: operands[0], + }; + + let field2_expr = Expr::UnionAtIndex { + union_layout, + tag_id, + index: i as u64, + structure: operands[1], + }; + + let eq_call_expr = root.call_specialized_op( + ident_ids, + ctx, + *layout, + root.arena.alloc([field1_sym, field2_sym]), + ); + + let eq_call_name = format!("eq_call_{}", i); + let eq_call_sym = root.create_symbol(ident_ids, &eq_call_name); + + stmt = Stmt::Let( + field1_sym, + field1_expr, + field_layouts[i], + root.arena.alloc( + // + Stmt::Let( + field2_sym, + field2_expr, + field_layouts[i], + root.arena.alloc( + // + Stmt::Let( + eq_call_sym, + eq_call_expr, + LAYOUT_BOOL, + root.arena.alloc( + // + if_false_return_false( + root, + eq_call_sym, + root.arena.alloc( + // + stmt, + ), + ), + ), + ), + ), + ), + ), + ) + } + stmt +} + +/// List equality +/// We can't use `ListGetUnsafe` because it increments the refcount, and we don't want that. +/// Another way to dereference a heap pointer is to use `Expr::UnionAtIndex`. +/// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout. +/// Then we can increment the Box pointer in a loop, dereferencing it each time. +/// (An alternative approach would be to create a new lowlevel like ListPeekUnsafe.) +fn eq_list<'a>( + root: &mut CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + elem_layout: &Layout<'a>, +) -> Stmt<'a> { + use LowLevel::*; + let layout_isize = root.layout_isize; + let arena = root.arena; + + // A "Box" layout (heap pointer to a single list element) + let box_union_layout = UnionLayout::NonNullableUnwrapped(root.arena.alloc([*elem_layout])); + let box_layout = Layout::Union(box_union_layout); + + // Compare lengths + + let len_1 = root.create_symbol(ident_ids, "len_1"); + let len_2 = root.create_symbol(ident_ids, "len_2"); + let len_1_stmt = |next| let_lowlevel(arena, layout_isize, len_1, ListLen, &[ARG_1], next); + let len_2_stmt = |next| let_lowlevel(arena, layout_isize, len_2, ListLen, &[ARG_2], next); + + let eq_len = root.create_symbol(ident_ids, "eq_len"); + let eq_len_stmt = |next| let_lowlevel(arena, LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next); + + // if lengths are equal... + + // get element pointers + let elements_1 = root.create_symbol(ident_ids, "elements_1"); + let elements_2 = root.create_symbol(ident_ids, "elements_2"); + let elements_1_expr = Expr::StructAtIndex { + index: 0, + field_layouts: root.arena.alloc([box_layout, layout_isize]), + structure: ARG_1, + }; + let elements_2_expr = Expr::StructAtIndex { + index: 0, + field_layouts: root.arena.alloc([box_layout, layout_isize]), + structure: ARG_2, + }; + let elements_1_stmt = |next| Stmt::Let(elements_1, elements_1_expr, box_layout, next); + let elements_2_stmt = |next| Stmt::Let(elements_2, elements_2_expr, box_layout, next); + + // Cast to integers + let start_1 = root.create_symbol(ident_ids, "start_1"); + let start_2 = root.create_symbol(ident_ids, "start_2"); + let start_1_stmt = + |next| let_lowlevel(arena, layout_isize, start_1, PtrCast, &[elements_1], next); + let start_2_stmt = + |next| let_lowlevel(arena, layout_isize, start_2, PtrCast, &[elements_2], next); + + // + // Loop initialisation + // + + // let size = literal int + let size = root.create_symbol(ident_ids, "size"); + let size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(root.ptr_size) as i128)); + let size_stmt = |next| Stmt::Let(size, size_expr, layout_isize, next); + + // let list_size = len_1 * size + let list_size = root.create_symbol(ident_ids, "list_size"); + let list_size_stmt = + |next| let_lowlevel(arena, layout_isize, list_size, NumMul, &[len_1, size], next); + + // let end_1 = start_1 + list_size + let end_1 = root.create_symbol(ident_ids, "end_1"); + let end_1_stmt = |next| { + let_lowlevel( + arena, + layout_isize, + end_1, + NumAdd, + &[start_1, list_size], + next, + ) + }; + + // + // Loop name & parameters + // + + let elems_loop = JoinPointId(root.create_symbol(ident_ids, "elems_loop")); + let addr1 = root.create_symbol(ident_ids, "addr1"); + let addr2 = root.create_symbol(ident_ids, "addr2"); + + let param_addr1 = Param { + symbol: addr1, + borrow: false, + layout: layout_isize, + }; + + let param_addr2 = Param { + symbol: addr2, + borrow: false, + layout: layout_isize, + }; + + // + // if we haven't reached the end yet... + // + + // Cast integers to box pointers + let box1 = root.create_symbol(ident_ids, "box1"); + let box2 = root.create_symbol(ident_ids, "box2"); + let box1_stmt = |next| let_lowlevel(arena, box_layout, box1, PtrCast, &[addr1], next); + let box2_stmt = |next| let_lowlevel(arena, box_layout, box2, PtrCast, &[addr2], next); + + // Dereference the box pointers to get the current elements + let elem1 = root.create_symbol(ident_ids, "elem1"); + let elem2 = root.create_symbol(ident_ids, "elem2"); + let elem1_expr = Expr::UnionAtIndex { + structure: box1, + union_layout: box_union_layout, + tag_id: 0, + index: 0, + }; + let elem2_expr = Expr::UnionAtIndex { + structure: box2, + union_layout: box_union_layout, + tag_id: 0, + index: 0, + }; + let elem1_stmt = |next| Stmt::Let(elem1, elem1_expr, *elem_layout, next); + let elem2_stmt = |next| Stmt::Let(elem2, elem2_expr, *elem_layout, next); + + // Compare the two current elements + let eq_elems = root.create_symbol(ident_ids, "eq_elems"); + let eq_elems_expr = root.call_specialized_op(ident_ids, ctx, *elem_layout, &[elem1, elem2]); + + let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next); + + // If current elements are equal, loop back again + let next_1 = root.create_symbol(ident_ids, "next_1"); + let next_2 = root.create_symbol(ident_ids, "next_2"); + let next_1_stmt = + |next| let_lowlevel(arena, layout_isize, next_1, NumAdd, &[addr1, size], next); + let next_2_stmt = + |next| let_lowlevel(arena, layout_isize, next_2, NumAdd, &[addr2, size], next); + + let jump_back = Stmt::Jump(elems_loop, root.arena.alloc([next_1, next_2])); + + // + // Control flow + // + + let is_end = root.create_symbol(ident_ids, "is_end"); + let is_end_stmt = + |next| let_lowlevel(arena, LAYOUT_BOOL, is_end, NumGte, &[addr1, end_1], next); + + let if_elems_not_equal = if_false_return_false( + root, + eq_elems, + // else + root.arena.alloc( + // + next_1_stmt(root.arena.alloc( + // + next_2_stmt(root.arena.alloc( + // + jump_back, + )), + )), + ), + ); + + let if_end_of_list = Stmt::Switch { + cond_symbol: is_end, + cond_layout: LAYOUT_BOOL, + ret_layout: LAYOUT_BOOL, + branches: root + .arena + .alloc([(1, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))]), + default_branch: ( + BranchInfo::None, + root.arena.alloc( + // + box1_stmt(root.arena.alloc( + // + box2_stmt(root.arena.alloc( + // + elem1_stmt(root.arena.alloc( + // + elem2_stmt(root.arena.alloc( + // + eq_elems_stmt(root.arena.alloc( + // + if_elems_not_equal, + )), + )), + )), + )), + )), + ), + ), + }; + + let joinpoint_loop = Stmt::Join { + id: elems_loop, + parameters: root.arena.alloc([param_addr1, param_addr2]), + body: root.arena.alloc( + // + is_end_stmt( + // + root.arena.alloc(if_end_of_list), + ), + ), + remainder: root + .arena + .alloc(Stmt::Jump(elems_loop, root.arena.alloc([start_1, start_2]))), + }; + + let if_different_lengths = if_false_return_false( + root, + eq_len, + // else + root.arena.alloc( + // + elements_1_stmt(root.arena.alloc( + // + elements_2_stmt(root.arena.alloc( + // + start_1_stmt(root.arena.alloc( + // + start_2_stmt(root.arena.alloc( + // + size_stmt(root.arena.alloc( + // + list_size_stmt(root.arena.alloc( + // + end_1_stmt(root.arena.alloc( + // + joinpoint_loop, + )), + )), + )), + )), + )), + )), + )), + ), + ); + + let pointers_else = len_1_stmt(root.arena.alloc( + // + len_2_stmt(root.arena.alloc( + // + eq_len_stmt(root.arena.alloc( + // + if_different_lengths, + )), + )), + )); + + if_pointers_equal_return_true( + root, + ident_ids, + [ARG_1, ARG_2], + root.arena.alloc(pointers_else), + ) +} diff --git a/compiler/mono/src/code_gen_help/mod.rs b/compiler/mono/src/code_gen_help/mod.rs new file mode 100644 index 0000000000..f5943ba935 --- /dev/null +++ b/compiler/mono/src/code_gen_help/mod.rs @@ -0,0 +1,439 @@ +use bumpalo::collections::vec::Vec; +use bumpalo::Bump; +use roc_builtins::bitcode::IntWidth; +use roc_module::ident::Ident; +use roc_module::low_level::LowLevel; +use roc_module::symbol::{IdentIds, ModuleId, Symbol}; + +use crate::ir::{ + Call, CallSpecId, CallType, Expr, HostExposedLayouts, Literal, ModifyRc, Proc, ProcLayout, + SelfRecursive, Stmt, UpdateModeId, +}; +use crate::layout::{Builtin, Layout, UnionLayout}; + +mod equality; +mod refcount; + +const LAYOUT_BOOL: Layout = Layout::Builtin(Builtin::Bool); +const LAYOUT_UNIT: Layout = Layout::Struct(&[]); + +const ARG_1: Symbol = Symbol::ARG_1; +const ARG_2: Symbol = Symbol::ARG_2; + +/// "Infinite" reference count, for static values +/// Ref counts are encoded as negative numbers where isize::MIN represents 1 +pub const REFCOUNT_MAX: usize = 0; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum HelperOp { + Inc, + Dec, + DecRef, + Eq, +} + +impl From<&ModifyRc> for HelperOp { + fn from(modify: &ModifyRc) -> Self { + match modify { + ModifyRc::Inc(..) => Self::Inc, + ModifyRc::Dec(_) => Self::Dec, + ModifyRc::DecRef(_) => Self::DecRef, + } + } +} + +#[derive(Debug)] +struct Specialization<'a> { + op: HelperOp, + layout: Layout<'a>, + symbol: Symbol, + proc: Option>, +} + +#[derive(Debug)] +pub struct Context<'a> { + new_linker_data: Vec<'a, (Symbol, ProcLayout<'a>)>, + recursive_union: Option>, + op: HelperOp, +} + +/// Generate specialized helper procs for code gen +/// ---------------------------------------------- +/// +/// Some low level operations need specialized helper procs to traverse data structures at runtime. +/// This includes refcounting, hashing, and equality checks. +/// +/// For example, when checking List equality, we need to visit each element and compare them. +/// Depending on the type of the list elements, we may need to recurse deeper into each element. +/// For tag unions, we may need branches for different tag IDs, etc. +/// +/// This module creates specialized helper procs for all such operations and types used in the program. +/// +/// The backend drives the process, in two steps: +/// 1) When it sees the relevant node, it calls CodeGenHelp to get the replacement IR. +/// CodeGenHelp returns IR for a call to the helper proc, and remembers the specialization. +/// 2) After the backend has generated code for all user procs, it takes the IR for all of the +/// specialized helpers procs, and generates target code for them too. +/// +pub struct CodeGenHelp<'a> { + arena: &'a Bump, + home: ModuleId, + ptr_size: u32, + layout_isize: Layout<'a>, + specializations: Vec<'a, Specialization<'a>>, + debug_recursion_depth: usize, +} + +impl<'a> CodeGenHelp<'a> { + pub fn new(arena: &'a Bump, intwidth_isize: IntWidth, home: ModuleId) -> Self { + CodeGenHelp { + arena, + home, + ptr_size: intwidth_isize.stack_size(), + layout_isize: Layout::Builtin(Builtin::Int(intwidth_isize)), + specializations: Vec::with_capacity_in(16, arena), + debug_recursion_depth: 0, + } + } + + pub fn take_procs(&mut self) -> Vec<'a, Proc<'a>> { + let procs_iter = self + .specializations + .drain(0..) + .map(|spec| spec.proc.unwrap()); + Vec::from_iter_in(procs_iter, self.arena) + } + + // ============================================================================ + // + // CALL GENERATED PROCS + // + // ============================================================================ + + /// Expand a `Refcounting` node to a `Let` node that calls a specialized helper proc. + /// The helper procs themselves are to be generated later with `generate_procs` + pub fn expand_refcount_stmt( + &mut self, + ident_ids: &mut IdentIds, + layout: Layout<'a>, + modify: &ModifyRc, + following: &'a Stmt<'a>, + ) -> (&'a Stmt<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) { + if !refcount::is_rc_implemented_yet(&layout) { + // Just a warning, so we can decouple backend development from refcounting development. + // When we are closer to completion, we can change it to a panic. + println!( + "WARNING! MEMORY LEAK! Refcounting not yet implemented for Layout {:?}", + layout + ); + return (following, Vec::new_in(self.arena)); + } + + let arena = self.arena; + + let mut ctx = Context { + new_linker_data: Vec::new_in(self.arena), + recursive_union: None, + op: HelperOp::from(modify), + }; + + match modify { + ModifyRc::Inc(structure, amount) => { + let layout_isize = self.layout_isize; + + // Define a constant for the amount to increment + let amount_sym = self.create_symbol(ident_ids, "amount"); + let amount_expr = Expr::Literal(Literal::Int(*amount as i128)); + let amount_stmt = |next| Stmt::Let(amount_sym, amount_expr, layout_isize, next); + + // Call helper proc, passing the Roc structure and constant amount + let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); + let call_expr = self.call_specialized_op( + ident_ids, + &mut ctx, + layout, + arena.alloc([*structure, amount_sym]), + ); + let call_stmt = Stmt::Let(call_result_empty, call_expr, LAYOUT_UNIT, following); + let rc_stmt = arena.alloc(amount_stmt(arena.alloc(call_stmt))); + + (rc_stmt, ctx.new_linker_data) + } + + ModifyRc::Dec(structure) => { + // Call helper proc, passing the Roc structure + let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); + let call_expr = self.call_specialized_op( + ident_ids, + &mut ctx, + layout, + arena.alloc([*structure]), + ); + + let rc_stmt = arena.alloc(Stmt::Let( + call_result_empty, + call_expr, + LAYOUT_UNIT, + following, + )); + + (rc_stmt, ctx.new_linker_data) + } + + ModifyRc::DecRef(structure) => { + // No generated procs for DecRef, just lowlevel ops + let rc_ptr_sym = self.create_symbol(ident_ids, "rc_ptr"); + + // Pass the refcount pointer to the lowlevel call (see utils.zig) + let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); + let call_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::RefCountDec, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: arena.alloc([rc_ptr_sym]), + }); + let call_stmt = Stmt::Let(call_result_empty, call_expr, LAYOUT_UNIT, following); + + let rc_stmt = arena.alloc(refcount::rc_ptr_from_struct( + self, + ident_ids, + *structure, + rc_ptr_sym, + arena.alloc(call_stmt), + )); + + (rc_stmt, ctx.new_linker_data) + } + } + } + + /// Replace a generic `Lowlevel::Eq` call with a specialized helper proc. + /// The helper procs themselves are to be generated later with `generate_procs` + pub fn call_specialized_equals( + &mut self, + ident_ids: &mut IdentIds, + layout: &Layout<'a>, + arguments: &'a [Symbol], + ) -> (Expr<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) { + let mut ctx = Context { + new_linker_data: Vec::new_in(self.arena), + recursive_union: None, + op: HelperOp::Eq, + }; + + let expr = self.call_specialized_op(ident_ids, &mut ctx, *layout, arguments); + + (expr, ctx.new_linker_data) + } + + // ============================================================================ + // + // CALL SPECIALIZED OP + // + // ============================================================================ + + fn call_specialized_op( + &mut self, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + called_layout: Layout<'a>, + arguments: &[Symbol], + ) -> Expr<'a> { + use HelperOp::*; + + debug_assert!(self.debug_recursion_depth < 10); + self.debug_recursion_depth += 1; + + let layout = if matches!(called_layout, Layout::RecursivePointer) { + let union_layout = ctx.recursive_union.unwrap(); + Layout::Union(union_layout) + } else { + called_layout + }; + + if layout_needs_helper_proc(&layout, ctx.op) { + let proc_name = self.find_or_create_proc(ident_ids, ctx, layout); + + let (ret_layout, arg_layouts): (&'a Layout<'a>, &'a [Layout<'a>]) = { + match ctx.op { + Dec | DecRef => (&LAYOUT_UNIT, self.arena.alloc([layout])), + Inc => (&LAYOUT_UNIT, self.arena.alloc([layout, self.layout_isize])), + Eq => (&LAYOUT_BOOL, self.arena.alloc([layout, layout])), + } + }; + + Expr::Call(Call { + call_type: CallType::ByName { + name: proc_name, + ret_layout, + arg_layouts, + specialization_id: CallSpecId::BACKEND_DUMMY, + }, + arguments: self.arena.alloc_slice_copy(arguments), + }) + } else { + Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::Eq, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: self.arena.alloc_slice_copy(arguments), + }) + } + } + + fn find_or_create_proc( + &mut self, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + layout: Layout<'a>, + ) -> Symbol { + use HelperOp::*; + + let found = self + .specializations + .iter() + .find(|spec| spec.op == ctx.op && spec.layout == layout); + + if let Some(spec) = found { + return spec.symbol; + } + + // Procs can be recursive, so we need to create the symbol before the body is complete + // But with nested recursion, that means Symbols and Procs can end up in different orders. + // We want the same order, especially for function indices in Wasm. So create an empty slot and fill it in later. + let (proc_symbol, proc_layout) = self.create_proc_symbol(ident_ids, ctx, &layout); + ctx.new_linker_data.push((proc_symbol, proc_layout)); + let spec_index = self.specializations.len(); + self.specializations.push(Specialization { + op: ctx.op, + layout, + symbol: proc_symbol, + proc: None, + }); + + // Recursively generate the body of the Proc and sub-procs + let (ret_layout, body) = match ctx.op { + Inc | Dec | DecRef => ( + LAYOUT_UNIT, + refcount::refcount_generic(self, ident_ids, ctx, layout), + ), + Eq => ( + LAYOUT_BOOL, + equality::eq_generic(self, ident_ids, ctx, layout), + ), + }; + + let args: &'a [(Layout<'a>, Symbol)] = { + let roc_value = (layout, ARG_1); + match ctx.op { + Inc => { + let inc_amount = (self.layout_isize, ARG_2); + self.arena.alloc([roc_value, inc_amount]) + } + Dec | DecRef => self.arena.alloc([roc_value]), + Eq => self.arena.alloc([roc_value, (layout, ARG_2)]), + } + }; + + self.specializations[spec_index].proc = Some(Proc { + name: proc_symbol, + args, + body, + closure_data_layout: None, + ret_layout, + is_self_recursive: SelfRecursive::NotSelfRecursive, + must_own_arguments: false, + host_exposed_layouts: HostExposedLayouts::NotHostExposed, + }); + + proc_symbol + } + + fn create_proc_symbol( + &self, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + layout: &Layout<'a>, + ) -> (Symbol, ProcLayout<'a>) { + let debug_name = format!( + "#help{}_{:?}_{:?}", + self.specializations.len(), + ctx.op, + layout + ) + .replace("Builtin", ""); + let proc_symbol: Symbol = self.create_symbol(ident_ids, &debug_name); + + let proc_layout = match ctx.op { + HelperOp::Inc => ProcLayout { + arguments: self.arena.alloc([*layout, self.layout_isize]), + result: LAYOUT_UNIT, + }, + HelperOp::Dec => ProcLayout { + arguments: self.arena.alloc([*layout]), + result: LAYOUT_UNIT, + }, + HelperOp::DecRef => unreachable!("No generated Proc for DecRef"), + HelperOp::Eq => ProcLayout { + arguments: self.arena.alloc([*layout, *layout]), + result: LAYOUT_BOOL, + }, + }; + + (proc_symbol, proc_layout) + } + + fn create_symbol(&self, ident_ids: &mut IdentIds, debug_name: &str) -> Symbol { + let ident_id = ident_ids.add(Ident::from(debug_name)); + Symbol::new(self.home, ident_id) + } +} + +fn let_lowlevel<'a>( + arena: &'a Bump, + result_layout: Layout<'a>, + result: Symbol, + op: LowLevel, + arguments: &[Symbol], + next: &'a Stmt<'a>, +) -> Stmt<'a> { + Stmt::Let( + result, + Expr::Call(Call { + call_type: CallType::LowLevel { + op, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: arena.alloc_slice_copy(arguments), + }), + result_layout, + next, + ) +} + +fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool { + match layout { + Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => { + false + } + + Layout::Builtin(Builtin::Str) => { + // Str type can use either Zig functions or generated IR, since it's not generic. + // Eq uses a Zig function, refcount uses generated IR. + // Both are fine, they were just developed at different times. + matches!(op, HelperOp::Inc | HelperOp::Dec | HelperOp::DecRef) + } + + Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => true, + + Layout::Struct(fields) => !fields.is_empty(), + + Layout::Union(UnionLayout::NonRecursive(tags)) => !tags.is_empty(), + + Layout::Union(_) => true, + + Layout::LambdaSet(_) | Layout::RecursivePointer => false, + } +} diff --git a/compiler/mono/src/code_gen_help/refcount.rs b/compiler/mono/src/code_gen_help/refcount.rs new file mode 100644 index 0000000000..f5f0e8cfe6 --- /dev/null +++ b/compiler/mono/src/code_gen_help/refcount.rs @@ -0,0 +1,236 @@ +use roc_builtins::bitcode::IntWidth; +use roc_module::low_level::LowLevel; +use roc_module::symbol::{IdentIds, Symbol}; + +use crate::ir::{BranchInfo, Call, CallType, Expr, Literal, Stmt, UpdateModeId}; +use crate::layout::{Builtin, Layout}; + +use super::{CodeGenHelp, Context, HelperOp}; + +const LAYOUT_BOOL: Layout = Layout::Builtin(Builtin::Bool); +const LAYOUT_UNIT: Layout = Layout::Struct(&[]); +const LAYOUT_PTR: Layout = Layout::RecursivePointer; +const LAYOUT_U32: Layout = Layout::Builtin(Builtin::Int(IntWidth::U32)); + +const ARG_1: Symbol = Symbol::ARG_1; +const ARG_2: Symbol = Symbol::ARG_2; + +pub fn refcount_generic<'a>( + root: &CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, + layout: Layout<'a>, +) -> Stmt<'a> { + debug_assert!(is_rc_implemented_yet(&layout)); + let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout); + + match layout { + Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => { + unreachable!("Not refcounted: {:?}", layout) + } + Layout::Builtin(Builtin::Str) => refcount_str(root, ident_ids, ctx), + Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => rc_todo(), + Layout::Struct(_) => rc_todo(), + Layout::Union(_) => rc_todo(), + Layout::LambdaSet(_) => { + unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.") + } + Layout::RecursivePointer => rc_todo(), + } +} + +// Check if refcounting is implemented yet. In the long term, this will be deleted. +// In the short term, it helps us to skip refcounting and let it leak, so we can make +// progress incrementally. Kept in sync with generate_procs using assertions. +pub fn is_rc_implemented_yet(layout: &Layout) -> bool { + matches!(layout, Layout::Builtin(Builtin::Str)) +} + +fn return_unit<'a>(root: &CodeGenHelp<'a>, ident_ids: &mut IdentIds) -> Stmt<'a> { + let unit = root.create_symbol(ident_ids, "unit"); + let ret_stmt = root.arena.alloc(Stmt::Ret(unit)); + Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt) +} + +// Subtract a constant from a pointer to find the refcount +// Also does some type casting, so that we have different Symbols and Layouts +// for the 'pointer' and 'integer' versions of the address. +// This helps to avoid issues with the backends Symbol->Layout mapping. +pub fn rc_ptr_from_struct<'a>( + root: &CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + structure: Symbol, + rc_ptr_sym: Symbol, + following: &'a Stmt<'a>, +) -> Stmt<'a> { + // Typecast the structure pointer to an integer + // Backends expect a number Layout to choose the right "subtract" instruction + let addr_sym = root.create_symbol(ident_ids, "addr"); + let addr_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::PtrCast, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([structure]), + }); + let addr_stmt = |next| Stmt::Let(addr_sym, addr_expr, root.layout_isize, next); + + // Pointer size constant + let ptr_size_sym = root.create_symbol(ident_ids, "ptr_size"); + let ptr_size_expr = Expr::Literal(Literal::Int(root.ptr_size as i128)); + let ptr_size_stmt = |next| Stmt::Let(ptr_size_sym, ptr_size_expr, root.layout_isize, next); + + // Refcount address + let rc_addr_sym = root.create_symbol(ident_ids, "rc_addr"); + let rc_addr_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::NumSub, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([structure, ptr_size_sym]), + }); + let rc_addr_stmt = |next| Stmt::Let(rc_addr_sym, rc_addr_expr, root.layout_isize, next); + + // Typecast the refcount address from integer to pointer + let rc_ptr_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::PtrCast, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([rc_addr_sym]), + }); + let rc_ptr_stmt = |next| Stmt::Let(rc_ptr_sym, rc_ptr_expr, LAYOUT_PTR, next); + + addr_stmt(root.arena.alloc( + // + ptr_size_stmt(root.arena.alloc( + // + rc_addr_stmt(root.arena.alloc( + // + rc_ptr_stmt(root.arena.alloc( + // + following, + )), + )), + )), + )) +} + +/// Generate a procedure to modify the reference count of a Str +fn refcount_str<'a>( + root: &CodeGenHelp<'a>, + ident_ids: &mut IdentIds, + ctx: &mut Context<'a>, +) -> Stmt<'a> { + let op = ctx.op; + + let string = ARG_1; + let layout_isize = root.layout_isize; + + // Get the string length as a signed int + let len = root.create_symbol(ident_ids, "len"); + let len_expr = Expr::StructAtIndex { + index: 1, + field_layouts: root.arena.alloc([LAYOUT_PTR, layout_isize]), + structure: string, + }; + let len_stmt = |next| Stmt::Let(len, len_expr, layout_isize, next); + + // Zero + let zero = root.create_symbol(ident_ids, "zero"); + let zero_expr = Expr::Literal(Literal::Int(0)); + let zero_stmt = |next| Stmt::Let(zero, zero_expr, layout_isize, next); + + // is_big_str = (len >= 0); + // Treat len as isize so that the small string flag is the same as the sign bit + let is_big_str = root.create_symbol(ident_ids, "is_big_str"); + let is_big_str_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::NumGte, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([len, zero]), + }); + let is_big_str_stmt = |next| Stmt::Let(is_big_str, is_big_str_expr, LAYOUT_BOOL, next); + + // Get the pointer to the string elements + let elements = root.create_symbol(ident_ids, "elements"); + let elements_expr = Expr::StructAtIndex { + index: 0, + field_layouts: root.arena.alloc([LAYOUT_PTR, layout_isize]), + structure: string, + }; + let elements_stmt = |next| Stmt::Let(elements, elements_expr, layout_isize, next); + + // A pointer to the refcount value itself + let rc_ptr = root.create_symbol(ident_ids, "rc_ptr"); + + // Alignment constant (same value as ptr_size but different layout) + let alignment = root.create_symbol(ident_ids, "alignment"); + let alignment_expr = Expr::Literal(Literal::Int(root.ptr_size as i128)); + let alignment_stmt = |next| Stmt::Let(alignment, alignment_expr, LAYOUT_U32, next); + + // Call the relevant Zig lowlevel to actually modify the refcount + let zig_call_result = root.create_symbol(ident_ids, "zig_call_result"); + let zig_call_expr = match op { + HelperOp::Inc => Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::RefCountInc, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([rc_ptr, ARG_2]), + }), + HelperOp::Dec | HelperOp::DecRef => Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::RefCountDec, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: root.arena.alloc([rc_ptr, alignment]), + }), + _ => unreachable!(), + }; + let zig_call_stmt = |next| Stmt::Let(zig_call_result, zig_call_expr, LAYOUT_UNIT, next); + + // Generate an `if` to skip small strings but modify big strings + let then_branch = elements_stmt(root.arena.alloc( + // + rc_ptr_from_struct( + root, + ident_ids, + elements, + rc_ptr, + root.arena.alloc( + // + alignment_stmt(root.arena.alloc( + // + zig_call_stmt(root.arena.alloc( + // + Stmt::Ret(zig_call_result), + )), + )), + ), + ), + )); + let if_stmt = Stmt::Switch { + cond_symbol: is_big_str, + cond_layout: LAYOUT_BOOL, + branches: root.arena.alloc([(1, BranchInfo::None, then_branch)]), + default_branch: ( + BranchInfo::None, + root.arena.alloc(return_unit(root, ident_ids)), + ), + ret_layout: LAYOUT_UNIT, + }; + + // Combine the statements in sequence + len_stmt(root.arena.alloc( + // + zero_stmt(root.arena.alloc( + // + is_big_str_stmt(root.arena.alloc( + // + if_stmt, + )), + )), + )) +}