diff --git a/compiler/gen/src/llvm/compare.rs b/compiler/gen/src/llvm/compare.rs index 4fad64b031..a422966aee 100644 --- a/compiler/gen/src/llvm/compare.rs +++ b/compiler/gen/src/llvm/compare.rs @@ -74,6 +74,17 @@ pub fn build_eq<'a, 'ctx, 'env>( } } } + (Layout::Struct(f1), Layout::Struct(f2)) => { + debug_assert_eq!(f1, f2); + + build_struct_eq( + env, + layout_ids, + f1, + lhs_val.into_struct_value(), + rhs_val.into_struct_value(), + ) + } (other1, other2) => { // TODO NOTE: This should ultimately have a _ => todo!("type mismatch!") branch todo!("implement equals for layouts {:?} == {:?}", other1, other2); @@ -156,6 +167,21 @@ pub fn build_neq<'a, 'ctx, 'env>( } } } + (Layout::Struct(f1), Layout::Struct(f2)) => { + debug_assert_eq!(f1, f2); + + let equal = build_struct_eq( + env, + layout_ids, + f1, + lhs_val.into_struct_value(), + rhs_val.into_struct_value(), + ); + + let not_equal: IntValue = env.builder.build_not(equal.into_int_value(), "not"); + + not_equal.into() + } (other1, other2) => { // TODO NOTE: This should ultimately have a _ => todo!("type mismatch!") branch todo!( @@ -366,3 +392,137 @@ fn build_list_eq_help<'a, 'ctx, 'env>( .build_return(Some(&env.context.bool_type().const_int(0, false))); } } + +fn build_struct_eq<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + field_layouts: &'a [Layout<'a>], + struct1: StructValue<'ctx>, + struct2: StructValue<'ctx>, +) -> BasicValueEnum<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let struct_layout = Layout::Struct(field_layouts); + + let symbol = Symbol::GENERIC_EQ; + let fn_name = layout_ids + .get(symbol, &struct_layout) + .to_symbol_string(symbol, &env.interns); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let arena = env.arena; + let arg_type = + basic_type_from_layout(arena, env.context, &struct_layout, env.ptr_bytes); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.bool_type().into(), + &[arg_type, arg_type], + ); + + build_struct_eq_help(env, layout_ids, function_value, field_layouts); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + let call = env + .builder + .build_call(function, &[struct1.into(), struct2.into()], "struct_eq"); + + call.set_call_convention(FAST_CALL_CONV); + + call.try_as_basic_value().left().unwrap() +} + +fn build_struct_eq_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + parent: FunctionValue<'ctx>, + field_layouts: &[Layout<'a>], +) { + let ctx = env.context; + let builder = env.builder; + + { + use inkwell::debug_info::AsDIScope; + + let func_scope = parent.get_subprogram().unwrap(); + let lexical_block = env.dibuilder.create_lexical_block( + /* scope */ func_scope.as_debug_info_scope(), + /* file */ env.compile_unit.get_file(), + /* line_no */ 0, + /* column_no */ 0, + ); + + let loc = env.dibuilder.create_debug_location( + ctx, + /* line */ 0, + /* column */ 0, + /* current_scope */ lexical_block.as_debug_info_scope(), + /* inlined_at */ None, + ); + builder.set_current_debug_location(&ctx, loc); + } + + // Add args to scope + let mut it = parent.get_param_iter(); + let struct1 = it.next().unwrap().into_struct_value(); + let struct2 = it.next().unwrap().into_struct_value(); + + set_name(struct1.into(), Symbol::ARG_1.ident_string(&env.interns)); + set_name(struct2.into(), Symbol::ARG_2.ident_string(&env.interns)); + + let entry = ctx.append_basic_block(parent, "entry"); + let start = ctx.append_basic_block(parent, "start"); + env.builder.position_at_end(entry); + env.builder.build_unconditional_branch(start); + + let return_true = ctx.append_basic_block(parent, "return_true"); + let return_false = ctx.append_basic_block(parent, "return_false"); + + let mut current = start; + + for (index, field_layout) in field_layouts.iter().enumerate() { + env.builder.position_at_end(current); + + let field1 = env + .builder + .build_extract_value(struct1, index as u32, "eq_field") + .unwrap(); + let field2 = env + .builder + .build_extract_value(struct2, index as u32, "eq_field") + .unwrap(); + + let are_equal = + build_eq(env, layout_ids, field1, field2, field_layout, field_layout).into_int_value(); + + current = ctx.append_basic_block(parent, &format!("eq_step_{}", index)); + + env.builder + .build_conditional_branch(are_equal, current, return_false); + } + + env.builder.position_at_end(current); + env.builder.build_unconditional_branch(return_true); + + { + env.builder.position_at_end(return_true); + env.builder + .build_return(Some(&env.context.bool_type().const_int(1, false))); + } + + { + env.builder.position_at_end(return_false); + env.builder + .build_return(Some(&env.context.bool_type().const_int(0, false))); + } +} diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 3550948a69..d3327b122d 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -741,6 +741,7 @@ define_builtins! { 11 DEC: "#dec" // internal function that increments the refcount 12 ARG_CLOSURE: "#arg_closure" // symbol used to store the closure record 13 LIST_EQ: "#list_eq" // internal function that checks list equality + 14 GENERIC_EQ: "#generic_eq" // internal function that checks generic equality } 1 NUM: "Num" => { 0 NUM_NUM: "Num" imported // the Num.Num type alias