diff --git a/cli/tests/cli_run.rs b/cli/tests/cli_run.rs index fc2afe6bbb..42d55e4ec4 100644 --- a/cli/tests/cli_run.rs +++ b/cli/tests/cli_run.rs @@ -275,6 +275,18 @@ mod cli_run { ); } + #[test] + #[serial(hof)] + fn hof_closures() { + check_output( + &example_file("benchmarks", "HofClosures.roc"), + "hof-closures", + &["--optimize"], + "", + false, + ); + } + // #[test] // #[serial(effect)] // fn run_effect_unoptimized() { diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index b5aeb5ce4d..c9a058e653 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -184,7 +184,10 @@ pub const RocFunctionCall1 = extern struct { pub fn listMap( list: RocList, - call: RocFunctionCall1, + caller: Caller1, + data: Opaque, + inc_n_data: IncN, + data_is_owned: bool, alignment: usize, old_element_width: usize, new_element_width: usize, @@ -195,12 +198,12 @@ pub fn listMap( const output = RocList.allocate(std.heap.c_allocator, alignment, size, new_element_width); const target_ptr = output.bytes orelse unreachable; - if (call.data_is_owned) { - call.inc_n_data(call.data, size); + if (data_is_owned) { + inc_n_data(data, size); } while (i < size) : (i += 1) { - call.caller(call.data, source_ptr + (i * old_element_width), target_ptr + (i * new_element_width)); + caller(data, source_ptr + (i * old_element_width), target_ptr + (i * new_element_width)); } return output; @@ -209,18 +212,29 @@ pub fn listMap( } } -pub fn listMapWithIndex(list: RocList, transform: Opaque, caller: Caller2, alignment: usize, old_element_width: usize, new_element_width: usize) callconv(.C) RocList { +pub fn listMapWithIndex( + list: RocList, + caller: Caller2, + data: Opaque, + inc_n_data: IncN, + data_is_owned: bool, + alignment: usize, + old_element_width: usize, + new_element_width: usize, +) callconv(.C) RocList { if (list.bytes) |source_ptr| { const size = list.len(); var i: usize = 0; const output = RocList.allocate(std.heap.c_allocator, alignment, size, new_element_width); const target_ptr = output.bytes orelse unreachable; - while (i < size) : (i += 1) { - caller(transform, @ptrCast(?[*]u8, &i), source_ptr + (i * old_element_width), target_ptr + (i * new_element_width)); + if (data_is_owned) { + inc_n_data(data, size); } - utils.decref(std.heap.c_allocator, alignment, list.bytes, size * old_element_width); + while (i < size) : (i += 1) { + caller(data, @ptrCast(?[*]u8, &i), source_ptr + (i * old_element_width), target_ptr + (i * new_element_width)); + } return output; } else { diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 48bea6a52c..1796103b0c 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -558,7 +558,7 @@ pub fn construct_optimization_passes<'a>( pmb.set_optimization_level(OptimizationLevel::None); } OptLevel::Optimize => { - pmb.set_optimization_level(OptimizationLevel::Aggressive); + pmb.set_optimization_level(OptimizationLevel::Less); // this threshold seems to do what we want pmb.set_inliner_with_threshold(275); @@ -3586,13 +3586,58 @@ pub static C_CALL_CONV: u32 = 0; pub static FAST_CALL_CONV: u32 = 8; pub static COLD_CALL_CONV: u32 = 9; +pub struct RocFunctionCall<'ctx> { + pub caller: PointerValue<'ctx>, + pub data: PointerValue<'ctx>, + pub inc_n_data: PointerValue<'ctx>, + pub data_is_owned: IntValue<'ctx>, +} + +fn roc_function_call<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: FunctionValue<'ctx>, + closure_data: BasicValueEnum<'ctx>, + closure_data_layout: Layout<'a>, + closure_data_is_owned: bool, + argument_layouts: &[Layout<'a>], +) -> RocFunctionCall<'ctx> { + use crate::llvm::bitcode::{build_inc_n_wrapper, build_transform_caller_new}; + + let closure_data_ptr = env + .builder + .build_alloca(closure_data.get_type(), "closure_data_ptr"); + env.builder.build_store(closure_data_ptr, closure_data); + + let stepper_caller = + build_transform_caller_new(env, transform, closure_data_layout, argument_layouts) + .as_global_value() + .as_pointer_value(); + + let inc_closure_data = build_inc_n_wrapper(env, layout_ids, &closure_data_layout) + .as_global_value() + .as_pointer_value(); + + let closure_data_is_owned = env + .context + .bool_type() + .const_int(closure_data_is_owned as u64, false); + + RocFunctionCall { + caller: stepper_caller, + inc_n_data: inc_closure_data, + data_is_owned: closure_data_is_owned, + data: closure_data_ptr, + } +} + #[allow(clippy::too_many_arguments)] fn run_higher_order_low_level<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, scope: &Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, - layout: &Layout<'a>, + return_layout: &Layout<'a>, op: LowLevel, function_layout: Layout<'a>, function_owns_closure_data: bool, @@ -3667,17 +3712,21 @@ fn run_higher_order_low_level<'a, 'ctx, 'env>( match list_layout { Layout::Builtin(Builtin::EmptyList) => empty_list(env), - Layout::Builtin(Builtin::List(_, element_layout)) => list_map( - env, - layout_ids, - function, - function_layout, - closure, - *closure_layout, - function_owns_closure_data, - list, - element_layout, - ), + Layout::Builtin(Builtin::List(_, element_layout)) => { + let argument_layouts = &[**element_layout]; + + let roc_function_call = roc_function_call( + env, + layout_ids, + function, + closure, + *closure_layout, + function_owns_closure_data, + argument_layouts, + ); + + list_map(env, roc_function_call, list, element_layout, return_layout) + } _ => unreachable!("invalid list layout"), } } @@ -3758,15 +3807,21 @@ fn run_higher_order_low_level<'a, 'ctx, 'env>( match list_layout { Layout::Builtin(Builtin::EmptyList) => empty_list(env), - Layout::Builtin(Builtin::List(_, element_layout)) => list_map_with_index( - env, - function, - function_layout, - closure, - *closure_layout, - list, - element_layout, - ), + Layout::Builtin(Builtin::List(_, element_layout)) => { + let argument_layouts = &[Layout::Builtin(Builtin::Usize), **element_layout]; + + let roc_function_call = roc_function_call( + env, + layout_ids, + function, + closure, + *closure_layout, + function_owns_closure_data, + argument_layouts, + ); + + list_map_with_index(env, roc_function_call, list, element_layout, return_layout) + } _ => unreachable!("invalid list layout"), } } @@ -3804,7 +3859,7 @@ fn run_higher_order_low_level<'a, 'ctx, 'env>( let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); - match (list_layout, layout) { + match (list_layout, return_layout) { (_, Layout::Builtin(Builtin::EmptyList)) | (Layout::Builtin(Builtin::EmptyList), _) => empty_list(env), ( @@ -3836,7 +3891,7 @@ fn run_higher_order_low_level<'a, 'ctx, 'env>( let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); - match (list_layout, layout) { + match (list_layout, return_layout) { (_, Layout::Builtin(Builtin::EmptyList)) | (Layout::Builtin(Builtin::EmptyList), _) => empty_list(env), ( diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 1c7796e852..cdd4dd160d 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -4,7 +4,7 @@ use crate::llvm::bitcode::{ build_inc_wrapper, build_transform_caller_new, call_bitcode_fn, call_void_bitcode_fn, }; use crate::llvm::build::{ - allocate_with_refcount_help, cast_basic_basic, complex_bitcast, Env, InPlace, + allocate_with_refcount_help, cast_basic_basic, complex_bitcast, Env, InPlace, RocFunctionCall, }; use crate::llvm::convert::{basic_type_from_layout, get_ptr_type}; use crate::llvm::refcounting::{ @@ -838,115 +838,43 @@ pub fn list_sort_with<'a, 'ctx, 'env>( /// List.mapWithIndex : List before, (Nat, before -> after) -> List after pub fn list_map_with_index<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - transform: FunctionValue<'ctx>, - transform_layout: Layout<'a>, - closure_data: BasicValueEnum<'ctx>, - closure_data_layout: Layout<'a>, + roc_function_call: RocFunctionCall<'ctx>, list: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, + return_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - list_map_generic( + call_bitcode_fn_returns_list( env, - transform, - transform_layout, - list, - element_layout, - closure_data, - closure_data_layout, + &[ + pass_list_as_i128(env, list), + roc_function_call.caller.into(), + pass_as_opaque(env, roc_function_call.data), + roc_function_call.inc_n_data.into(), + roc_function_call.data_is_owned.into(), + alignment_intvalue(env, &element_layout), + layout_width(env, element_layout), + layout_width(env, return_layout), + ], bitcode::LIST_MAP_WITH_INDEX, - &[Layout::Builtin(Builtin::Usize), *element_layout], ) } -fn roc_function_call_1<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - transform: FunctionValue<'ctx>, - closure_data: BasicValueEnum<'ctx>, - closure_data_layout: Layout<'a>, - closure_data_is_owned: bool, - argument_layouts: &[Layout<'a>], -) -> PointerValue<'ctx> { - // %list.RocFunctionCall1 = type { void (i8*, i8*, i8*)*, i8*, void (i8*, i64)*, i1 } - let struct_type = env.module.get_struct_type("list.RocFunctionCall1").unwrap(); - - let builder = env.builder; - - let closure_data_ptr = builder.build_alloca(closure_data.get_type(), "closure_data_ptr"); - env.builder.build_store(closure_data_ptr, closure_data); - - let stepper_caller = - build_transform_caller_new(env, transform, closure_data_layout, argument_layouts) - .as_global_value() - .as_pointer_value(); - - let inc_closure_data = build_inc_n_wrapper(env, layout_ids, &closure_data_layout) - .as_global_value() - .as_pointer_value(); - - let closure_data_is_owned = env - .context - .bool_type() - .const_int(closure_data_is_owned as u64, false); - - let mut struct_value = builder - .build_insert_value(struct_type.get_undef(), stepper_caller, 0, "") - .unwrap(); - - struct_value = builder - .build_insert_value(struct_value, pass_as_opaque(env, closure_data_ptr), 1, "") - .unwrap(); - - struct_value = builder - .build_insert_value(struct_value, inc_closure_data, 2, "") - .unwrap(); - - struct_value = builder - .build_insert_value(struct_value, closure_data_is_owned, 3, "") - .unwrap(); - - let ptr = env.builder.build_alloca(struct_type, "roc_function_call_1"); - - env.builder.build_store(ptr, struct_value); - - ptr -} - /// List.map : List before, (before -> after) -> List after pub fn list_map<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - transform: FunctionValue<'ctx>, - transform_layout: Layout<'a>, - closure_data: BasicValueEnum<'ctx>, - closure_data_layout: Layout<'a>, - closure_data_is_owned: bool, + roc_function_call: RocFunctionCall<'ctx>, list: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, + return_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - let argument_layouts = &[*element_layout]; - - let return_layout = match transform_layout { - Layout::FunctionPointer(_, ret) => ret, - Layout::Closure(_, _, ret) => ret, - _ => unreachable!("not a callable layout"), - }; - - let roc_function_call = roc_function_call_1( - env, - layout_ids, - transform, - closure_data, - closure_data_layout, - closure_data_is_owned, - argument_layouts, - ); - call_bitcode_fn_returns_list( env, &[ pass_list_as_i128(env, list), - roc_function_call.into(), + roc_function_call.caller.into(), + pass_as_opaque(env, roc_function_call.data), + roc_function_call.inc_n_data.into(), + roc_function_call.data_is_owned.into(), alignment_intvalue(env, &element_layout), layout_width(env, element_layout), layout_width(env, return_layout), @@ -955,47 +883,6 @@ pub fn list_map<'a, 'ctx, 'env>( ) } -fn list_map_generic<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - transform: FunctionValue<'ctx>, - transform_layout: Layout<'a>, - list: BasicValueEnum<'ctx>, - element_layout: &Layout<'a>, - closure_data: BasicValueEnum<'ctx>, - closure_data_layout: Layout<'a>, - op: &str, - argument_layouts: &[Layout<'a>], -) -> BasicValueEnum<'ctx> { - let builder = env.builder; - - let return_layout = match transform_layout { - Layout::FunctionPointer(_, ret) => ret, - Layout::Closure(_, _, ret) => ret, - _ => unreachable!("not a callable layout"), - }; - - let closure_data_ptr = builder.build_alloca(closure_data.get_type(), "closure_data_ptr"); - env.builder.build_store(closure_data_ptr, closure_data); - - let stepper_caller = - build_transform_caller_new(env, transform, closure_data_layout, argument_layouts) - .as_global_value() - .as_pointer_value(); - - call_bitcode_fn_returns_list( - env, - &[ - pass_list_as_i128(env, list), - pass_as_opaque(env, closure_data_ptr), - stepper_caller.into(), - alignment_intvalue(env, &element_layout), - layout_width(env, element_layout), - layout_width(env, return_layout), - ], - op, - ) -} - pub fn list_map2<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, diff --git a/compiler/load/src/file.rs b/compiler/load/src/file.rs index 29a34f449a..816d57ad17 100644 --- a/compiler/load/src/file.rs +++ b/compiler/load/src/file.rs @@ -2048,7 +2048,6 @@ fn update<'a>( && state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations { - Proc::insert_refcount_operations(arena, &mut state.procedures); // display the mono IR of the module, for debug purposes if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS { let procs_string = state @@ -2068,6 +2067,8 @@ fn update<'a>( // &mut state.procedures, // ); + Proc::insert_refcount_operations(arena, &mut state.procedures); + state.constrained_ident_ids.insert(module_id, ident_ids); for (module_id, requested) in external_specializations_requested { diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index ca1596ab57..9704fb6fd4 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -425,15 +425,74 @@ impl<'a> BorrowInfState<'a> { debug_assert!(op.is_higher_order()); match op { - ListMap => match self.param_map.get_symbol(arguments[1], *closure_layout) { - Some(ps) => { - if !ps[0].borrow { + ListMap | ListSortWith => { + match self.param_map.get_symbol(arguments[1], *closure_layout) { + Some(function_ps) => { + // own the list if the function wants to own the element + if !function_ps[0].borrow { + self.own_var(arguments[0]); + } + + // own the closure environment if the function needs to own it + if let Some(false) = function_ps.get(1).map(|p| p.borrow) { + self.own_var(arguments[2]); + } + } + None => unreachable!(), + } + } + ListMapWithIndex => { + match self.param_map.get_symbol(arguments[1], *closure_layout) { + Some(function_ps) => { + // own the list if the function wants to own the element + if !function_ps[1].borrow { + self.own_var(arguments[0]); + } + + // own the closure environment if the function needs to own it + if let Some(false) = function_ps.get(2).map(|p| p.borrow) { + self.own_var(arguments[2]); + } + } + None => unreachable!(), + } + } + ListMap2 => match self.param_map.get_symbol(arguments[2], *closure_layout) { + Some(function_ps) => { + // own the lists if the function wants to own the element + if !function_ps[0].borrow { self.own_var(arguments[0]); } - if ps.len() > 1 && !ps[1].borrow { + if !function_ps[1].borrow { + self.own_var(arguments[1]); + } + + // own the closure environment if the function needs to own it + if let Some(false) = function_ps.get(2).map(|p| p.borrow) { + self.own_var(arguments[3]); + } + } + None => unreachable!(), + }, + ListMap3 => match self.param_map.get_symbol(arguments[3], *closure_layout) { + Some(function_ps) => { + // own the lists if the function wants to own the element + if !function_ps[0].borrow { + self.own_var(arguments[0]); + } + + if !function_ps[1].borrow { + self.own_var(arguments[1]); + } + if !function_ps[2].borrow { self.own_var(arguments[2]); } + + // own the closure environment if the function needs to own it + if let Some(false) = function_ps.get(3).map(|p| p.borrow) { + self.own_var(arguments[4]); + } } None => unreachable!(), }, diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index 3a7dacc1a4..3ae307ee52 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -416,12 +416,6 @@ impl<'a> Context<'a> { and it has been borrowed by the application. Remark: `x` may occur multiple times in the application (e.g., `f x y x`). This is why we check whether it is the first occurrence. */ - dbg!( - self.must_consume(*x), - is_first_occurence(xs, i), - *is_borrow, - !b_live_vars.contains(x) - ); if self.must_consume(*x) && is_first_occurence(xs, i) @@ -461,61 +455,124 @@ impl<'a> Context<'a> { HigherOrderLowLevel { op, closure_layout, .. - } => match op { - roc_module::low_level::LowLevel::ListMap => { - match self.param_map.get_symbol(arguments[1], *closure_layout) { - Some(ps) => { - let b = if ps[0].borrow { - let ps = [BORROWED, BORROWED, BORROWED]; - println!("----------------"); - self.add_dec_after_lowlevel(arguments, &ps, b, b_live_vars) + } => { + macro_rules! create_call { + ($borrows:expr) => { + Expr::Call(crate::ir::Call { + call_type: if $borrows { + call_type } else { - let ps = [OWNED, BORROWED, BORROWED]; - println!("----------------"); - let b = self.add_dec_after_lowlevel(arguments, &ps, b, b_live_vars); - - self.arena.alloc(Stmt::Refcounting( - ModifyRc::DecRef(arguments[0]), - self.arena.alloc(b), - )) - }; - - dbg!(self.must_consume(arguments[2])); - - let call_type = { - if ps[1].borrow { - call_type - } else { - HigherOrderLowLevel { - op: *op, - closure_layout: *closure_layout, - function_owns_closure_data: true, - } + HigherOrderLowLevel { + op: *op, + closure_layout: *closure_layout, + function_owns_closure_data: true, } - }; + }, + arguments, + }) + }; + } - let v = Expr::Call(crate::ir::Call { - call_type, - arguments, - }); - - &*self.arena.alloc(Stmt::Let(z, v, l, b)) + macro_rules! decref_if_owned { + ($borrows:expr, $argument:expr, $stmt:expr) => { + if !$borrows { + self.arena.alloc(Stmt::Refcounting( + ModifyRc::DecRef($argument), + self.arena.alloc($stmt), + )) + } else { + $stmt } - None => unreachable!(), + }; + } + + const FUNCTION: bool = BORROWED; + const CLOSURE_DATA: bool = BORROWED; + + match op { + roc_module::low_level::LowLevel::ListMap => { + match self.param_map.get_symbol(arguments[1], *closure_layout) { + Some(function_ps) => { + let borrows = [function_ps[0].borrow, FUNCTION, CLOSURE_DATA]; + + let b = self.add_dec_after_lowlevel( + arguments, + &borrows, + b, + b_live_vars, + ); + + // if the list is owned, then all elements have been consumed, but not the list itself + let b = decref_if_owned!(function_ps[0].borrow, arguments[0], b); + + let v = create_call!(function_ps[1].borrow); + + &*self.arena.alloc(Stmt::Let(z, v, l, b)) + } + None => unreachable!(), + } + } + roc_module::low_level::LowLevel::ListMapWithIndex => { + match self.param_map.get_symbol(arguments[1], *closure_layout) { + Some(function_ps) => { + let borrows = [function_ps[1].borrow, FUNCTION, CLOSURE_DATA]; + + let b = self.add_dec_after_lowlevel( + arguments, + &borrows, + b, + b_live_vars, + ); + + let b = decref_if_owned!(function_ps[1].borrow, arguments[0], b); + + let v = create_call!(function_ps[2].borrow); + + &*self.arena.alloc(Stmt::Let(z, v, l, b)) + } + None => unreachable!(), + } + } + roc_module::low_level::LowLevel::ListMap2 => { + match self.param_map.get_symbol(arguments[2], *closure_layout) { + Some(function_ps) => { + let borrows = [ + function_ps[1].borrow, + function_ps[2].borrow, + FUNCTION, + CLOSURE_DATA, + ]; + + let b = self.add_dec_after_lowlevel( + arguments, + &borrows, + b, + b_live_vars, + ); + + let b = decref_if_owned!(function_ps[1].borrow, arguments[0], b); + let b = decref_if_owned!(function_ps[2].borrow, arguments[1], b); + + let v = create_call!(function_ps[2].borrow); + + &*self.arena.alloc(Stmt::Let(z, v, l, b)) + } + None => unreachable!(), + } + } + _ => { + let ps = crate::borrow::lowlevel_borrow_signature(self.arena, *op); + let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars); + + let v = Expr::Call(crate::ir::Call { + call_type, + arguments, + }); + + &*self.arena.alloc(Stmt::Let(z, v, l, b)) } } - _ => { - let ps = crate::borrow::lowlevel_borrow_signature(self.arena, *op); - let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars); - - let v = Expr::Call(crate::ir::Call { - call_type, - arguments, - }); - - &*self.arena.alloc(Stmt::Let(z, v, l, b)) - } - }, + } Foreign { .. } => { let ps = crate::borrow::foreign_borrow_signature(self.arena, arguments.len());