diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 97f2d62839..eace9af756 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -14,9 +14,9 @@ use roc_mono::ir::{ use roc_mono::layout::{Builtin, Layout, LayoutIds, TagIdIntType, UnionLayout}; use roc_reporting::internal_error; -use crate::layout::{CallConv, ReturnMethod, WasmLayout}; -use crate::low_level::{decode_low_level, LowlevelBuildResult}; -use crate::storage::{Storage, StoredValue, StoredValueKind}; +use crate::layout::{CallConv, ReturnMethod, StackMemoryFormat, WasmLayout}; +use crate::low_level::{dispatch_low_level, LowlevelBuildResult}; +use crate::storage::{StackMemoryLocation, Storage, StoredValue, StoredValueKind}; use crate::wasm_module::linking::{ DataSymbol, LinkingSection, RelocationSection, WasmObjectSymbol, WASM_SYM_BINDING_WEAK, WASM_SYM_UNDEFINED, @@ -272,6 +272,7 @@ impl<'a> WasmBackend<'a> { self.start_block(BlockType::from(ret_type)); for (layout, symbol) in proc.args { + self.symbol_layouts.insert(*symbol, *layout); let arg_layout = WasmLayout::new(layout); self.storage .allocate(&arg_layout, *symbol, StoredValueKind::Parameter); @@ -480,6 +481,8 @@ impl<'a> WasmBackend<'a> { // make locals for join pointer parameters let mut jp_param_storages = Vec::with_capacity_in(parameters.len(), self.env.arena); for parameter in parameters.iter() { + self.symbol_layouts + .insert(parameter.symbol, parameter.layout); let wasm_layout = WasmLayout::new(¶meter.layout); let mut param_storage = self.storage.allocate( &wasm_layout, @@ -645,21 +648,34 @@ impl<'a> WasmBackend<'a> { field_layouts, structure, } => { - if let StoredValue::StackMemory { location, .. } = self.storage.get(structure) { - let (local_id, mut offset) = - location.local_and_offset(self.storage.stack_frame_pointer); - for field in field_layouts.iter().take(*index as usize) { - offset += field.stack_size(PTR_SIZE); + self.storage.ensure_value_has_local( + &mut self.code_builder, + *sym, + storage.to_owned(), + ); + let (local_id, mut offset) = match self.storage.get(structure) { + StoredValue::StackMemory { location, .. } => { + location.local_and_offset(self.storage.stack_frame_pointer) } - self.storage.copy_value_from_memory( - &mut self.code_builder, - *sym, + + StoredValue::Local { + value_type, local_id, - offset, - ); - } else { - internal_error!("Unexpected storage for {:?}", structure) + .. + } => { + debug_assert!(matches!(value_type, ValueType::I32)); + (*local_id, 0) + } + + StoredValue::VirtualMachineStack { .. } => { + internal_error!("ensure_value_has_local didn't work") + } + }; + for field in field_layouts.iter().take(*index as usize) { + offset += field.stack_size(PTR_SIZE); } + self.storage + .copy_value_from_memory(&mut self.code_builder, *sym, local_id, offset); } Expr::Array { elems, elem_layout } => { @@ -1024,77 +1040,247 @@ impl<'a> WasmBackend<'a> { return_layout: WasmLayout, storage: &StoredValue, ) { - let (param_types, ret_type) = self.storage.load_symbols_for_call( - self.env.arena, - &mut self.code_builder, - arguments, - return_sym, - &return_layout, - CallConv::Zig, - ); + use LowLevel::*; - let build_result = decode_low_level( - &mut self.code_builder, - &mut self.storage, - lowlevel, - arguments, - &return_layout, - ); - use LowlevelBuildResult::*; - - match build_result { - Done => {} - BuiltinCall(name) => { - self.call_zig_builtin(name, param_types, ret_type); + match lowlevel { + Eq | NotEq => self.build_eq(lowlevel, arguments, return_sym, return_layout, storage), + PtrCast => { + // Don't want Zig calling convention when casting pointers. + self.storage.load_symbols(&mut self.code_builder, arguments); } - SpecializedEq | SpecializedNotEq => { - let layout = self.symbol_layouts[&arguments[0]]; - let layout_rhs = self.symbol_layouts[&arguments[1]]; - debug_assert!( - layout == layout_rhs, - "Cannot do `==` comparison on different types" + Hash => todo!("Generic hash function generation"), + + // Almost all lowlevels take this branch, except for the special cases above + _ => { + // Load the arguments using Zig calling convention + let (param_types, ret_type) = self.storage.load_symbols_for_call( + self.env.arena, + &mut self.code_builder, + arguments, + return_sym, + &return_layout, + CallConv::Zig, ); - if layout == Layout::Builtin(Builtin::Str) { - self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type); - } else if layout.stack_size(PTR_SIZE) == 0 { - // If the layout has zero size, and it type-checks, the values must be equal - let value = matches!(build_result, SpecializedEq); - self.code_builder.i32_const(value as i32); - return; - } else { - let ident_ids = self - .interns - .all_ident_ids - .get_mut(&self.env.module_id) - .unwrap(); + // Generate instructions OR decide which Zig function to call + let build_result = dispatch_low_level( + &mut self.code_builder, + &mut self.storage, + lowlevel, + arguments, + &return_layout, + ); - let (replacement_expr, new_specializations) = self - .helper_proc_gen - .specialize_equals(ident_ids, &layout, arguments); - - // If any new specializations were created, register their symbol data - for spec in new_specializations.into_iter() { - self.register_helper_proc(spec); + // Handle the result + use LowlevelBuildResult::*; + match build_result { + Done => {} + BuiltinCall(name) => { + self.call_zig_builtin(name, param_types, ret_type); + } + NotImplemented => { + todo!("Low level operation {:?}", lowlevel) } - - let bool_layout = Layout::Builtin(Builtin::Bool); - self.build_expr(&return_sym, replacement_expr, &bool_layout, storage); } + } + } + } - if matches!(build_result, SpecializedNotEq) { + fn build_eq( + &mut self, + lowlevel: LowLevel, + arguments: &'a [Symbol], + return_sym: Symbol, + return_layout: WasmLayout, + storage: &StoredValue, + ) { + let arg_layout = self.symbol_layouts[&arguments[0]]; + let other_arg_layout = self.symbol_layouts[&arguments[1]]; + debug_assert!( + arg_layout == other_arg_layout, + "Cannot do `==` comparison on different types" + ); + + match arg_layout { + Layout::Builtin( + Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, + ) => self.build_eq_number(lowlevel, arguments, return_layout), + + Layout::Builtin(Builtin::Str) => { + let (param_types, ret_type) = self.storage.load_symbols_for_call( + self.env.arena, + &mut self.code_builder, + arguments, + return_sym, + &return_layout, + CallConv::Zig, + ); + self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type); + if matches!(lowlevel, LowLevel::NotEq) { self.code_builder.i32_eqz(); } } - SpecializedHash => { - todo!("Specialized hash functions") + + Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) + | Layout::Struct(_) + | Layout::Union(_) + | Layout::LambdaSet(_) => { + if arg_layout.stack_size(PTR_SIZE) == 0 { + // A zero-size type has only one possible value, like `{}` or `Unit` + // The arguments don't exist at runtime. Just emit True (Eq) or False (NotEq). + let result = matches!(lowlevel, LowLevel::Eq); + self.code_builder.i32_const(result as i32); + } else { + self.build_eq_specialized(&arg_layout, arguments, return_sym, storage); + if matches!(lowlevel, LowLevel::NotEq) { + self.code_builder.i32_eqz(); + } + } } - NotImplemented => { - todo!("Low level operation {:?}", lowlevel) + + Layout::RecursivePointer => { + internal_error!("`==` on RecursivePointer should be converted to the parent layout") } } } + fn build_eq_number( + &mut self, + lowlevel: LowLevel, + arguments: &'a [Symbol], + return_layout: WasmLayout, + ) { + use StoredValue::*; + match self.storage.get(&arguments[0]).to_owned() { + VirtualMachineStack { value_type, .. } | Local { value_type, .. } => { + self.storage.load_symbols(&mut self.code_builder, arguments); + match lowlevel { + LowLevel::Eq => match value_type { + ValueType::I32 => self.code_builder.i32_eq(), + ValueType::I64 => self.code_builder.i64_eq(), + ValueType::F32 => self.code_builder.f32_eq(), + ValueType::F64 => self.code_builder.f64_eq(), + }, + LowLevel::NotEq => match value_type { + ValueType::I32 => self.code_builder.i32_ne(), + ValueType::I64 => self.code_builder.i64_ne(), + ValueType::F32 => self.code_builder.f32_ne(), + ValueType::F64 => self.code_builder.f64_ne(), + }, + _ => internal_error!("Low-level op {:?} handled in the wrong place", lowlevel), + } + } + StackMemory { + format, + location: location0, + .. + } => { + if let StackMemory { + location: location1, + .. + } = self.storage.get(&arguments[1]).to_owned() + { + self.build_eq_num128(format, [location0, location1], arguments, return_layout); + if matches!(lowlevel, LowLevel::NotEq) { + self.code_builder.i32_eqz(); + } + } + } + } + } + + fn build_eq_num128( + &mut self, + format: StackMemoryFormat, + locations: [StackMemoryLocation; 2], + arguments: &'a [Symbol], + return_layout: WasmLayout, + ) { + match format { + StackMemoryFormat::Decimal => { + // Both args are finite + let first = [arguments[0]]; + let second = [arguments[1]]; + dispatch_low_level( + &mut self.code_builder, + &mut self.storage, + LowLevel::NumIsFinite, + &first, + &return_layout, + ); + dispatch_low_level( + &mut self.code_builder, + &mut self.storage, + LowLevel::NumIsFinite, + &second, + &return_layout, + ); + self.code_builder.i32_and(); + + // AND they have the same bytes + self.build_eq_num128_bytes(locations); + self.code_builder.i32_and(); + } + + StackMemoryFormat::Int128 => self.build_eq_num128_bytes(locations), + + StackMemoryFormat::Float128 => todo!("equality for f128"), + + StackMemoryFormat::DataStructure => { + internal_error!("Data structure equality is handled elsewhere") + } + } + } + + /// Check that two 128-bit numbers contain the same bytes + fn build_eq_num128_bytes(&mut self, locations: [StackMemoryLocation; 2]) { + let (local0, offset0) = locations[0].local_and_offset(self.storage.stack_frame_pointer); + let (local1, offset1) = locations[1].local_and_offset(self.storage.stack_frame_pointer); + + self.code_builder.get_local(local0); + self.code_builder.i64_load(Align::Bytes8, offset0); + self.code_builder.get_local(local1); + self.code_builder.i64_load(Align::Bytes8, offset1); + self.code_builder.i64_eq(); + + self.code_builder.get_local(local0); + self.code_builder.i64_load(Align::Bytes8, offset0 + 8); + self.code_builder.get_local(local1); + self.code_builder.i64_load(Align::Bytes8, offset1 + 8); + self.code_builder.i64_eq(); + + self.code_builder.i32_and(); + } + + /// Call a helper procedure that implements `==` for a specific data structure + fn build_eq_specialized( + &mut self, + arg_layout: &Layout<'a>, + arguments: &'a [Symbol], + return_sym: Symbol, + storage: &StoredValue, + ) { + let ident_ids = self + .interns + .all_ident_ids + .get_mut(&self.env.module_id) + .unwrap(); + + // Get an IR expression for the call to the specialized procedure + let (specialized_call_expr, new_specializations) = self + .helper_proc_gen + .call_specialized_equals(ident_ids, arg_layout, arguments); + + // If any new specializations were created, register their symbol data + for spec in new_specializations.into_iter() { + self.register_helper_proc(spec); + } + + // Generate Wasm code for the IR call expression + let bool_layout = Layout::Builtin(Builtin::Bool); + self.build_expr(&return_sym, specialized_call_expr, &bool_layout, storage); + } + fn load_literal( &mut self, lit: &Literal<'a>, diff --git a/compiler/gen_wasm/src/low_level.rs b/compiler/gen_wasm/src/low_level.rs index 914b8241bd..cc5cff66f6 100644 --- a/compiler/gen_wasm/src/low_level.rs +++ b/compiler/gen_wasm/src/low_level.rs @@ -11,13 +11,10 @@ use crate::wasm_module::{Align, CodeBuilder, ValueType::*}; pub enum LowlevelBuildResult { Done, BuiltinCall(&'static str), - SpecializedEq, - SpecializedNotEq, - SpecializedHash, NotImplemented, } -pub fn decode_low_level<'a>( +pub fn dispatch_low_level<'a>( code_builder: &mut CodeBuilder<'a>, storage: &mut Storage<'a>, lowlevel: LowLevel, @@ -525,109 +522,15 @@ pub fn decode_low_level<'a>( WasmLayout::StackMemory { .. } => return NotImplemented, } } - Eq => { - use StoredValue::*; - match storage.get(&args[0]).to_owned() { - VirtualMachineStack { value_type, .. } | Local { value_type, .. } => { - match value_type { - I32 => code_builder.i32_eq(), - I64 => code_builder.i64_eq(), - F32 => code_builder.f32_eq(), - F64 => code_builder.f64_eq(), - } - } - StackMemory { - format, - location: location0, - .. - } => { - if let StackMemory { - location: location1, - .. - } = storage.get(&args[1]).to_owned() - { - let stack_frame_pointer = storage.stack_frame_pointer; - let compare_bytes = |code_builder: &mut CodeBuilder| { - let (local0, offset0) = location0.local_and_offset(stack_frame_pointer); - let (local1, offset1) = location1.local_and_offset(stack_frame_pointer); - - code_builder.get_local(local0); - code_builder.i64_load(Align::Bytes8, offset0); - code_builder.get_local(local1); - code_builder.i64_load(Align::Bytes8, offset1); - code_builder.i64_eq(); - - code_builder.get_local(local0); - code_builder.i64_load(Align::Bytes8, offset0 + 8); - code_builder.get_local(local1); - code_builder.i64_load(Align::Bytes8, offset1 + 8); - code_builder.i64_eq(); - - code_builder.i32_and(); - }; - - match format { - Decimal => { - // Both args are finite - let first = [args[0]]; - let second = [args[1]]; - decode_low_level( - code_builder, - storage, - LowLevel::NumIsFinite, - &first, - ret_layout, - ); - decode_low_level( - code_builder, - storage, - LowLevel::NumIsFinite, - &second, - ret_layout, - ); - code_builder.i32_and(); - - // AND they have the same bytes - compare_bytes(code_builder); - code_builder.i32_and(); - } - Int128 => compare_bytes(code_builder), - Float128 => return NotImplemented, - DataStructure => return SpecializedEq, - } - } - } - } - } - NotEq => match storage.get(&args[0]) { - StoredValue::VirtualMachineStack { value_type, .. } - | StoredValue::Local { value_type, .. } => match value_type { - I32 => code_builder.i32_ne(), - I64 => code_builder.i64_ne(), - F32 => code_builder.f32_ne(), - F64 => code_builder.f64_ne(), - }, - StoredValue::StackMemory { format, .. } => { - if matches!(format, DataStructure) { - return SpecializedNotEq; - } else { - decode_low_level(code_builder, storage, LowLevel::Eq, args, ret_layout); - code_builder.i32_eqz(); - } - } - }, And => code_builder.i32_and(), Or => code_builder.i32_or(), Not => code_builder.i32_eqz(), - Hash => return SpecializedHash, ExpectTrue => return NotImplemented, - PtrCast => { - // We don't need any instructions here, since we've already loaded the value. - // PtrCast just creates separate Symbols and Layouts for the argument and return value. - // This is used for pointer math in refcounting and for pointer equality - } RefCountInc => return BuiltinCall(bitcode::UTILS_INCREF), RefCountDec => return BuiltinCall(bitcode::UTILS_DECREF), + Eq | NotEq | Hash | PtrCast => { + internal_error!("{:?} should be handled in backend.rs", lowlevel) + } } Done } diff --git a/compiler/gen_wasm/src/storage.rs b/compiler/gen_wasm/src/storage.rs index a46c25766f..65e71ed9b2 100644 --- a/compiler/gen_wasm/src/storage.rs +++ b/compiler/gen_wasm/src/storage.rs @@ -319,9 +319,11 @@ impl<'a> Storage<'a> { code_builder.i64_load(align, offset); } else if *size <= 12 && BUILTINS_ZIG_VERSION == ZigVersion::Zig9 { code_builder.i64_load(align, offset); + code_builder.get_local(local_id); code_builder.i32_load(align, offset + 8); } else { code_builder.i64_load(align, offset); + code_builder.get_local(local_id); code_builder.i64_load(align, offset + 8); } } diff --git a/compiler/mono/src/code_gen_help.rs b/compiler/mono/src/code_gen_help.rs index 0cdeb7a933..535ea7265b 100644 --- a/compiler/mono/src/code_gen_help.rs +++ b/compiler/mono/src/code_gen_help.rs @@ -196,7 +196,7 @@ impl<'a> CodeGenHelp<'a> { /// Replace a generic `Lowlevel::Eq` call with a specialized helper proc. /// The helper procs themselves are to be generated later with `generate_procs` - pub fn specialize_equals( + pub fn call_specialized_equals( &mut self, ident_ids: &mut IdentIds, layout: &Layout<'a>, @@ -690,7 +690,7 @@ impl<'a> CodeGenHelp<'a> { } Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(), Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts), - Layout::Union(_) => eq_todo(), + Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout), Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"), Layout::RecursivePointer => eq_todo(), }; @@ -715,9 +715,9 @@ impl<'a> CodeGenHelp<'a> { ptr2: Symbol, following: &'a Stmt<'a>, ) -> Stmt<'a> { - let ptr1_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr1)); - let ptr2_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr2)); - let ptr_eq = self.create_symbol(ident_ids, &format!("eq_{:?}_{:?}", ptr1_addr, ptr2_addr)); + let ptr1_addr = self.create_symbol(ident_ids, "addr1"); + let ptr2_addr = self.create_symbol(ident_ids, "addr2"); + let ptr_eq = self.create_symbol(ident_ids, "eq_addr"); Stmt::Let( ptr1_addr, @@ -780,7 +780,9 @@ impl<'a> CodeGenHelp<'a> { fn eq_struct(&self, ident_ids: &mut IdentIds, field_layouts: &'a [Layout<'a>]) -> Stmt<'a> { let else_clause = self.eq_fields( ident_ids, + 0, field_layouts, + None, &[Symbol::ARG_1, Symbol::ARG_2], Stmt::Ret(Symbol::BOOL_TRUE), ); @@ -795,14 +797,15 @@ impl<'a> CodeGenHelp<'a> { fn eq_fields( &self, ident_ids: &mut IdentIds, + tag_id: u64, field_layouts: &'a [Layout<'a>], + rec_ptr_layout: Option>, arguments: &'a [Symbol], following: Stmt<'a>, ) -> Stmt<'a> { let mut stmt = following; for (i, layout) in field_layouts.iter().enumerate().rev() { - let field1_name = format!("{:?}_field_{}", arguments[0], i); - let field1_sym = self.create_symbol(ident_ids, &field1_name); + let field1_sym = self.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i)); let field1_expr = Expr::StructAtIndex { index: i as u64, field_layouts, @@ -810,8 +813,7 @@ impl<'a> CodeGenHelp<'a> { }; let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next); - let field2_name = format!("{:?}_field_{}", arguments[1], i); - let field2_sym = self.create_symbol(ident_ids, &field2_name); + let field2_sym = self.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i)); let field2_expr = Expr::StructAtIndex { index: i as u64, field_layouts, @@ -820,7 +822,13 @@ impl<'a> CodeGenHelp<'a> { let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next); let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]); - let eq_call_expr = self.apply_op_to_sub_layout(HelperOp::Eq, layout, sub_layout_args); + let sub_layout = match (layout, rec_ptr_layout) { + (Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout), + _ => layout, + }; + + let eq_call_expr = + self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, sub_layout_args); let eq_call_name = format!("eq_call_{}", i); let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name); let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next); @@ -838,6 +846,181 @@ impl<'a> CodeGenHelp<'a> { } stmt } + + fn eq_tag_union(&self, ident_ids: &mut IdentIds, union_layout: UnionLayout<'a>) -> Stmt<'a> { + use UnionLayout::*; + + let main_stmt = match union_layout { + NonRecursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None), + + Recursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None), + + NonNullableUnwrapped(field_layouts) => self.eq_fields( + ident_ids, + 0, + field_layouts, + Some(Layout::Union(union_layout)), + &[Symbol::ARG_1, Symbol::ARG_2], + Stmt::Ret(Symbol::BOOL_TRUE), + ), + + NullableWrapped { + other_tags, + nullable_id, + } => self.eq_tag_union_help(ident_ids, union_layout, other_tags, Some(nullable_id)), + + NullableUnwrapped { + other_fields, + nullable_id: n, + } => self.eq_tag_union_help( + ident_ids, + union_layout, + self.arena.alloc([other_fields]), + Some(n as u16), + ), + }; + + self.if_pointers_equal_return_true( + ident_ids, + Symbol::ARG_1, + Symbol::ARG_2, + self.arena.alloc(main_stmt), + ) + } + + fn eq_tag_union_help( + &self, + ident_ids: &mut IdentIds, + union_layout: UnionLayout<'a>, + tag_layouts: &'a [&'a [Layout<'a>]], + nullable_id: Option, + ) -> Stmt<'a> { + let tag_id_layout = union_layout.tag_id_layout(); + + let tag_id_a = self.create_symbol(ident_ids, "tag_id_a"); + let tag_id_a_stmt = |next| { + Stmt::Let( + tag_id_a, + Expr::GetTagId { + structure: Symbol::ARG_1, + union_layout, + }, + tag_id_layout, + next, + ) + }; + + let tag_id_b = self.create_symbol(ident_ids, "tag_id_b"); + let tag_id_b_stmt = |next| { + Stmt::Let( + tag_id_b, + Expr::GetTagId { + structure: Symbol::ARG_2, + union_layout, + }, + tag_id_layout, + next, + ) + }; + + let tag_ids_eq = self.create_symbol(ident_ids, "tag_ids_eq"); + let tag_ids_eq_stmt = |next| { + Stmt::Let( + tag_ids_eq, + Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::Eq, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: self.arena.alloc([tag_id_a, tag_id_b]), + }), + LAYOUT_BOOL, + next, + ) + }; + + let if_equal_ids_stmt = |next| Stmt::Switch { + cond_symbol: tag_ids_eq, + cond_layout: LAYOUT_BOOL, + branches: self + .arena + .alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]), + default_branch: (BranchInfo::None, next), + ret_layout: LAYOUT_BOOL, + }; + + // + // Switch statement by tag ID + // + + let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), self.arena); + + // If there's a null tag, check it first. We might not need to load any data from memory. + if let Some(id) = nullable_id { + tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))) + } + + let recursive_ptr_layout = Some(Layout::Union(union_layout)); + + let mut tag_id: u64 = 0; + for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) { + if let Some(null_id) = nullable_id { + if tag_id == null_id as u64 { + tag_id += 1; + } + } + + tag_branches.push(( + tag_id, + BranchInfo::None, + self.eq_fields( + ident_ids, + tag_id, + field_layouts, + recursive_ptr_layout, + &[Symbol::ARG_1, Symbol::ARG_2], + Stmt::Ret(Symbol::BOOL_TRUE), + ), + )); + + tag_id += 1; + } + + let tag_switch_stmt = Stmt::Switch { + cond_symbol: tag_id_a, + cond_layout: tag_id_layout, + branches: tag_branches.into_bump_slice(), + default_branch: ( + BranchInfo::None, + self.arena.alloc(self.eq_fields( + ident_ids, + tag_id, + tag_layouts.last().unwrap(), + recursive_ptr_layout, + &[Symbol::ARG_1, Symbol::ARG_2], + Stmt::Ret(Symbol::BOOL_TRUE), + )), + ), + ret_layout: LAYOUT_BOOL, + }; + + // + // combine all the statments + // + tag_id_a_stmt(self.arena.alloc( + // + tag_id_b_stmt(self.arena.alloc( + // + tag_ids_eq_stmt(self.arena.alloc( + // + if_equal_ids_stmt(self.arena.alloc( + // + tag_switch_stmt, + )), + )), + )), + )) + } } /// Helper to derive a debug function name from a layout @@ -865,7 +1048,7 @@ fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool { Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) | Layout::Struct(_) | Layout::Union(_) - | Layout::LambdaSet(_) - | Layout::RecursivePointer => true, + | Layout::LambdaSet(_) => true, + Layout::RecursivePointer => false, } } diff --git a/compiler/test_gen/src/gen_compare.rs b/compiler/test_gen/src/gen_compare.rs index 4286610086..02f8364566 100644 --- a/compiler/test_gen/src/gen_compare.rs +++ b/compiler/test_gen/src/gen_compare.rs @@ -179,7 +179,7 @@ fn record() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn unit() { assert_evals_to!("Unit == Unit", true, bool); assert_evals_to!("Unit != Unit", false, bool); @@ -231,7 +231,7 @@ fn large_str() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn eq_result_tag_true() { assert_evals_to!( indoc!( @@ -251,7 +251,7 @@ fn eq_result_tag_true() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn eq_result_tag_false() { assert_evals_to!( indoc!( @@ -271,7 +271,7 @@ fn eq_result_tag_false() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn eq_expr() { assert_evals_to!( indoc!( @@ -293,7 +293,7 @@ fn eq_expr() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn eq_linked_list() { assert_evals_to!( indoc!( @@ -351,7 +351,7 @@ fn eq_linked_list() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn eq_linked_list_false() { assert_evals_to!( indoc!( @@ -373,7 +373,7 @@ fn eq_linked_list_false() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn eq_nullable_expr() { assert_evals_to!( indoc!( @@ -502,7 +502,7 @@ fn list_neq_nested() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn compare_union_same_content() { assert_evals_to!( indoc!( @@ -524,7 +524,7 @@ fn compare_union_same_content() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn compare_recursive_union_same_content() { assert_evals_to!( indoc!( @@ -546,7 +546,7 @@ fn compare_recursive_union_same_content() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn compare_nullable_recursive_union_same_content() { assert_evals_to!( indoc!(