diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 8317ce1170..5fee5730ac 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -11,8 +11,7 @@ use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_int, str_join_with, str_number_of_bytes, str_split, str_starts_with, CHAR_LAYOUT, }; - -use crate::llvm::compare::{build_eq, build_neq}; +use crate::llvm::compare::{generic_eq, generic_neq}; use crate::llvm::convert::{ basic_type_from_builtin, basic_type_from_layout, block_of_memory, block_of_memory_slices, collection, get_fn_type, get_ptr_type, ptr_int, @@ -3860,7 +3859,7 @@ fn run_low_level<'a, 'ctx, 'env>( let (lhs_arg, lhs_layout) = load_symbol_and_layout(scope, &args[0]); let (rhs_arg, rhs_layout) = load_symbol_and_layout(scope, &args[1]); - build_eq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout) + generic_eq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout) } NotEq => { debug_assert_eq!(args.len(), 2); @@ -3868,7 +3867,7 @@ fn run_low_level<'a, 'ctx, 'env>( let (lhs_arg, lhs_layout) = load_symbol_and_layout(scope, &args[0]); let (rhs_arg, rhs_layout) = load_symbol_and_layout(scope, &args[1]); - build_neq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout) + generic_neq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout) } And => { // The (&&) operator diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 74e1df4089..48e7e6797f 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1,7 +1,7 @@ use crate::llvm::build::{ allocate_with_refcount_help, build_num_binop, cast_basic_basic, Env, InPlace, }; -use crate::llvm::compare::build_eq; +use crate::llvm::compare::generic_eq; use crate::llvm::convert::{basic_type_from_layout, collection, get_ptr_type}; use crate::llvm::refcounting::{ decrement_refcount_layout, increment_refcount_layout, refcount_is_one_comparison, @@ -1114,7 +1114,7 @@ pub fn list_contains_help<'a, 'ctx, 'env>( let current_elem = builder.build_load(current_elem_ptr, "load_elem"); - let has_found = build_eq( + let has_found = generic_eq( env, layout_ids, current_elem, diff --git a/compiler/gen/src/llvm/compare.rs b/compiler/gen/src/llvm/compare.rs index 4fad64b031..ec1f0ebd05 100644 --- a/compiler/gen/src/llvm/compare.rs +++ b/compiler/gen/src/llvm/compare.rs @@ -1,14 +1,21 @@ use crate::llvm::build::Env; -use crate::llvm::build::{set_name, FAST_CALL_CONV}; +use crate::llvm::build::{cast_block_of_memory_to_tag, complex_bitcast, set_name, FAST_CALL_CONV}; use crate::llvm::build_list::{list_len, load_list_ptr}; use crate::llvm::build_str::str_equal; use crate::llvm::convert::{basic_type_from_layout, get_ptr_type}; -use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, StructValue}; +use bumpalo::collections::Vec; +use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue}; use inkwell::{AddressSpace, FloatPredicate, IntPredicate}; use roc_module::symbol::Symbol; -use roc_mono::layout::{Builtin, Layout, LayoutIds}; +use roc_mono::layout::{Builtin, Layout, LayoutIds, UnionLayout}; -pub fn build_eq<'a, 'ctx, 'env>( +#[derive(Clone, Debug)] +enum WhenRecursive<'a> { + Unreachable, + Loop(UnionLayout<'a>), +} + +pub fn generic_eq<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, lhs_val: BasicValueEnum<'ctx>, @@ -16,153 +23,336 @@ pub fn build_eq<'a, 'ctx, 'env>( lhs_layout: &Layout<'a>, rhs_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - match (lhs_layout, rhs_layout) { - (Layout::Builtin(lhs_builtin), Layout::Builtin(rhs_builtin)) => { - let int_cmp = |pred, label| { - let int_val = env.builder.build_int_compare( - pred, - lhs_val.into_int_value(), - rhs_val.into_int_value(), - label, - ); + build_eq( + env, + layout_ids, + lhs_val, + rhs_val, + lhs_layout, + rhs_layout, + WhenRecursive::Unreachable, + ) +} - BasicValueEnum::IntValue(int_val) - }; +pub fn generic_neq<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + lhs_layout: &Layout<'a>, + rhs_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + build_neq( + env, + layout_ids, + lhs_val, + rhs_val, + lhs_layout, + rhs_layout, + WhenRecursive::Unreachable, + ) +} - let float_cmp = |pred, label| { - let int_val = env.builder.build_float_compare( - pred, - lhs_val.into_float_value(), - rhs_val.into_float_value(), - label, - ); +fn build_eq_builtin<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: &Builtin<'a>, + when_recursive: WhenRecursive<'a>, +) -> BasicValueEnum<'ctx> { + let int_cmp = |pred, label| { + let int_val = env.builder.build_int_compare( + pred, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + label, + ); - BasicValueEnum::IntValue(int_val) - }; + BasicValueEnum::IntValue(int_val) + }; - match (lhs_builtin, rhs_builtin) { - (Builtin::Int128, Builtin::Int128) => int_cmp(IntPredicate::EQ, "eq_i128"), - (Builtin::Int64, Builtin::Int64) => int_cmp(IntPredicate::EQ, "eq_i64"), - (Builtin::Int32, Builtin::Int32) => int_cmp(IntPredicate::EQ, "eq_i32"), - (Builtin::Int16, Builtin::Int16) => int_cmp(IntPredicate::EQ, "eq_i16"), - (Builtin::Int8, Builtin::Int8) => int_cmp(IntPredicate::EQ, "eq_i8"), - (Builtin::Int1, Builtin::Int1) => int_cmp(IntPredicate::EQ, "eq_i1"), - (Builtin::Float64, Builtin::Float64) => float_cmp(FloatPredicate::OEQ, "eq_f64"), - (Builtin::Float32, Builtin::Float32) => float_cmp(FloatPredicate::OEQ, "eq_f32"), - (Builtin::Str, Builtin::Str) => str_equal(env, lhs_val, rhs_val), - (Builtin::EmptyList, Builtin::EmptyList) => { - env.context.bool_type().const_int(1, false).into() - } - (Builtin::List(_, _), Builtin::EmptyList) - | (Builtin::EmptyList, Builtin::List(_, _)) => { - unreachable!("the `==` operator makes sure its two arguments have the same type and thus layout") - } - (Builtin::List(_, elem1), Builtin::List(_, elem2)) => { - debug_assert_eq!(elem1, elem2); + let float_cmp = |pred, label| { + let int_val = env.builder.build_float_compare( + pred, + lhs_val.into_float_value(), + rhs_val.into_float_value(), + label, + ); - build_list_eq( - env, - layout_ids, - lhs_layout, - elem1, - lhs_val.into_struct_value(), - rhs_val.into_struct_value(), - ) - } - (b1, b2) => { - todo!("Handle equals for builtin layouts {:?} == {:?}", b1, b2); - } - } + BasicValueEnum::IntValue(int_val) + }; + + match builtin { + Builtin::Int128 => int_cmp(IntPredicate::EQ, "eq_i128"), + Builtin::Int64 => int_cmp(IntPredicate::EQ, "eq_i64"), + Builtin::Int32 => int_cmp(IntPredicate::EQ, "eq_i32"), + Builtin::Int16 => int_cmp(IntPredicate::EQ, "eq_i16"), + Builtin::Int8 => int_cmp(IntPredicate::EQ, "eq_i8"), + Builtin::Int1 => int_cmp(IntPredicate::EQ, "eq_i1"), + + Builtin::Usize => int_cmp(IntPredicate::EQ, "eq_usize"), + + Builtin::Float128 => float_cmp(FloatPredicate::OEQ, "eq_f128"), + Builtin::Float64 => float_cmp(FloatPredicate::OEQ, "eq_f64"), + Builtin::Float32 => float_cmp(FloatPredicate::OEQ, "eq_f32"), + Builtin::Float16 => float_cmp(FloatPredicate::OEQ, "eq_f16"), + + Builtin::Str => str_equal(env, lhs_val, rhs_val), + Builtin::List(_, elem) => build_list_eq( + env, + layout_ids, + &Layout::Builtin(builtin.clone()), + elem, + lhs_val.into_struct_value(), + rhs_val.into_struct_value(), + when_recursive, + ), + Builtin::Set(_elem) => todo!("equality on Set"), + Builtin::Dict(_key, _value) => todo!("equality on Dict"), + + // empty structures are always equal to themselves + Builtin::EmptyStr => env.context.bool_type().const_int(1, false).into(), + Builtin::EmptyList => env.context.bool_type().const_int(1, false).into(), + Builtin::EmptyDict => env.context.bool_type().const_int(1, false).into(), + Builtin::EmptySet => env.context.bool_type().const_int(1, false).into(), + } +} + +fn build_eq<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + lhs_layout: &Layout<'a>, + rhs_layout: &Layout<'a>, + when_recursive: WhenRecursive<'a>, +) -> BasicValueEnum<'ctx> { + if lhs_layout != rhs_layout { + panic!( + "Equality of different layouts; did you have a type mismatch?\n{:?} == {:?}", + lhs_layout, rhs_layout + ); + } + + match lhs_layout { + Layout::Builtin(builtin) => { + build_eq_builtin(env, layout_ids, lhs_val, rhs_val, builtin, when_recursive) } - (other1, other2) => { - // TODO NOTE: This should ultimately have a _ => todo!("type mismatch!") branch - todo!("implement equals for layouts {:?} == {:?}", other1, other2); + + Layout::Struct(fields) => build_struct_eq( + env, + layout_ids, + fields, + when_recursive, + lhs_val.into_struct_value(), + rhs_val.into_struct_value(), + ), + + Layout::Union(union_layout) => build_tag_eq( + env, + layout_ids, + when_recursive, + lhs_layout, + union_layout, + lhs_val, + rhs_val, + ), + + Layout::PhantomEmptyStruct => { + // always equal to itself + env.context.bool_type().const_int(1, false).into() + } + + Layout::RecursivePointer => match when_recursive { + WhenRecursive::Unreachable => { + unreachable!("recursion pointers should never be compared directly") + } + + WhenRecursive::Loop(union_layout) => { + let layout = Layout::Union(union_layout.clone()); + + let bt = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); + + // cast the i64 pointer to a pointer to block of memory + let field1_cast = env + .builder + .build_bitcast(lhs_val, bt, "i64_to_opaque") + .into_pointer_value(); + + let field2_cast = env + .builder + .build_bitcast(rhs_val, bt, "i64_to_opaque") + .into_pointer_value(); + + build_tag_eq( + env, + layout_ids, + WhenRecursive::Loop(union_layout.clone()), + &layout, + &union_layout, + field1_cast.into(), + field2_cast.into(), + ) + } + }, + + Layout::Pointer(_) => { + unreachable!("unused") + } + + Layout::FunctionPointer(_, _) | Layout::Closure(_, _, _) => { + unreachable!("the type system will guarantee these are never compared") } } } -pub fn build_neq<'a, 'ctx, 'env>( +fn build_neq_builtin<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: &Builtin<'a>, + when_recursive: WhenRecursive<'a>, +) -> BasicValueEnum<'ctx> { + let int_cmp = |pred, label| { + let int_val = env.builder.build_int_compare( + pred, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + label, + ); + + BasicValueEnum::IntValue(int_val) + }; + + let float_cmp = |pred, label| { + let int_val = env.builder.build_float_compare( + pred, + lhs_val.into_float_value(), + rhs_val.into_float_value(), + label, + ); + + BasicValueEnum::IntValue(int_val) + }; + + match builtin { + Builtin::Int128 => int_cmp(IntPredicate::NE, "neq_i128"), + Builtin::Int64 => int_cmp(IntPredicate::NE, "neq_i64"), + Builtin::Int32 => int_cmp(IntPredicate::NE, "neq_i32"), + Builtin::Int16 => int_cmp(IntPredicate::NE, "neq_i16"), + Builtin::Int8 => int_cmp(IntPredicate::NE, "neq_i8"), + Builtin::Int1 => int_cmp(IntPredicate::NE, "neq_i1"), + + Builtin::Usize => int_cmp(IntPredicate::NE, "neq_usize"), + + Builtin::Float128 => float_cmp(FloatPredicate::ONE, "neq_f128"), + Builtin::Float64 => float_cmp(FloatPredicate::ONE, "neq_f64"), + Builtin::Float32 => float_cmp(FloatPredicate::ONE, "neq_f32"), + Builtin::Float16 => float_cmp(FloatPredicate::ONE, "neq_f16"), + + Builtin::Str => { + let is_equal = str_equal(env, lhs_val, rhs_val).into_int_value(); + let result: IntValue = env.builder.build_not(is_equal, "negate"); + + result.into() + } + Builtin::List(_, elem) => { + let is_equal = build_list_eq( + env, + layout_ids, + &Layout::Builtin(builtin.clone()), + elem, + lhs_val.into_struct_value(), + rhs_val.into_struct_value(), + when_recursive, + ) + .into_int_value(); + + let result: IntValue = env.builder.build_not(is_equal, "negate"); + + result.into() + } + Builtin::Set(_elem) => todo!("equality on Set"), + Builtin::Dict(_key, _value) => todo!("equality on Dict"), + + // empty structures are always equal to themselves + Builtin::EmptyStr => env.context.bool_type().const_int(0, false).into(), + Builtin::EmptyList => env.context.bool_type().const_int(0, false).into(), + Builtin::EmptyDict => env.context.bool_type().const_int(0, false).into(), + Builtin::EmptySet => env.context.bool_type().const_int(0, false).into(), + } +} + +fn build_neq<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, lhs_val: BasicValueEnum<'ctx>, rhs_val: BasicValueEnum<'ctx>, lhs_layout: &Layout<'a>, rhs_layout: &Layout<'a>, + when_recursive: WhenRecursive<'a>, ) -> BasicValueEnum<'ctx> { - match (lhs_layout, rhs_layout) { - (Layout::Builtin(lhs_builtin), Layout::Builtin(rhs_builtin)) => { - let int_cmp = |pred, label| { - let int_val = env.builder.build_int_compare( - pred, - lhs_val.into_int_value(), - rhs_val.into_int_value(), - label, - ); + if lhs_layout != rhs_layout { + panic!( + "Inequality of different layouts; did you have a type mismatch?\n{:?} != {:?}", + lhs_layout, rhs_layout + ); + } - BasicValueEnum::IntValue(int_val) - }; - - let float_cmp = |pred, label| { - let int_val = env.builder.build_float_compare( - pred, - lhs_val.into_float_value(), - rhs_val.into_float_value(), - label, - ); - - BasicValueEnum::IntValue(int_val) - }; - - match (lhs_builtin, rhs_builtin) { - (Builtin::Int128, Builtin::Int128) => int_cmp(IntPredicate::NE, "neq_i128"), - (Builtin::Int64, Builtin::Int64) => int_cmp(IntPredicate::NE, "neq_i64"), - (Builtin::Int32, Builtin::Int32) => int_cmp(IntPredicate::NE, "neq_i32"), - (Builtin::Int16, Builtin::Int16) => int_cmp(IntPredicate::NE, "neq_i16"), - (Builtin::Int8, Builtin::Int8) => int_cmp(IntPredicate::NE, "neq_i8"), - (Builtin::Int1, Builtin::Int1) => int_cmp(IntPredicate::NE, "neq_i1"), - (Builtin::Float64, Builtin::Float64) => float_cmp(FloatPredicate::ONE, "neq_f64"), - (Builtin::Float32, Builtin::Float32) => float_cmp(FloatPredicate::ONE, "neq_f32"), - (Builtin::Str, Builtin::Str) => { - let is_equal = str_equal(env, lhs_val, rhs_val).into_int_value(); - let result: IntValue = env.builder.build_not(is_equal, "negate"); - - result.into() - } - (Builtin::EmptyList, Builtin::EmptyList) => { - env.context.bool_type().const_int(0, false).into() - } - (Builtin::List(_, _), Builtin::EmptyList) - | (Builtin::EmptyList, Builtin::List(_, _)) => { - unreachable!("the `==` operator makes sure its two arguments have the same type and thus layout") - } - (Builtin::List(_, elem1), Builtin::List(_, elem2)) => { - debug_assert_eq!(elem1, elem2); - - let equal = build_list_eq( - env, - layout_ids, - lhs_layout, - elem1, - lhs_val.into_struct_value(), - rhs_val.into_struct_value(), - ); - - let not_equal: IntValue = env.builder.build_not(equal.into_int_value(), "not"); - - not_equal.into() - } - (b1, b2) => { - todo!("Handle not equals for builtin layouts {:?} == {:?}", b1, b2); - } - } + match lhs_layout { + Layout::Builtin(builtin) => { + build_neq_builtin(env, layout_ids, lhs_val, rhs_val, builtin, when_recursive) } - (other1, other2) => { - // TODO NOTE: This should ultimately have a _ => todo!("type mismatch!") branch - todo!( - "implement not equals for layouts {:?} == {:?}", - other1, - other2 - ); + + Layout::Struct(fields) => { + let is_equal = build_struct_eq( + env, + layout_ids, + fields, + when_recursive, + lhs_val.into_struct_value(), + rhs_val.into_struct_value(), + ) + .into_int_value(); + + let result: IntValue = env.builder.build_not(is_equal, "negate"); + + result.into() + } + Layout::Union(union_layout) => { + let is_equal = build_tag_eq( + env, + layout_ids, + when_recursive, + lhs_layout, + union_layout, + lhs_val, + rhs_val, + ) + .into_int_value(); + + let result: IntValue = env.builder.build_not(is_equal, "negate"); + + result.into() + } + + Layout::PhantomEmptyStruct => { + // always equal to itself + env.context.bool_type().const_int(1, false).into() + } + + Layout::RecursivePointer => { + unreachable!("recursion pointers should never be compared directly") + } + + Layout::Pointer(_) => { + unreachable!("unused") + } + + Layout::FunctionPointer(_, _) | Layout::Closure(_, _, _) => { + unreachable!("the type system will guarantee these are never compared") } } } @@ -174,7 +364,9 @@ fn build_list_eq<'a, 'ctx, 'env>( element_layout: &Layout<'a>, list1: StructValue<'ctx>, list2: StructValue<'ctx>, + when_recursive: WhenRecursive<'a>, ) -> BasicValueEnum<'ctx> { + dbg!("list", &when_recursive); let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); @@ -196,7 +388,13 @@ fn build_list_eq<'a, 'ctx, 'env>( &[arg_type, arg_type], ); - build_list_eq_help(env, layout_ids, function_value, element_layout); + build_list_eq_help( + env, + layout_ids, + when_recursive, + function_value, + element_layout, + ); function_value } @@ -217,6 +415,7 @@ fn build_list_eq<'a, 'ctx, 'env>( fn build_list_eq_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + when_recursive: WhenRecursive<'a>, parent: FunctionValue<'ctx>, element_layout: &Layout<'a>, ) { @@ -331,6 +530,7 @@ fn build_list_eq_help<'a, 'ctx, 'env>( elem2, element_layout, element_layout, + when_recursive, ) .into_int_value(); @@ -366,3 +566,707 @@ fn build_list_eq_help<'a, 'ctx, 'env>( .build_return(Some(&env.context.bool_type().const_int(0, false))); } } + +fn build_struct_eq<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + field_layouts: &'a [Layout<'a>], + when_recursive: WhenRecursive<'a>, + struct1: StructValue<'ctx>, + struct2: StructValue<'ctx>, +) -> BasicValueEnum<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let struct_layout = Layout::Struct(field_layouts); + + let symbol = Symbol::GENERIC_EQ; + let fn_name = layout_ids + .get(symbol, &struct_layout) + .to_symbol_string(symbol, &env.interns); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let arena = env.arena; + let arg_type = + basic_type_from_layout(arena, env.context, &struct_layout, env.ptr_bytes); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.bool_type().into(), + &[arg_type, arg_type], + ); + + build_struct_eq_help( + env, + layout_ids, + function_value, + when_recursive, + field_layouts, + ); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + let call = env + .builder + .build_call(function, &[struct1.into(), struct2.into()], "struct_eq"); + + call.set_call_convention(FAST_CALL_CONV); + + call.try_as_basic_value().left().unwrap() +} + +fn build_struct_eq_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + parent: FunctionValue<'ctx>, + when_recursive: WhenRecursive<'a>, + field_layouts: &[Layout<'a>], +) { + let ctx = env.context; + let builder = env.builder; + + { + use inkwell::debug_info::AsDIScope; + + let func_scope = parent.get_subprogram().unwrap(); + let lexical_block = env.dibuilder.create_lexical_block( + /* scope */ func_scope.as_debug_info_scope(), + /* file */ env.compile_unit.get_file(), + /* line_no */ 0, + /* column_no */ 0, + ); + + let loc = env.dibuilder.create_debug_location( + ctx, + /* line */ 0, + /* column */ 0, + /* current_scope */ lexical_block.as_debug_info_scope(), + /* inlined_at */ None, + ); + builder.set_current_debug_location(&ctx, loc); + } + + // Add args to scope + let mut it = parent.get_param_iter(); + let struct1 = it.next().unwrap().into_struct_value(); + let struct2 = it.next().unwrap().into_struct_value(); + + set_name(struct1.into(), Symbol::ARG_1.ident_string(&env.interns)); + set_name(struct2.into(), Symbol::ARG_2.ident_string(&env.interns)); + + let entry = ctx.append_basic_block(parent, "entry"); + let start = ctx.append_basic_block(parent, "start"); + env.builder.position_at_end(entry); + env.builder.build_unconditional_branch(start); + + let return_true = ctx.append_basic_block(parent, "return_true"); + let return_false = ctx.append_basic_block(parent, "return_false"); + + let mut current = start; + + for (index, field_layout) in field_layouts.iter().enumerate() { + env.builder.position_at_end(current); + + let field1 = env + .builder + .build_extract_value(struct1, index as u32, "eq_field") + .unwrap(); + + let field2 = env + .builder + .build_extract_value(struct2, index as u32, "eq_field") + .unwrap(); + + let are_equal = if let Layout::RecursivePointer = field_layout { + match &when_recursive { + WhenRecursive::Unreachable => { + unreachable!("The current layout should not be recursive, but is") + } + WhenRecursive::Loop(union_layout) => { + let field_layout = Layout::Union(union_layout.clone()); + + let bt = basic_type_from_layout( + env.arena, + env.context, + &field_layout, + env.ptr_bytes, + ); + + // cast the i64 pointer to a pointer to block of memory + let field1_cast = env + .builder + .build_bitcast(field1, bt, "i64_to_opaque") + .into_pointer_value(); + + let field2_cast = env + .builder + .build_bitcast(field2, bt, "i64_to_opaque") + .into_pointer_value(); + + build_eq( + env, + layout_ids, + field1_cast.into(), + field2_cast.into(), + &field_layout, + &field_layout, + WhenRecursive::Loop(union_layout.clone()), + ) + .into_int_value() + } + } + } else { + build_eq( + env, + layout_ids, + field1, + field2, + field_layout, + field_layout, + when_recursive.clone(), + ) + .into_int_value() + }; + + current = ctx.append_basic_block(parent, &format!("eq_step_{}", index)); + + env.builder + .build_conditional_branch(are_equal, current, return_false); + } + + env.builder.position_at_end(current); + env.builder.build_unconditional_branch(return_true); + + { + env.builder.position_at_end(return_true); + env.builder + .build_return(Some(&env.context.bool_type().const_int(1, false))); + } + + { + env.builder.position_at_end(return_false); + env.builder + .build_return(Some(&env.context.bool_type().const_int(0, false))); + } +} + +fn build_tag_eq<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + when_recursive: WhenRecursive<'a>, + tag_layout: &Layout<'a>, + union_layout: &UnionLayout<'a>, + tag1: BasicValueEnum<'ctx>, + tag2: BasicValueEnum<'ctx>, +) -> BasicValueEnum<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let symbol = Symbol::GENERIC_EQ; + let fn_name = layout_ids + .get(symbol, &tag_layout) + .to_symbol_string(symbol, &env.interns); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let arena = env.arena; + let arg_type = basic_type_from_layout(arena, env.context, &tag_layout, env.ptr_bytes); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.bool_type().into(), + &[arg_type, arg_type], + ); + + build_tag_eq_help( + env, + layout_ids, + when_recursive, + function_value, + union_layout, + ); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + let call = env.builder.build_call(function, &[tag1, tag2], "tag_eq"); + + call.set_call_convention(FAST_CALL_CONV); + + call.try_as_basic_value().left().unwrap() +} + +fn build_tag_eq_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + when_recursive: WhenRecursive<'a>, + parent: FunctionValue<'ctx>, + union_layout: &UnionLayout<'a>, +) { + let ctx = env.context; + let builder = env.builder; + + { + use inkwell::debug_info::AsDIScope; + + let func_scope = parent.get_subprogram().unwrap(); + let lexical_block = env.dibuilder.create_lexical_block( + /* scope */ func_scope.as_debug_info_scope(), + /* file */ env.compile_unit.get_file(), + /* line_no */ 0, + /* column_no */ 0, + ); + + let loc = env.dibuilder.create_debug_location( + ctx, + /* line */ 0, + /* column */ 0, + /* current_scope */ lexical_block.as_debug_info_scope(), + /* inlined_at */ None, + ); + builder.set_current_debug_location(&ctx, loc); + } + + // Add args to scope + let mut it = parent.get_param_iter(); + let tag1 = it.next().unwrap(); + let tag2 = it.next().unwrap(); + + set_name(tag1, Symbol::ARG_1.ident_string(&env.interns)); + set_name(tag2, Symbol::ARG_2.ident_string(&env.interns)); + + let entry = ctx.append_basic_block(parent, "entry"); + + let return_true = ctx.append_basic_block(parent, "return_true"); + let return_false = ctx.append_basic_block(parent, "return_false"); + + { + env.builder.position_at_end(return_false); + env.builder + .build_return(Some(&env.context.bool_type().const_int(0, false))); + } + + { + env.builder.position_at_end(return_true); + env.builder + .build_return(Some(&env.context.bool_type().const_int(1, false))); + } + + env.builder.position_at_end(entry); + + use UnionLayout::*; + + match union_layout { + NonRecursive(tags) => { + // SAFETY we know that non-recursive tags cannot be NULL + let id1 = nonrec_tag_id(env, tag1.into_struct_value()); + let id2 = nonrec_tag_id(env, tag2.into_struct_value()); + + let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields"); + + let same_tag = + env.builder + .build_int_compare(IntPredicate::EQ, id1, id2, "compare_tag_id"); + + env.builder + .build_conditional_branch(same_tag, compare_tag_fields, return_false); + + env.builder.position_at_end(compare_tag_fields); + + // switch on all the tag ids + + let mut cases = Vec::with_capacity_in(tags.len(), env.arena); + + for (tag_id, field_layouts) in tags.iter().enumerate() { + let block = env.context.append_basic_block(parent, "tag_id_modify"); + env.builder.position_at_end(block); + + // TODO drop tag id? + let struct_layout = Layout::Struct(field_layouts); + + let wrapper_type = + basic_type_from_layout(env.arena, env.context, &struct_layout, env.ptr_bytes); + debug_assert!(wrapper_type.is_struct_type()); + + let struct1 = cast_block_of_memory_to_tag( + env.builder, + tag1.into_struct_value(), + wrapper_type, + ); + let struct2 = cast_block_of_memory_to_tag( + env.builder, + tag2.into_struct_value(), + wrapper_type, + ); + + let answer = build_struct_eq( + env, + layout_ids, + field_layouts, + when_recursive.clone(), + struct1, + struct2, + ); + + env.builder.build_return(Some(&answer)); + + cases.push(( + env.context.i64_type().const_int(tag_id as u64, false), + block, + )); + } + + env.builder.position_at_end(compare_tag_fields); + + let default = cases.pop().unwrap().1; + + env.builder.build_switch(id1, default, &cases); + } + Recursive(tags) => { + let ptr_equal = env.builder.build_int_compare( + IntPredicate::EQ, + env.builder + .build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"), + env.builder + .build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"), + "compare_pointers", + ); + + let compare_tag_ids = ctx.append_basic_block(parent, "compare_tag_ids"); + + env.builder + .build_conditional_branch(ptr_equal, return_true, compare_tag_ids); + + env.builder.position_at_end(compare_tag_ids); + + // SAFETY we know that non-recursive tags cannot be NULL + let id1 = unsafe { rec_tag_id_unsafe(env, tag1.into_pointer_value()) }; + let id2 = unsafe { rec_tag_id_unsafe(env, tag2.into_pointer_value()) }; + + let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields"); + + let same_tag = + env.builder + .build_int_compare(IntPredicate::EQ, id1, id2, "compare_tag_id"); + + env.builder + .build_conditional_branch(same_tag, compare_tag_fields, return_false); + + env.builder.position_at_end(compare_tag_fields); + + // switch on all the tag ids + + let mut cases = Vec::with_capacity_in(tags.len(), env.arena); + + for (tag_id, field_layouts) in tags.iter().enumerate() { + let block = env.context.append_basic_block(parent, "tag_id_modify"); + env.builder.position_at_end(block); + + let answer = eq_ptr_to_struct( + env, + layout_ids, + union_layout, + field_layouts, + tag1.into_pointer_value(), + tag2.into_pointer_value(), + ); + + env.builder.build_return(Some(&answer)); + + cases.push(( + env.context.i64_type().const_int(tag_id as u64, false), + block, + )); + } + + env.builder.position_at_end(compare_tag_fields); + + let default = cases.pop().unwrap().1; + + env.builder.build_switch(id1, default, &cases); + } + NullableUnwrapped { other_fields, .. } => { + // drop the tag id; it is not stored + let other_fields = &other_fields[1..]; + + let ptr_equal = env.builder.build_int_compare( + IntPredicate::EQ, + env.builder + .build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"), + env.builder + .build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"), + "compare_pointers", + ); + + let check_for_null = ctx.append_basic_block(parent, "check_for_null"); + let compare_other = ctx.append_basic_block(parent, "compare_other"); + + env.builder + .build_conditional_branch(ptr_equal, return_true, check_for_null); + + // check for NULL + + env.builder.position_at_end(check_for_null); + + let is_null_1 = env + .builder + .build_is_null(tag1.into_pointer_value(), "is_null"); + + let is_null_2 = env + .builder + .build_is_null(tag2.into_pointer_value(), "is_null"); + + let either_null = env.builder.build_or(is_null_1, is_null_2, "either_null"); + + // logic: the pointers are not the same, if one is NULL, the other one is not + // therefore the two tags are not equal + env.builder + .build_conditional_branch(either_null, return_false, compare_other); + + // compare the non-null case + + env.builder.position_at_end(compare_other); + + let answer = eq_ptr_to_struct( + env, + layout_ids, + union_layout, + other_fields, + tag1.into_pointer_value(), + tag2.into_pointer_value(), + ); + + env.builder.build_return(Some(&answer)); + } + NullableWrapped { other_tags, .. } => { + let ptr_equal = env.builder.build_int_compare( + IntPredicate::EQ, + env.builder + .build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"), + env.builder + .build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"), + "compare_pointers", + ); + + let check_for_null = ctx.append_basic_block(parent, "check_for_null"); + let compare_other = ctx.append_basic_block(parent, "compare_other"); + + env.builder + .build_conditional_branch(ptr_equal, return_true, check_for_null); + + // check for NULL + + env.builder.position_at_end(check_for_null); + + let is_null_1 = env + .builder + .build_is_null(tag1.into_pointer_value(), "is_null"); + + let is_null_2 = env + .builder + .build_is_null(tag2.into_pointer_value(), "is_null"); + + // Logic: + // + // NULL and NULL => equal + // NULL and not => not equal + // not and NULL => not equal + // not and not => more work required + + let i8_type = env.context.i8_type(); + + let sum = env.builder.build_int_add( + env.builder.build_int_cast(is_null_1, i8_type, "to_u8"), + env.builder.build_int_cast(is_null_2, i8_type, "to_u8"), + "sum_is_null", + ); + + env.builder.build_switch( + sum, + compare_other, + &[ + (i8_type.const_int(2, false), return_true), + (i8_type.const_int(1, false), return_false), + ], + ); + + // compare the non-null case + + env.builder.position_at_end(compare_other); + + // SAFETY we know at this point that tag1/tag2 are not NULL + let id1 = unsafe { rec_tag_id_unsafe(env, tag1.into_pointer_value()) }; + let id2 = unsafe { rec_tag_id_unsafe(env, tag2.into_pointer_value()) }; + + let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields"); + + let same_tag = + env.builder + .build_int_compare(IntPredicate::EQ, id1, id2, "compare_tag_id"); + + env.builder + .build_conditional_branch(same_tag, compare_tag_fields, return_false); + + env.builder.position_at_end(compare_tag_fields); + + // switch on all the tag ids + + let tags = other_tags; + let mut cases = Vec::with_capacity_in(tags.len(), env.arena); + + for (tag_id, field_layouts) in tags.iter().enumerate() { + let block = env.context.append_basic_block(parent, "tag_id_modify"); + env.builder.position_at_end(block); + + let answer = eq_ptr_to_struct( + env, + layout_ids, + union_layout, + field_layouts, + tag1.into_pointer_value(), + tag2.into_pointer_value(), + ); + + env.builder.build_return(Some(&answer)); + + cases.push(( + env.context.i64_type().const_int(tag_id as u64, false), + block, + )); + } + + env.builder.position_at_end(compare_tag_fields); + + let default = cases.pop().unwrap().1; + + env.builder.build_switch(id1, default, &cases); + } + NonNullableUnwrapped(field_layouts) => { + let ptr_equal = env.builder.build_int_compare( + IntPredicate::EQ, + env.builder + .build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"), + env.builder + .build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"), + "compare_pointers", + ); + + let compare_fields = ctx.append_basic_block(parent, "compare_fields"); + + env.builder + .build_conditional_branch(ptr_equal, return_true, compare_fields); + + env.builder.position_at_end(compare_fields); + + let answer = eq_ptr_to_struct( + env, + layout_ids, + union_layout, + field_layouts, + tag1.into_pointer_value(), + tag2.into_pointer_value(), + ); + + env.builder.build_return(Some(&answer)); + } + } +} + +fn eq_ptr_to_struct<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + union_layout: &UnionLayout<'a>, + field_layouts: &'a [Layout<'a>], + tag1: PointerValue<'ctx>, + tag2: PointerValue<'ctx>, +) -> IntValue<'ctx> { + use inkwell::types::BasicType; + + let struct_layout = Layout::Struct(field_layouts); + + let wrapper_type = + basic_type_from_layout(env.arena, env.context, &struct_layout, env.ptr_bytes); + debug_assert!(wrapper_type.is_struct_type()); + + // cast the opaque pointer to a pointer of the correct shape + let struct1_ptr = env + .builder + .build_bitcast( + tag1, + wrapper_type.ptr_type(AddressSpace::Generic), + "opaque_to_correct", + ) + .into_pointer_value(); + + let struct2_ptr = env + .builder + .build_bitcast( + tag2, + wrapper_type.ptr_type(AddressSpace::Generic), + "opaque_to_correct", + ) + .into_pointer_value(); + + let struct1 = env + .builder + .build_load(struct1_ptr, "load_struct1") + .into_struct_value(); + + let struct2 = env + .builder + .build_load(struct2_ptr, "load_struct2") + .into_struct_value(); + + build_struct_eq( + env, + layout_ids, + field_layouts, + WhenRecursive::Loop(union_layout.clone()), + struct1, + struct2, + ) + .into_int_value() +} + +fn nonrec_tag_id<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + tag: StructValue<'ctx>, +) -> IntValue<'ctx> { + complex_bitcast( + env.builder, + tag.into(), + env.context.i64_type().into(), + "load_tag_id", + ) + .into_int_value() +} + +unsafe fn rec_tag_id_unsafe<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + tag: PointerValue<'ctx>, +) -> IntValue<'ctx> { + let ptr = env + .builder + .build_bitcast( + tag, + env.context.i64_type().ptr_type(AddressSpace::Generic), + "cast_for_tag_id", + ) + .into_pointer_value(); + + env.builder.build_load(ptr, "load_tag_id").into_int_value() +} diff --git a/compiler/gen/tests/gen_compare.rs b/compiler/gen/tests/gen_compare.rs new file mode 100644 index 0000000000..2a119d4869 --- /dev/null +++ b/compiler/gen/tests/gen_compare.rs @@ -0,0 +1,419 @@ +#[macro_use] +extern crate pretty_assertions; +#[macro_use] +extern crate indoc; + +extern crate bumpalo; +extern crate inkwell; +extern crate libc; +extern crate roc_gen; + +#[macro_use] +mod helpers; + +#[cfg(test)] +mod gen_num { + + #[test] + fn eq_i64() { + assert_evals_to!( + indoc!( + r#" + i : I64 + i = 1 + + i == i + "# + ), + true, + bool + ); + } + + #[test] + fn neq_i64() { + assert_evals_to!( + indoc!( + r#" + i : I64 + i = 1 + + i != i + "# + ), + false, + bool + ); + } + + #[test] + fn eq_u64() { + assert_evals_to!( + indoc!( + r#" + i : U64 + i = 1 + + i == i + "# + ), + true, + bool + ); + } + + #[test] + fn neq_u64() { + assert_evals_to!( + indoc!( + r#" + i : U64 + i = 1 + + i != i + "# + ), + false, + bool + ); + } + + #[test] + fn eq_f64() { + assert_evals_to!( + indoc!( + r#" + i : F64 + i = 1 + + i == i + "# + ), + true, + bool + ); + } + + #[test] + fn neq_f64() { + assert_evals_to!( + indoc!( + r#" + i : F64 + i = 1 + + i != i + "# + ), + false, + bool + ); + } + + #[test] + fn eq_bool_tag() { + assert_evals_to!( + indoc!( + r#" + true : Bool + true = True + + true == True + "# + ), + true, + bool + ); + } + + #[test] + fn neq_bool_tag() { + assert_evals_to!( + indoc!( + r#" + true : Bool + true = True + + true == False + "# + ), + false, + bool + ); + } + + #[test] + fn empty_record() { + assert_evals_to!("{} == {}", true, bool); + assert_evals_to!("{} != {}", false, bool); + } + + #[test] + fn unit() { + assert_evals_to!("Unit == Unit", true, bool); + assert_evals_to!("Unit != Unit", false, bool); + } + + #[test] + fn newtype() { + assert_evals_to!("Identity 42 == Identity 42", true, bool); + assert_evals_to!("Identity 42 != Identity 42", false, bool); + } + + #[test] + fn small_str() { + assert_evals_to!("\"aaa\" == \"aaa\"", true, bool); + assert_evals_to!("\"aaa\" == \"bbb\"", false, bool); + assert_evals_to!("\"aaa\" != \"aaa\"", false, bool); + } + + #[test] + fn large_str() { + assert_evals_to!( + indoc!( + r#" + x = "Unicode can represent text values which span multiple languages" + y = "Unicode can represent text values which span multiple languages" + + x == y + "# + ), + true, + bool + ); + + assert_evals_to!( + indoc!( + r#" + x = "Unicode can represent text values which span multiple languages" + y = "Here are some valid Roc strings" + + x != y + "# + ), + true, + bool + ); + } + + #[test] + fn eq_result_tag_true() { + assert_evals_to!( + indoc!( + r#" + x : Result I64 I64 + x = Ok 1 + + y : Result I64 I64 + y = Ok 1 + + x == y + "# + ), + true, + bool + ); + } + + #[test] + fn eq_result_tag_false() { + assert_evals_to!( + indoc!( + r#" + x : Result I64 I64 + x = Ok 1 + + y : Result I64 I64 + y = Err 1 + + x == y + "# + ), + false, + bool + ); + } + + #[test] + fn eq_expr() { + assert_evals_to!( + indoc!( + r#" + Expr : [ Add Expr Expr, Mul Expr Expr, Val I64, Var I64 ] + + x : Expr + x = Val 0 + + y : Expr + y = Val 0 + + x == y + "# + ), + true, + bool + ); + } + + #[test] + fn eq_linked_list() { + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + x : LinkedList I64 + x = Nil + + y : LinkedList I64 + y = Nil + + x == y + "# + ), + true, + bool + ); + + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + x : LinkedList I64 + x = Cons 1 Nil + + y : LinkedList I64 + y = Cons 1 Nil + + x == y + "# + ), + true, + bool + ); + + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + x : LinkedList I64 + x = Cons 1 (Cons 2 Nil) + + y : LinkedList I64 + y = Cons 1 (Cons 2 Nil) + + x == y + "# + ), + true, + bool + ); + } + + #[test] + fn eq_linked_list_false() { + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + x : LinkedList I64 + x = Cons 1 Nil + + y : LinkedList I64 + y = Cons 1 (Cons 2 Nil) + + y == x + "# + ), + false, + bool + ); + } + + #[test] + fn eq_nullable_expr() { + assert_evals_to!( + indoc!( + r#" + Expr : [ Add Expr Expr, Mul Expr Expr, Val I64, Empty ] + + x : Expr + x = Val 0 + + y : Expr + y = Add x x + + x != y + "# + ), + true, + bool + ); + } + + #[test] + fn eq_rosetree() { + assert_evals_to!( + indoc!( + r#" + Rose a : [ Rose (List (Rose a)) ] + + x : Rose I64 + x = Rose [] + + y : Rose I64 + y = Rose [] + + x == y + "# + ), + true, + bool + ); + + assert_evals_to!( + indoc!( + r#" + Rose a : [ Rose (List (Rose a)) ] + + x : Rose I64 + x = Rose [] + + y : Rose I64 + y = Rose [] + + x != y + "# + ), + false, + bool + ); + } + + #[test] + #[ignore] + fn rosetree_with_tag() { + // currently stack overflows in type checking + + assert_evals_to!( + indoc!( + r#" + Rose a : [ Rose (Result (List (Rose a)) I64) ] + + x : Rose I64 + x = (Rose (Ok [])) + + y : Rose I64 + y = (Rose (Ok [])) + + x == y + "# + ), + true, + bool + ); + } +} diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 4f3049346c..0a745dd6e3 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -741,10 +741,14 @@ define_builtins! { 11 DEC: "#dec" // internal function that increments the refcount 12 ARG_CLOSURE: "#arg_closure" // symbol used to store the closure record 13 LIST_EQ: "#list_eq" // internal function that checks list equality + 14 GENERIC_HASH: "#generic_hash" // hash of arbitrary layouts 15 GENERIC_HASH_REF: "#generic_hash_by_ref" // hash of arbitrary layouts, passed as an opaque pointer + 16 GENERIC_EQ_REF: "#generic_eq_by_ref" // equality of arbitrary layouts, passed as an opaque pointer 17 GENERIC_RC_REF: "#generic_rc_by_ref" // refcount of arbitrary layouts, passed as an opaque pointer + + 18 GENERIC_EQ: "#generic_eq" // internal function that checks generic equality } 1 NUM: "Num" => { 0 NUM_NUM: "Num" imported // the Num.Num type alias diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 54b1c16798..7b68fc593b 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -19,6 +19,28 @@ use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder}; pub const PRETTY_PRINT_IR_SYMBOLS: bool = false; +macro_rules! return_on_layout_error { + ($env:expr, $layout_result:expr) => { + match $layout_result { + Ok(cached) => cached, + Err(LayoutProblem::UnresolvedTypeVar(_)) => { + return Stmt::RuntimeError($env.arena.alloc(format!( + "UnresolvedTypeVar {} line {}", + file!(), + line!() + ))); + } + Err(LayoutProblem::Erroneous) => { + return Stmt::RuntimeError($env.arena.alloc(format!( + "Erroneous {} line {}", + file!(), + line!() + ))); + } + } + }; +} + #[derive(Clone, Debug, PartialEq)] pub enum MonoProblem { PatternProblem(crate::exhaustive::Error), @@ -3790,11 +3812,10 @@ pub fn with_hole<'a>( ) .into_bump_slice(); - let full_layout = layout_cache - .from_var(env.arena, fn_var, env.subs) - .unwrap_or_else(|err| { - panic!("TODO turn fn_var into a RuntimeError {:?}", err) - }); + let full_layout = return_on_layout_error!( + env, + layout_cache.from_var(env.arena, fn_var, env.subs) + ); let arg_layouts = match full_layout { Layout::FunctionPointer(args, _) => args, @@ -3802,11 +3823,10 @@ pub fn with_hole<'a>( _ => unreachable!("function has layout that is not function pointer"), }; - let ret_layout = layout_cache - .from_var(env.arena, ret_var, env.subs) - .unwrap_or_else(|err| { - panic!("TODO turn fn_var into a RuntimeError {:?}", err) - }); + let ret_layout = return_on_layout_error!( + env, + layout_cache.from_var(env.arena, ret_var, env.subs) + ); // if the function expression (loc_expr) is already a symbol, // re-use that symbol, and don't define its value again