diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 046b00886e..5b99c62223 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -7,8 +7,8 @@ use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, - list_map2, list_map3, list_map_new, list_map_with_index, list_prepend, list_range, list_repeat, - list_reverse, list_set, list_single, list_sort_with, list_walk_help, + list_map2, list_map3, list_map_with_index, list_prepend, list_range, list_repeat, list_reverse, + list_set, list_single, list_sort_with, list_walk_help, }; use crate::llvm::build_str::{ empty_str, str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, @@ -3716,7 +3716,7 @@ fn run_low_level<'a, 'ctx, 'env>( match list_layout { Layout::Builtin(Builtin::EmptyList) => empty_list(env), - Layout::Builtin(Builtin::List(_, element_layout)) => list_map_new( + Layout::Builtin(Builtin::List(_, element_layout)) => list_map( env, layout_ids, function, @@ -3730,12 +3730,13 @@ fn run_low_level<'a, 'ctx, 'env>( } } ListMap2 => { - debug_assert_eq!(args.len(), 3); + debug_assert_eq!(args.len(), 4); let (list1, list1_layout) = load_symbol_and_layout(scope, &args[0]); let (list2, list2_layout) = load_symbol_and_layout(scope, &args[1]); - let (func, func_layout) = load_symbol_and_layout(scope, &args[2]); + let (function_layout, function) = scope.function_pointers[&args[2]]; + let (closure, closure_layout) = load_symbol_and_layout(scope, &args[3]); match (list1_layout, list2_layout) { ( @@ -3744,8 +3745,10 @@ fn run_low_level<'a, 'ctx, 'env>( ) => list_map2( env, layout_ids, - func, - func_layout, + function, + function_layout, + closure, + *closure_layout, list1, list2, element1_layout, @@ -3757,13 +3760,14 @@ fn run_low_level<'a, 'ctx, 'env>( } } ListMap3 => { - debug_assert_eq!(args.len(), 4); + debug_assert_eq!(args.len(), 5); let (list1, list1_layout) = load_symbol_and_layout(scope, &args[0]); let (list2, list2_layout) = load_symbol_and_layout(scope, &args[1]); let (list3, list3_layout) = load_symbol_and_layout(scope, &args[2]); - let (func, func_layout) = load_symbol_and_layout(scope, &args[3]); + let (function_layout, function) = scope.function_pointers[&args[3]]; + let (closure, closure_layout) = load_symbol_and_layout(scope, &args[4]); match (list1_layout, list2_layout, list3_layout) { ( @@ -3773,8 +3777,10 @@ fn run_low_level<'a, 'ctx, 'env>( ) => list_map3( env, layout_ids, - func, - func_layout, + function, + function_layout, + closure, + *closure_layout, list1, list2, list3, @@ -3789,44 +3795,64 @@ fn run_low_level<'a, 'ctx, 'env>( } } ListMapWithIndex => { - // List.map : List before, (before -> after) -> List after - debug_assert_eq!(args.len(), 2); + // List.mapWithIndex : List before, (Nat, before -> after) -> List after + debug_assert_eq!(args.len(), 3); let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + let (function_layout, function) = scope.function_pointers[&args[1]]; + + let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); match list_layout { Layout::Builtin(Builtin::EmptyList) => empty_list(env), - Layout::Builtin(Builtin::List(_, element_layout)) => { - list_map_with_index(env, layout_ids, func, func_layout, list, element_layout) - } + Layout::Builtin(Builtin::List(_, element_layout)) => list_map_with_index( + env, + layout_ids, + function, + function_layout, + closure, + *closure_layout, + list, + element_layout, + ), _ => unreachable!("invalid list layout"), } } ListKeepIf => { // List.keepIf : List elem, (elem -> Bool) -> List elem - debug_assert_eq!(args.len(), 2); + debug_assert_eq!(args.len(), 3); let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + let (function_layout, function) = scope.function_pointers[&args[1]]; + + let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); match list_layout { Layout::Builtin(Builtin::EmptyList) => empty_list(env), - Layout::Builtin(Builtin::List(_, element_layout)) => { - list_keep_if(env, layout_ids, func, func_layout, list, element_layout) - } + Layout::Builtin(Builtin::List(_, element_layout)) => list_keep_if( + env, + layout_ids, + function, + function_layout, + closure, + *closure_layout, + list, + element_layout, + ), _ => unreachable!("invalid list layout"), } } ListKeepOks => { // List.keepOks : List before, (before -> Result after *) -> List after - debug_assert_eq!(args.len(), 2); + debug_assert_eq!(args.len(), 3); let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + let (function_layout, function) = scope.function_pointers[&args[1]]; + + let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); match (list_layout, layout) { (_, Layout::Builtin(Builtin::EmptyList)) @@ -3837,8 +3863,10 @@ fn run_low_level<'a, 'ctx, 'env>( ) => list_keep_oks( env, layout_ids, - func, - func_layout, + function, + function_layout, + closure, + *closure_layout, list, before_layout, after_layout, @@ -3854,7 +3882,9 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + let (function_layout, function) = scope.function_pointers[&args[1]]; + + let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); match (list_layout, layout) { (_, Layout::Builtin(Builtin::EmptyList)) @@ -3865,8 +3895,10 @@ fn run_low_level<'a, 'ctx, 'env>( ) => list_keep_errs( env, layout_ids, - func, - func_layout, + function, + function_layout, + closure, + *closure_layout, list, before_layout, after_layout, diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 549134c120..1354f6af9f 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -652,20 +652,27 @@ pub fn list_contains<'a, 'ctx, 'env>( pub fn list_keep_if<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, + transform: FunctionValue<'ctx>, + transform_layout: Layout<'a>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, list: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { let builder = env.builder; - let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); - env.builder.build_store(transform_ptr, transform); + let closure_data_ptr = builder.build_alloca(closure_data.get_type(), "closure_data_ptr"); + env.builder.build_store(closure_data_ptr, closure_data); - let stepper_caller = - build_transform_caller(env, layout_ids, transform_layout, &[*element_layout]) - .as_global_value() - .as_pointer_value(); + let stepper_caller = build_transform_caller_new( + env, + layout_ids, + transform, + closure_data_layout, + &[*element_layout], + ) + .as_global_value() + .as_pointer_value(); let inc_element_fn = build_inc_wrapper(env, layout_ids, element_layout); let dec_element_fn = build_dec_wrapper(env, layout_ids, element_layout); @@ -674,7 +681,7 @@ pub fn list_keep_if<'a, 'ctx, 'env>( env, &[ pass_list_as_i128(env, list), - pass_as_opaque(env, transform_ptr), + pass_as_opaque(env, closure_data_ptr), stepper_caller.into(), alignment_intvalue(env, &element_layout), layout_width(env, element_layout), @@ -689,8 +696,10 @@ pub fn list_keep_if<'a, 'ctx, 'env>( pub fn list_keep_oks<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, + transform: FunctionValue<'ctx>, + transform_layout: Layout<'a>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, list: BasicValueEnum<'ctx>, before_layout: &Layout<'a>, after_layout: &Layout<'a>, @@ -700,6 +709,8 @@ pub fn list_keep_oks<'a, 'ctx, 'env>( layout_ids, transform, transform_layout, + closure_data, + closure_data_layout, list, before_layout, after_layout, @@ -711,8 +722,10 @@ pub fn list_keep_oks<'a, 'ctx, 'env>( pub fn list_keep_errs<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, + transform: FunctionValue<'ctx>, + transform_layout: Layout<'a>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, list: BasicValueEnum<'ctx>, before_layout: &Layout<'a>, after_layout: &Layout<'a>, @@ -722,6 +735,8 @@ pub fn list_keep_errs<'a, 'ctx, 'env>( layout_ids, transform, transform_layout, + closure_data, + closure_data_layout, list, before_layout, after_layout, @@ -732,8 +747,10 @@ pub fn list_keep_errs<'a, 'ctx, 'env>( pub fn list_keep_result<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, + transform: FunctionValue<'ctx>, + transform_layout: Layout<'a>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, list: BasicValueEnum<'ctx>, before_layout: &Layout<'a>, after_layout: &Layout<'a>, @@ -747,13 +764,18 @@ pub fn list_keep_result<'a, 'ctx, 'env>( _ => unreachable!("not a callable layout"), }; - let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); - env.builder.build_store(transform_ptr, transform); + let closure_data_ptr = builder.build_alloca(closure_data.get_type(), "closure_data_ptr"); + env.builder.build_store(closure_data_ptr, closure_data); - let stepper_caller = - build_transform_caller(env, layout_ids, transform_layout, &[*before_layout]) - .as_global_value() - .as_pointer_value(); + let stepper_caller = build_transform_caller_new( + env, + layout_ids, + transform, + closure_data_layout, + &[*before_layout], + ) + .as_global_value() + .as_pointer_value(); let before_width = env .ptr_int() @@ -767,14 +789,14 @@ pub fn list_keep_result<'a, 'ctx, 'env>( .ptr_int() .const_int(result_layout.stack_size(env.ptr_bytes) as u64, false); - let inc_closure = build_inc_wrapper(env, layout_ids, transform_layout); + let inc_closure = build_inc_wrapper(env, layout_ids, &transform_layout); let dec_result_fn = build_dec_wrapper(env, layout_ids, result_layout); call_bitcode_fn( env, &[ pass_list_as_i128(env, list), - pass_as_opaque(env, transform_ptr), + pass_as_opaque(env, closure_data_ptr), stepper_caller.into(), alignment_intvalue(env, &before_layout), before_width.into(), @@ -814,8 +836,8 @@ pub fn list_sort_with<'a, 'ctx, 'env>( ) } -/// List.map : List before, (before -> after) -> List after -pub fn list_map_new<'a, 'ctx, 'env>( +/// List.mapWithIndex : List before, (Nat, before -> after) -> List after +pub fn list_map_with_index<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, transform: FunctionValue<'ctx>, @@ -825,7 +847,32 @@ pub fn list_map_new<'a, 'ctx, 'env>( list: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - list_map_generic_new( + list_map_generic( + env, + layout_ids, + transform, + transform_layout, + list, + element_layout, + closure_data, + closure_data_layout, + bitcode::LIST_MAP_WITH_INDEX, + &[Layout::Builtin(Builtin::Usize), *element_layout], + ) +} + +/// List.map : List before, (before -> after) -> List after +pub fn list_map<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: FunctionValue<'ctx>, + transform_layout: Layout<'a>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, + list: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_map_generic( env, layout_ids, transform, @@ -839,7 +886,7 @@ pub fn list_map_new<'a, 'ctx, 'env>( ) } -fn list_map_generic_new<'a, 'ctx, 'env>( +fn list_map_generic<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, transform: FunctionValue<'ctx>, @@ -872,8 +919,6 @@ fn list_map_generic_new<'a, 'ctx, 'env>( .as_global_value() .as_pointer_value(); - dbg!(element_layout, layout_width(env, element_layout)); - call_bitcode_fn_returns_list( env, &[ @@ -888,93 +933,13 @@ fn list_map_generic_new<'a, 'ctx, 'env>( ) } -/// List.map : List before, (before -> after) -> List after -pub fn list_map<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, - list: BasicValueEnum<'ctx>, - element_layout: &Layout<'a>, -) -> BasicValueEnum<'ctx> { - list_map_generic( - env, - layout_ids, - transform, - transform_layout, - list, - element_layout, - bitcode::LIST_MAP, - &[*element_layout], - ) -} - -/// List.mapWithIndex : List before, (Nat, before -> after) -> List after -pub fn list_map_with_index<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, - list: BasicValueEnum<'ctx>, - element_layout: &Layout<'a>, -) -> BasicValueEnum<'ctx> { - list_map_generic( - env, - layout_ids, - transform, - transform_layout, - list, - element_layout, - bitcode::LIST_MAP_WITH_INDEX, - &[Layout::Builtin(Builtin::Usize), *element_layout], - ) -} - -fn list_map_generic<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, - list: BasicValueEnum<'ctx>, - element_layout: &Layout<'a>, - op: &str, - argument_layouts: &[Layout<'a>], -) -> BasicValueEnum<'ctx> { - let builder = env.builder; - - let return_layout = match transform_layout { - Layout::FunctionPointer(_, ret) => ret, - Layout::Closure(_, _, ret) => ret, - _ => unreachable!("not a callable layout"), - }; - - let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); - env.builder.build_store(transform_ptr, transform); - - let stepper_caller = - build_transform_caller(env, layout_ids, transform_layout, argument_layouts) - .as_global_value() - .as_pointer_value(); - - call_bitcode_fn_returns_list( - env, - &[ - pass_list_as_i128(env, list), - pass_as_opaque(env, transform_ptr), - stepper_caller.into(), - alignment_intvalue(env, &element_layout), - layout_width(env, element_layout), - layout_width(env, return_layout), - ], - op, - ) -} - pub fn list_map2<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, + transform: FunctionValue<'ctx>, + transform_layout: Layout<'a>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, list1: BasicValueEnum<'ctx>, list2: BasicValueEnum<'ctx>, element1_layout: &Layout<'a>, @@ -988,14 +953,18 @@ pub fn list_map2<'a, 'ctx, 'env>( _ => unreachable!("not a callable layout"), }; - let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); - env.builder.build_store(transform_ptr, transform); + let closure_data_ptr = builder.build_alloca(closure_data.get_type(), "closure_data_ptr"); + env.builder.build_store(closure_data_ptr, closure_data); - let argument_layouts = [*element1_layout, *element2_layout]; - let stepper_caller = - build_transform_caller(env, layout_ids, transform_layout, &argument_layouts) - .as_global_value() - .as_pointer_value(); + let stepper_caller = build_transform_caller_new( + env, + layout_ids, + transform, + closure_data_layout, + &[*element1_layout, *element2_layout], + ) + .as_global_value() + .as_pointer_value(); let a_width = env .ptr_int() @@ -1017,7 +986,7 @@ pub fn list_map2<'a, 'ctx, 'env>( &[ pass_list_as_i128(env, list1), pass_list_as_i128(env, list2), - pass_as_opaque(env, transform_ptr), + pass_as_opaque(env, closure_data_ptr), stepper_caller.into(), alignment_intvalue(env, &transform_layout), a_width.into(), @@ -1033,8 +1002,10 @@ pub fn list_map2<'a, 'ctx, 'env>( pub fn list_map3<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - transform: BasicValueEnum<'ctx>, - transform_layout: &Layout<'a>, + transform: FunctionValue<'ctx>, + transform_layout: Layout<'a>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, list1: BasicValueEnum<'ctx>, list2: BasicValueEnum<'ctx>, list3: BasicValueEnum<'ctx>, @@ -1050,14 +1021,18 @@ pub fn list_map3<'a, 'ctx, 'env>( _ => unreachable!("not a callable layout"), }; - let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); - env.builder.build_store(transform_ptr, transform); + let closure_data_ptr = builder.build_alloca(closure_data.get_type(), "closure_data_ptr"); + env.builder.build_store(closure_data_ptr, closure_data); - let argument_layouts = [*element1_layout, *element2_layout, *element3_layout]; - let stepper_caller = - build_transform_caller(env, layout_ids, transform_layout, &argument_layouts) - .as_global_value() - .as_pointer_value(); + let stepper_caller = build_transform_caller_new( + env, + layout_ids, + transform, + closure_data_layout, + &[*element1_layout, *element2_layout, *element3_layout], + ) + .as_global_value() + .as_pointer_value(); let a_width = env .ptr_int() @@ -1085,9 +1060,9 @@ pub fn list_map3<'a, 'ctx, 'env>( pass_list_as_i128(env, list1), pass_list_as_i128(env, list2), pass_list_as_i128(env, list3), - pass_as_opaque(env, transform_ptr), + pass_as_opaque(env, closure_data_ptr), stepper_caller.into(), - alignment_intvalue(env, transform_layout), + alignment_intvalue(env, &transform_layout), a_width.into(), b_width.into(), c_width.into(), diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index cec5d26086..00157e72e3 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -631,6 +631,8 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { // TODO is true or false more efficient for non-refcounted layouts? let irrelevant = false; + let function = irrelevant; + let closure_data = irrelevant; let owned = false; let borrowed = true; @@ -653,16 +655,18 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ListPrepend => arena.alloc_slice_copy(&[owned, owned]), StrJoinWith => arena.alloc_slice_copy(&[borrowed, borrowed]), ListJoin => arena.alloc_slice_copy(&[irrelevant]), - ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]), - ListMap2 => arena.alloc_slice_copy(&[owned, owned, irrelevant]), - ListMap3 => arena.alloc_slice_copy(&[owned, owned, owned, irrelevant]), - ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, borrowed]), + ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, function, closure_data]), + ListMap2 => arena.alloc_slice_copy(&[owned, owned, function, closure_data]), + ListMap3 => arena.alloc_slice_copy(&[owned, owned, owned, function, closure_data]), + ListKeepIf | ListKeepOks | ListKeepErrs => { + arena.alloc_slice_copy(&[owned, function, closure_data]) + } ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListRange => arena.alloc_slice_copy(&[irrelevant, irrelevant]), ListWalk | ListWalkUntil | ListWalkBackwards => { - arena.alloc_slice_copy(&[owned, irrelevant, owned]) + arena.alloc_slice_copy(&[owned, owned, function, closure_data]) } - ListSortWith => arena.alloc_slice_copy(&[owned, irrelevant]), + ListSortWith => arena.alloc_slice_copy(&[owned, function, closure_data]), // TODO when we have lists with capacity (if ever) // List.append should own its first argument @@ -695,7 +699,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { DictUnion | DictDifference | DictIntersection => arena.alloc_slice_copy(&[owned, borrowed]), // borrow function argument so we don't have to worry about RC of the closure - DictWalk => arena.alloc_slice_copy(&[owned, borrowed, owned]), + DictWalk => arena.alloc_slice_copy(&[owned, owned, function, closure_data]), SetFromList => arena.alloc_slice_copy(&[owned]), diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index bbd2cdfac1..87d75159ad 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -2550,6 +2550,38 @@ fn specialize_naked_symbol<'a>( ) } +macro_rules! match_on_closure_argument { + ($env:expr, $procs:expr, $layout_cache:expr, $closure_data_symbol:expr, $closure_data_var:expr, $op:expr, [$($x:expr),* $(,)?], $layout: expr, $assigned:expr, $hole:expr) => {{ + let closure_data_layout = return_on_layout_error!( + $env, + $layout_cache.from_var($env.arena, $closure_data_var, $env.subs) + ); + + let top_level = TopLevelFunctionLayout::from_layout($env.arena, closure_data_layout); + + let arena = $env.arena; + + match closure_data_layout { + Layout::Closure(_, lambda_set, _) => { + lowlevel_match_on_lambda_set( + $env, + lambda_set, + $closure_data_symbol, + |top_level_function, closure_data| self::Call { + call_type: CallType::LowLevel { op: $op }, + arguments: arena.alloc([$($x,)* top_level_function, closure_data]), + }, + arena.alloc(top_level).full(), + $layout, + $assigned, + $hole, + ) + } + _ => unreachable!(), + } + }}; +} + pub fn with_hole<'a>( env: &mut Env<'a, '_>, can_expr: roc_can::expr::Expr, @@ -3869,7 +3901,7 @@ pub fn with_hole<'a>( } Call(boxed, loc_args, _) => { - let (fn_var, loc_expr, _closure_var, ret_var) = *boxed; + let (fn_var, loc_expr, lambda_set_var, ret_var) = *boxed; // even if a call looks like it's by name, it may in fact be by-pointer. // E.g. in `(\f, x -> f x)` the call is in fact by pointer. @@ -3986,6 +4018,9 @@ pub fn with_hole<'a>( Layout::Closure(_, lambda_set, _) => { let closure_data_symbol = env.unique_symbol(); + let top_level = + TopLevelFunctionLayout::from_layout(env.arena, full_layout); + result = match_on_lambda_set( env, lambda_set, @@ -4099,55 +4134,86 @@ pub fn with_hole<'a>( use LowLevel::*; match op { - ListMap => { + ListMap | ListMapWithIndex | ListKeepIf | ListKeepOks | ListKeepErrs + | ListSortWith => { debug_assert_eq!(arg_symbols.len(), 2); - let list_symbol = arg_symbols[0]; - let closure_data_symbol = arg_symbols[1]; + let closure_index = 1; + let closure_data_symbol = arg_symbols[closure_index]; + let closure_data_var = args[closure_index].0; - let closure_data_layout = return_on_layout_error!( + match_on_closure_argument!( env, - layout_cache.from_var(env.arena, args[1].0, env.subs) - ); + procs, + layout_cache, + closure_data_symbol, + closure_data_var, + op, + [arg_symbols[0]], + layout, + assigned, + hole + ) + } + ListWalk | ListWalkUntil | ListWalkBackwards | DictWalk => { + debug_assert_eq!(arg_symbols.len(), 3); - let top_level = - TopLevelFunctionLayout::from_layout(env.arena, closure_data_layout); + let closure_index = 1; + let closure_data_symbol = arg_symbols[closure_index]; + let closure_data_var = args[closure_index].0; - let arena = env.arena; + match_on_closure_argument!( + env, + procs, + layout_cache, + closure_data_symbol, + closure_data_var, + op, + [arg_symbols[0], arg_symbols[2]], + layout, + assigned, + hole + ) + } + ListMap2 => { + debug_assert_eq!(arg_symbols.len(), 3); - match closure_data_layout { - Layout::Closure(argument_layouts, lambda_set, return_layout) => { - // specialize the possible options for the function - for (function_symbol, _) in lambda_set.set { - procs.insert_passed_by_name( - env, - args[1].0, - *function_symbol, - top_level, - layout_cache, - ); - } + let closure_index = 2; + let closure_data_symbol = arg_symbols[closure_index]; + let closure_data_var = args[closure_index].0; - lowlevel_match_on_lambda_set( - env, - lambda_set, - closure_data_symbol, - |top_level_function, closure_data| self::Call { - call_type: CallType::LowLevel { op }, - arguments: arena.alloc([ - list_symbol, - top_level_function, - closure_data, - ]), - }, - arena.alloc(top_level).full(), - layout, - assigned, - hole, - ) - } - _ => unreachable!(), - } + match_on_closure_argument!( + env, + procs, + layout_cache, + closure_data_symbol, + closure_data_var, + op, + [arg_symbols[0], arg_symbols[1]], + layout, + assigned, + hole + ) + } + ListMap3 => { + debug_assert_eq!(arg_symbols.len(), 4); + + let closure_index = 3; + let closure_data_symbol = arg_symbols[closure_index]; + let closure_data_var = args[closure_index].0; + + match_on_closure_argument!( + env, + procs, + layout_cache, + closure_data_symbol, + closure_data_var, + op, + [arg_symbols[0], arg_symbols[1], arg_symbols[2]], + layout, + assigned, + hole + ) } _ => { let call = self::Call { @@ -5648,6 +5714,7 @@ fn reuse_function_symbol<'a>( env.arena.alloc(result), ); } + _ => { // danger: a foreign symbol may not be specialized! debug_assert!( @@ -5677,34 +5744,23 @@ fn reuse_function_symbol<'a>( match res_layout { Ok(Layout::Closure(argument_layouts, lambda_set, ret_layout)) => { + // define the function pointer + let function_ptr_layout = + TopLevelFunctionLayout::from_layout(env.arena, res_layout.unwrap()); + + procs.insert_passed_by_name( + env, + arg_var, + original, + function_ptr_layout, + layout_cache, + ); + if captures { // this is a closure by capture, meaning it itself captures local variables. let closure_data = symbol; - // let closure_data_layout = closure_layout.as_named_layout(original); - let closure_data_layout = lambda_set.runtime_representation(); - - // define the function pointer - let function_ptr_layout = TopLevelFunctionLayout::from_layout( - env.arena, - lambda_set.extend_function_layout( - env.arena, - argument_layouts, - ret_layout, - ), - ); - - procs.insert_passed_by_name( - env, - arg_var, - original, - function_ptr_layout, - layout_cache, - ); - - // define the closure data - let symbols = match captured { CapturedSymbols::Captured(captured_symbols) => { Vec::from_iter_in(captured_symbols.iter().map(|x| x.0), env.arena) @@ -6143,8 +6199,6 @@ fn call_by_name_help<'a>( maybe_closure_layout, layout ); - dbg!(maybe_closure_layout, field_symbols, &loc_args); - call_specialized_proc( env, procs, @@ -7497,6 +7551,38 @@ where env.arena.alloc(stmt), ) } + Layout::Builtin(Builtin::Int1) => { + let closure_tag_id_symbol = closure_data_symbol; + + lowlevel_enum_lambda_set_to_switch( + env, + lambda_set.set, + closure_tag_id_symbol, + Layout::Builtin(Builtin::Int1), + closure_data_symbol, + to_lowlevel_call, + function_layout, + return_layout, + assigned, + hole, + ) + } + Layout::Builtin(Builtin::Int8) => { + let closure_tag_id_symbol = closure_data_symbol; + + lowlevel_enum_lambda_set_to_switch( + env, + lambda_set.set, + closure_tag_id_symbol, + Layout::Builtin(Builtin::Int8), + closure_data_symbol, + to_lowlevel_call, + function_layout, + return_layout, + assigned, + hole, + ) + } other => todo!("{:?}", other), } } @@ -7538,7 +7624,7 @@ where ); let stmt = Stmt::Let( - assigned, + bound, Expr::FunctionPointer(*function_symbol, function_layout), function_layout, env.arena.alloc(stmt), @@ -7989,3 +8075,103 @@ fn enum_lambda_set_branch_help<'a>( hole, ) } + +fn lowlevel_enum_lambda_set_to_switch<'a, F>( + env: &mut Env<'a, '_>, + lambda_set: &'a [(Symbol, &'a [Layout<'a>])], + closure_tag_id_symbol: Symbol, + closure_tag_id_layout: Layout<'a>, + closure_data_symbol: Symbol, + to_lowlevel_call: F, + function_layout: Layout<'a>, + return_layout: Layout<'a>, + assigned: Symbol, + hole: &'a Stmt<'a>, +) -> Stmt<'a> +where + F: Fn(Symbol, Symbol) -> Call<'a> + Copy, +{ + debug_assert!(!lambda_set.is_empty()); + + let join_point_id = JoinPointId(env.unique_symbol()); + + let mut branches = Vec::with_capacity_in(lambda_set.len(), env.arena); + + let closure_layout = closure_tag_id_layout; + + for (i, (function_symbol, _)) in lambda_set.iter().enumerate() { + let stmt = lowlevel_enum_lambda_set_branch( + env, + join_point_id, + *function_symbol, + closure_data_symbol, + closure_layout, + to_lowlevel_call, + function_layout, + return_layout, + ); + branches.push((i as u64, BranchInfo::None, stmt)); + } + + let default_branch = { + let (_, info, stmt) = branches.pop().unwrap(); + + (info, &*env.arena.alloc(stmt)) + }; + + let switch = Stmt::Switch { + cond_symbol: closure_tag_id_symbol, + cond_layout: closure_tag_id_layout, + branches: branches.into_bump_slice(), + default_branch, + ret_layout: return_layout, + }; + + let param = Param { + symbol: assigned, + layout: return_layout, + borrow: false, + }; + + Stmt::Join { + id: join_point_id, + parameters: &*env.arena.alloc([param]), + continuation: hole, + remainder: env.arena.alloc(switch), + } +} + +fn lowlevel_enum_lambda_set_branch<'a, F>( + env: &mut Env<'a, '_>, + join_point_id: JoinPointId, + function_symbol: Symbol, + closure_data_symbol: Symbol, + closure_data_layout: Layout<'a>, + to_lowlevel_call: F, + function_layout: Layout<'a>, + return_layout: Layout<'a>, +) -> Stmt<'a> +where + F: Fn(Symbol, Symbol) -> Call<'a> + Copy, +{ + let result_symbol = env.unique_symbol(); + + let hole = Stmt::Jump(join_point_id, env.arena.alloc([result_symbol])); + + let bound = env.unique_symbol(); + + // build the call + let stmt = Stmt::Let( + result_symbol, + Expr::Call(to_lowlevel_call(bound, closure_data_symbol)), + return_layout, + env.arena.alloc(hole), + ); + + Stmt::Let( + bound, + Expr::FunctionPointer(function_symbol, function_layout), + function_layout, + env.arena.alloc(stmt), + ) +} diff --git a/compiler/test_gen/src/gen_list.rs b/compiler/test_gen/src/gen_list.rs index 124f719a6b..96cddca5c8 100644 --- a/compiler/test_gen/src/gen_list.rs +++ b/compiler/test_gen/src/gen_list.rs @@ -629,7 +629,8 @@ fn list_map2_pair() { assert_evals_to!( indoc!( r#" - List.map2 [1,2,3] [3,2,1] (\a,b -> Pair a b) + f = (\a,b -> Pair a b) + List.map2 [1,2,3] [3,2,1] f "# ), RocList::from_slice(&[(1, 3), (2, 2), (3, 1)]), @@ -645,7 +646,7 @@ fn list_map2_different_lengths() { List.map2 ["a", "b", "lllllllllllllongnggg" ] ["b"] - Str.concat + (\a, b -> Str.concat a b) "# ), RocList::from_slice(&[RocStr::from_slice("ab".as_bytes()),]),