diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 1fec3347d1..97f2d62839 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -1049,9 +1049,19 @@ impl<'a> WasmBackend<'a> { } SpecializedEq | SpecializedNotEq => { let layout = self.symbol_layouts[&arguments[0]]; + let layout_rhs = self.symbol_layouts[&arguments[1]]; + debug_assert!( + layout == layout_rhs, + "Cannot do `==` comparison on different types" + ); if layout == Layout::Builtin(Builtin::Str) { self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type); + } else if layout.stack_size(PTR_SIZE) == 0 { + // If the layout has zero size, and it type-checks, the values must be equal + let value = matches!(build_result, SpecializedEq); + self.code_builder.i32_const(value as i32); + return; } else { let ident_ids = self .interns @@ -1068,7 +1078,8 @@ impl<'a> WasmBackend<'a> { self.register_helper_proc(spec); } - self.build_expr(&return_sym, replacement_expr, &layout, storage); + let bool_layout = Layout::Builtin(Builtin::Bool); + self.build_expr(&return_sym, replacement_expr, &bool_layout, storage); } if matches!(build_result, SpecializedNotEq) { diff --git a/compiler/gen_wasm/src/low_level.rs b/compiler/gen_wasm/src/low_level.rs index 0ccd8d96af..db6af39b11 100644 --- a/compiler/gen_wasm/src/low_level.rs +++ b/compiler/gen_wasm/src/low_level.rs @@ -7,6 +7,7 @@ use crate::layout::{StackMemoryFormat::*, WasmLayout}; use crate::storage::{Storage, StoredValue}; use crate::wasm_module::{Align, CodeBuilder, ValueType::*}; +#[derive(Debug)] pub enum LowlevelBuildResult { Done, BuiltinCall(&'static str), diff --git a/compiler/mono/src/code_gen_help.rs b/compiler/mono/src/code_gen_help.rs index b09906f136..cb2c418730 100644 --- a/compiler/mono/src/code_gen_help.rs +++ b/compiler/mono/src/code_gen_help.rs @@ -227,7 +227,7 @@ impl<'a> CodeGenHelp<'a> { // ============================================================================ // - // TRAVERSE LAYOUT & CREATE PROC NAMES + // CREATE SPECIALIZATIONS // // ============================================================================ @@ -270,7 +270,9 @@ impl<'a> CodeGenHelp<'a> { new_procs_info.push((symbol, proc_layout)); let mut visit_child = |child| { - self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child); + if layout_needs_helper_proc(child, op) { + self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child); + } }; let mut visit_children = |children: &'a [Layout]| { @@ -369,82 +371,97 @@ impl<'a> CodeGenHelp<'a> { ) -> Vec<'a, Proc<'a>> { use HelperOp::*; - // Move the vector out of self, so we can loop over it safely - let mut specs = std::mem::replace(&mut self.specs, Vec::with_capacity_in(0, arena)); + // Clone the specializations so we can loop over them safely + // We need to keep self.specs for lookups of sub-procedures during generation + // Maybe could avoid this by separating specs vector from CodeGenHelp, letting backend own both. + let mut specs = self.specs.clone(); - let procs_iter = specs.drain(0..).map(|(layout, op, proc_symbol)| match op { - Inc | Dec | DecRef => { - debug_assert!(Self::is_rc_implemented_yet(&layout)); - let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout); + let procs_iter = specs.drain(0..).map(|(layout, op, proc_symbol)| { + let (ret_layout, body) = match op { + Inc | Dec | DecRef => (LAYOUT_UNIT, self.refcount_generic(ident_ids, layout, op)), + Eq => (LAYOUT_BOOL, self.eq_generic(ident_ids, layout)), + }; - match layout { - Layout::Builtin( - Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, - ) => unreachable!("Not refcounted: {:?}", layout), - Layout::Builtin(Builtin::Str) => { - self.gen_modify_str(ident_ids, op, proc_symbol) - } - Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => { - rc_todo() - } - Layout::Struct(_) => rc_todo(), - Layout::Union(union_layout) => match union_layout { - UnionLayout::NonRecursive(_) => rc_todo(), - UnionLayout::Recursive(_) => rc_todo(), - UnionLayout::NonNullableUnwrapped(_) => rc_todo(), - UnionLayout::NullableWrapped { .. } => rc_todo(), - UnionLayout::NullableUnwrapped { .. } => rc_todo(), - }, - Layout::LambdaSet(_) => unreachable!( - "Refcounting on LambdaSet is invalid. Should be a Union at runtime." - ), - Layout::RecursivePointer => rc_todo(), + let roc_value = (layout, Symbol::ARG_1); + let args: &'a [(Layout<'a>, Symbol)] = match op { + HelperOp::Inc => { + let inc_amount = (self.layout_isize, Symbol::ARG_2); + self.arena.alloc([roc_value, inc_amount]) } - } - Eq => { - let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout); + HelperOp::Dec | HelperOp::DecRef => self.arena.alloc([roc_value]), + HelperOp::Eq => self.arena.alloc([roc_value, (layout, Symbol::ARG_2)]), + }; - match layout { - Layout::Builtin( - Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, - ) => unreachable!( - "No generated proc for `==`. Use direct code gen for {:?}", - layout - ), - Layout::Builtin(Builtin::Str) => { - unreachable!("No generated helper proc for `==` on Str. Use Zig function.") - } - Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => { - eq_todo() - } - Layout::Struct(_) => eq_todo(), - Layout::Union(union_layout) => { - self.eq_tag_union(ident_ids, proc_symbol, union_layout) - } - Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"), - Layout::RecursivePointer => eq_todo(), - } + Proc { + name: proc_symbol, + args, + body, + closure_data_layout: None, + ret_layout, + is_self_recursive: SelfRecursive::NotSelfRecursive, + must_own_arguments: false, + host_exposed_layouts: HostExposedLayouts::NotHostExposed, } }); Vec::from_iter_in(procs_iter, arena) } - fn return_unit(&self, ident_ids: &mut IdentIds) -> Stmt<'a> { - let unit = self.create_symbol(ident_ids, "unit"); - let ret_stmt = self.arena.alloc(Stmt::Ret(unit)); - Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt) - } + /// Apply the HelperOp to a field of a data structure + /// Only called while generating bodies of helper procs + /// The list of specializations should be complete by this time + fn apply_op_to_sub_layout( + &mut self, + op: HelperOp, + sub_layout: &Layout<'a>, + arguments: &'a [Symbol], + ) -> Expr<'a> { + let found = self + .specs + .iter() + .find(|(l, o, _)| l == sub_layout && *o == op); - fn gen_args(&self, op: HelperOp, layout: Layout<'a>) -> &'a [(Layout<'a>, Symbol)] { - let roc_value = (layout, Symbol::ARG_1); - match op { - HelperOp::Inc => { - let inc_amount = (self.layout_isize, Symbol::ARG_2); - self.arena.alloc([roc_value, inc_amount]) - } - HelperOp::Dec | HelperOp::DecRef => self.arena.alloc([roc_value]), - HelperOp::Eq => self.arena.alloc([roc_value, (layout, Symbol::ARG_2)]), + if let Some((_, _, proc_name)) = found { + let arg_layouts: &[Layout<'a>] = match op { + HelperOp::Eq => self.arena.alloc([*sub_layout, *sub_layout]), + HelperOp::Inc => self.arena.alloc([*sub_layout, self.layout_isize]), + HelperOp::Dec => self.arena.alloc([*sub_layout]), + HelperOp::DecRef => unreachable!("DecRef is not recursive"), + }; + let ret_layout = if matches!(op, HelperOp::Eq) { + &LAYOUT_BOOL + } else { + &LAYOUT_UNIT + }; + + Expr::Call(Call { + call_type: CallType::ByName { + name: *proc_name, + ret_layout, + arg_layouts, + specialization_id: CallSpecId::BACKEND_DUMMY, + }, + arguments, + }) + } else { + // By the time we get here (generating helper procs), the list of specializations is complete. + // So if we didn't find one, we must be at a leaf of the layout tree. + debug_assert!(!layout_needs_helper_proc(sub_layout, op)); + + let lowlevel = match op { + HelperOp::Eq => LowLevel::Eq, + HelperOp::Inc => LowLevel::RefCountInc, + HelperOp::Dec => LowLevel::RefCountDec, + HelperOp::DecRef => unreachable!("DecRef is not recursive"), + }; + + Expr::Call(Call { + call_type: CallType::LowLevel { + op: lowlevel, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments, + }) } } @@ -454,13 +471,38 @@ impl<'a> CodeGenHelp<'a> { // // ============================================================================ - /// Generate a procedure to modify the reference count of a Str - fn gen_modify_str( + fn refcount_generic( &mut self, ident_ids: &mut IdentIds, + layout: Layout<'a>, op: HelperOp, - proc_name: Symbol, - ) -> Proc<'a> { + ) -> Stmt<'a> { + debug_assert!(Self::is_rc_implemented_yet(&layout)); + let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout); + + match layout { + Layout::Builtin( + Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, + ) => unreachable!("Not refcounted: {:?}", layout), + Layout::Builtin(Builtin::Str) => self.refcount_str(ident_ids, op), + Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => rc_todo(), + Layout::Struct(_) => rc_todo(), + Layout::Union(_) => rc_todo(), + Layout::LambdaSet(_) => { + unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.") + } + Layout::RecursivePointer => rc_todo(), + } + } + + fn return_unit(&self, ident_ids: &mut IdentIds) -> Stmt<'a> { + let unit = self.create_symbol(ident_ids, "unit"); + let ret_stmt = self.arena.alloc(Stmt::Ret(unit)); + Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt) + } + + /// Generate a procedure to modify the reference count of a Str + fn refcount_str(&mut self, ident_ids: &mut IdentIds, op: HelperOp) -> Stmt<'a> { let string = Symbol::ARG_1; let layout_isize = self.layout_isize; @@ -562,7 +604,7 @@ impl<'a> CodeGenHelp<'a> { }; // Combine the statements in sequence - let body = len_stmt(self.arena.alloc( + len_stmt(self.arena.alloc( // zero_stmt(self.arena.alloc( // @@ -571,20 +613,7 @@ impl<'a> CodeGenHelp<'a> { if_stmt, )), )), - )); - - let args = self.gen_args(op, Layout::Builtin(Builtin::Str)); - - Proc { - name: proc_name, - args, - body, - closure_data_layout: None, - ret_layout: LAYOUT_UNIT, - is_self_recursive: SelfRecursive::NotSelfRecursive, - must_own_arguments: false, - host_exposed_layouts: HostExposedLayouts::NotHostExposed, - } + )) } // ============================================================================ @@ -593,14 +622,103 @@ impl<'a> CodeGenHelp<'a> { // // ============================================================================ - fn eq_tag_union( + fn eq_generic(&mut self, ident_ids: &mut IdentIds, layout: Layout<'a>) -> Stmt<'a> { + let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout); + + let arguments = &[Symbol::ARG_1, Symbol::ARG_2]; + + let main_body = match layout { + Layout::Builtin( + Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, + ) => unreachable!( + "No generated proc for `==`. Use direct code gen for {:?}", + layout + ), + Layout::Builtin(Builtin::Str) => { + unreachable!("No generated helper proc for `==` on Str. Use Zig function.") + } + Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(), + Layout::Struct(field_layouts) => self.eq_struct( + ident_ids, + field_layouts, + arguments, + Stmt::Ret(Symbol::BOOL_TRUE), + ), + Layout::Union(_) => eq_todo(), + Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"), + Layout::RecursivePointer => eq_todo(), + }; + + Stmt::Let( + Symbol::BOOL_TRUE, + Expr::Literal(Literal::Int(1)), + LAYOUT_BOOL, + self.arena.alloc(Stmt::Let( + Symbol::BOOL_FALSE, + Expr::Literal(Literal::Int(0)), + LAYOUT_BOOL, + self.arena.alloc(main_body), + )), + ) + } + + fn if_false_return_false(&mut self, symbol: Symbol, following: &'a Stmt<'a>) -> Stmt<'a> { + Stmt::Switch { + cond_symbol: symbol, + cond_layout: LAYOUT_BOOL, + branches: self + .arena + .alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]), + default_branch: (BranchInfo::None, following), + ret_layout: LAYOUT_BOOL, + } + } + + fn eq_struct( &mut self, ident_ids: &mut IdentIds, - proc_name: Symbol, - union_layout: UnionLayout<'a>, - ) -> Proc<'a> { - // - todo!("return something") + field_layouts: &'a [Layout<'a>], + arguments: &'a [Symbol], + following: Stmt<'a>, + ) -> Stmt<'a> { + let mut stmt = following; + for (i, layout) in field_layouts.iter().enumerate().rev() { + let field1_name = format!("{:?}_field_{}", arguments[0], i); + let field1_sym = self.create_symbol(ident_ids, &field1_name); + let field1_expr = Expr::StructAtIndex { + index: i as u64, + field_layouts, + structure: arguments[0], + }; + let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next); + + let field2_name = format!("{:?}_field_{}", arguments[1], i); + let field2_sym = self.create_symbol(ident_ids, &field2_name); + let field2_expr = Expr::StructAtIndex { + index: i as u64, + field_layouts, + structure: arguments[1], + }; + let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next); + + let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]); + let eq_call_expr = self.apply_op_to_sub_layout(HelperOp::Eq, layout, sub_layout_args); + let eq_call_name = format!("eq_call_{}", i); + let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name); + let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next); + + stmt = field1_stmt(self.arena.alloc( + // + field2_stmt(self.arena.alloc( + // + eq_call_stmt(self.arena.alloc( + // + self.if_false_return_false(eq_call_sym, self.arena.alloc(stmt)), + )), + )), + )) + } + stmt } } @@ -611,13 +729,25 @@ fn layout_debug_name<'a>(layout: &Layout<'a>) -> &'static str { Layout::Builtin(Builtin::Set(_)) => "set", Layout::Builtin(Builtin::Dict(_, _)) => "dict", Layout::Builtin(Builtin::Str) => "str", - Layout::Builtin(builtin) => { - debug_assert!(!builtin.is_refcounted()); - unreachable!("Builtin {:?} is not refcounted", builtin); - } Layout::Struct(_) => "struct", Layout::Union(_) => "union", Layout::LambdaSet(_) => "lambdaset", - Layout::RecursivePointer => "recursive_pointer", + _ => unreachable!("Can't create helper proc name for {:?}", layout), + } +} + +fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool { + match layout { + Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => { + false + } + Layout::Builtin(Builtin::Str) => { + matches!(op, HelperOp::Inc | HelperOp::Dec | HelperOp::DecRef) + } + Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) + | Layout::Struct(_) + | Layout::Union(_) + | Layout::LambdaSet(_) + | Layout::RecursivePointer => true, } } diff --git a/compiler/test_gen/src/gen_compare.rs b/compiler/test_gen/src/gen_compare.rs index f58ce79ff9..4286610086 100644 --- a/compiler/test_gen/src/gen_compare.rs +++ b/compiler/test_gen/src/gen_compare.rs @@ -146,12 +146,38 @@ fn neq_bool_tag() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn empty_record() { assert_evals_to!("{} == {}", true, bool); assert_evals_to!("{} != {}", false, bool); } +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] +fn record() { + assert_evals_to!( + "{ x: 123, y: \"Hello\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }", + true, + bool + ); + + assert_evals_to!( + "{ x: 234, y: \"Hello\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }", + false, + bool + ); + assert_evals_to!( + "{ x: 123, y: \"World\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }", + false, + bool + ); + assert_evals_to!( + "{ x: 123, y: \"Hello\", z: 1.11 } == { x: 123, y: \"Hello\", z: 3.14 }", + false, + bool + ); +} + #[test] #[cfg(any(feature = "gen-llvm"))] fn unit() {