diff --git a/compiler/load/src/file.rs b/compiler/load/src/file.rs index a385eee464..ea760f8c28 100644 --- a/compiler/load/src/file.rs +++ b/compiler/load/src/file.rs @@ -751,6 +751,7 @@ enum Msg<'a> { layout_cache: LayoutCache<'a>, external_specializations_requested: MutMap, procedures: MutMap<(Symbol, Layout<'a>), Proc<'a>>, + passed_by_pointer: MutMap<(Symbol, Layout<'a>), Symbol>, problems: Vec, module_timing: ModuleTiming, subs: Subs, @@ -781,6 +782,7 @@ struct State<'a> { pub module_cache: ModuleCache<'a>, pub dependencies: Dependencies<'a>, pub procedures: MutMap<(Symbol, Layout<'a>), Proc<'a>>, + pub passed_by_pointer: MutMap<(Symbol, Layout<'a>), Symbol>, pub exposed_to_host: MutMap, /// This is the "final" list of IdentIds, after canonicalization and constraint gen @@ -1403,6 +1405,7 @@ where module_cache: ModuleCache::default(), dependencies: Dependencies::default(), procedures: MutMap::default(), + passed_by_pointer: MutMap::default(), exposed_to_host: MutMap::default(), exposed_types, headers_parsed, @@ -1931,6 +1934,7 @@ fn update<'a>( mut ident_ids, subs, procedures, + passed_by_pointer, external_specializations_requested, problems, module_timing, @@ -1945,12 +1949,17 @@ fn update<'a>( .notify(module_id, Phase::MakeSpecializations); state.procedures.extend(procedures); + state.passed_by_pointer.extend(passed_by_pointer); state.timings.insert(module_id, module_timing); if state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations { debug_assert!(work.is_empty(), "still work remaining {:?}", &work); - Proc::insert_refcount_operations(arena, &mut state.procedures); + Proc::insert_refcount_operations( + arena, + &mut state.procedures, + &state.passed_by_pointer, + ); Proc::optimize_refcount_operations( arena, @@ -3621,7 +3630,7 @@ fn make_specializations<'a>( ); let external_specializations_requested = procs.externals_we_need.clone(); - let procedures = procs.get_specialized_procs_without_rc(mono_env.arena); + let (procedures, passed_by_pointer) = procs.get_specialized_procs_without_rc(mono_env.arena); let make_specializations_end = SystemTime::now(); module_timing.make_specializations = make_specializations_end @@ -3633,6 +3642,7 @@ fn make_specializations<'a>( ident_ids, layout_cache, procedures, + passed_by_pointer, problems: mono_problems, subs, external_specializations_requested, diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 14f6687ed0..4399d8aa18 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -9,11 +9,21 @@ use roc_module::symbol::Symbol; pub fn infer_borrow<'a>( arena: &'a Bump, procs: &MutMap<(Symbol, Layout<'a>), Proc<'a>>, + passed_by_pointer: &MutMap<(Symbol, Layout<'a>), Symbol>, ) -> ParamMap<'a> { let mut param_map = ParamMap { items: MutMap::default(), }; + for (key, other) in passed_by_pointer { + if let Some(proc) = procs.get(key) { + let mut proc: Proc = proc.clone(); + proc.name = *other; + + param_map.visit_proc_always_owned(arena, &proc); + } + } + for proc in procs.values() { param_map.visit_proc(arena, proc); } @@ -125,6 +135,21 @@ impl<'a> ParamMap<'a> { .into_bump_slice() } + fn init_borrow_args_always_owned( + arena: &'a Bump, + ps: &'a [(Layout<'a>, Symbol)], + ) -> &'a [Param<'a>] { + Vec::from_iter_in( + ps.iter().map(|(layout, symbol)| Param { + borrow: false, + layout: layout.clone(), + symbol: *symbol, + }), + arena, + ) + .into_bump_slice() + } + fn visit_proc(&mut self, arena: &'a Bump, proc: &Proc<'a>) { self.items.insert( Key::Declaration(proc.name), @@ -134,6 +159,15 @@ impl<'a> ParamMap<'a> { self.visit_stmt(arena, proc.name, &proc.body); } + fn visit_proc_always_owned(&mut self, arena: &'a Bump, proc: &Proc<'a>) { + self.items.insert( + Key::Declaration(proc.name), + Self::init_borrow_args_always_owned(arena, proc.args), + ); + + self.visit_stmt(arena, proc.name, &proc.body); + } + fn visit_stmt(&mut self, arena: &'a Bump, _fnid: Symbol, stmt: &Stmt<'a>) { use Stmt::*; diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 7b68fc593b..698e1332d0 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -172,8 +172,20 @@ impl<'a> Proc<'a> { pub fn insert_refcount_operations( arena: &'a Bump, procs: &mut MutMap<(Symbol, Layout<'a>), Proc<'a>>, + passed_by_pointer: &MutMap<(Symbol, Layout<'a>), Symbol>, ) { - let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, procs)); + let borrow_params = + arena.alloc(crate::borrow::infer_borrow(arena, procs, passed_by_pointer)); + + for (key, other) in passed_by_pointer { + if let Some(proc) = procs.get(key) { + let mut proc: Proc = proc.clone(); + proc.name = *other; + + let layout = key.1.clone(); + procs.insert((*other, layout), proc); + } + } for (_, proc) in procs.iter_mut() { crate::inc_dec::visit_proc(arena, borrow_params, proc); @@ -255,6 +267,7 @@ pub struct Procs<'a> { pub runtime_errors: MutMap, pub externals_others_need: ExternalSpecializations, pub externals_we_need: MutMap, + pub passed_by_pointer: MutMap<(Symbol, Layout<'a>), Symbol>, } impl<'a> Default for Procs<'a> { @@ -267,6 +280,7 @@ impl<'a> Default for Procs<'a> { runtime_errors: MutMap::default(), externals_we_need: MutMap::default(), externals_others_need: ExternalSpecializations::default(), + passed_by_pointer: MutMap::default(), } } } @@ -314,7 +328,10 @@ impl<'a> Procs<'a> { pub fn get_specialized_procs_without_rc( self, arena: &'a Bump, - ) -> MutMap<(Symbol, Layout<'a>), Proc<'a>> { + ) -> ( + MutMap<(Symbol, Layout<'a>), Proc<'a>>, + MutMap<(Symbol, Layout<'a>), Symbol>, + ) { let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher()); for (key, in_prog_proc) in self.specialized.into_iter() { @@ -337,7 +354,7 @@ impl<'a> Procs<'a> { } } - result + (result, self.passed_by_pointer) } // TODO investigate make this an iterator? @@ -366,7 +383,11 @@ impl<'a> Procs<'a> { } } - let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result)); + let borrow_params = arena.alloc(crate::borrow::infer_borrow( + arena, + &result, + &self.passed_by_pointer, + )); for (_, proc) in result.iter_mut() { crate::inc_dec::visit_proc(arena, borrow_params, proc); @@ -406,7 +427,11 @@ impl<'a> Procs<'a> { } } - let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result)); + let borrow_params = arena.alloc(crate::borrow::infer_borrow( + arena, + &result, + &self.passed_by_pointer, + )); for (_, proc) in result.iter_mut() { crate::inc_dec::visit_proc(arena, borrow_params, proc); @@ -2435,7 +2460,8 @@ fn specialize_naked_symbol<'a>( match hole { Stmt::Jump(_, _) => todo!("not sure what to do in this case yet"), _ => { - let expr = Expr::FunctionPointer(symbol, layout.clone()); + let expr = + call_by_pointer(env, procs, layout_cache, symbol, layout.clone()); let new_symbol = env.unique_symbol(); return Stmt::Let( new_symbol, @@ -3523,7 +3549,7 @@ pub fn with_hole<'a>( // TODO should the let have layout Pointer? Stmt::Let( assigned, - Expr::FunctionPointer(name, layout.clone()), + call_by_pointer(env, procs, layout_cache, name, layout.clone()), layout, hole, ) @@ -3720,7 +3746,13 @@ pub fn with_hole<'a>( } } - let expr = Expr::FunctionPointer(name, function_ptr_layout.clone()); + let expr = call_by_pointer( + env, + procs, + layout_cache, + name, + function_ptr_layout.clone(), + ); stmt = Stmt::Let( function_pointer, @@ -3746,7 +3778,7 @@ pub fn with_hole<'a>( // TODO should the let have layout Pointer? Stmt::Let( assigned, - Expr::FunctionPointer(name, layout.clone()), + call_by_pointer(env, procs, layout_cache, name, layout.clone()), layout, hole, ) @@ -5327,7 +5359,7 @@ fn handle_variable_aliasing<'a>( .from_var(env.arena, variable, env.subs) .unwrap(); - let expr = Expr::FunctionPointer(right, layout.clone()); + let expr = call_by_pointer(env, procs, layout_cache, right, layout.clone()); Stmt::Let(left, expr, layout, env.arena.alloc(result)) } else { substitute_in_exprs(env.arena, &mut result, left, right); @@ -5375,7 +5407,7 @@ fn reuse_function_symbol<'a>( // an imported symbol is always a function pointer: // either it's a function, or a top-level 0-argument thunk - let expr = Expr::FunctionPointer(original, layout.clone()); + let expr = call_by_pointer(env, procs, layout_cache, original, layout.clone()); return Stmt::Let(symbol, expr, layout, env.arena.alloc(result)); } _ => { @@ -5572,6 +5604,20 @@ where result } +fn call_by_pointer<'a>( + env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, + symbol: Symbol, + layout: Layout<'a>, +) -> Expr<'a> { + procs + .passed_by_pointer + .insert((symbol, layout.clone()), symbol); + + Expr::FunctionPointer(symbol, layout) +} + fn add_needed_external<'a>( procs: &mut Procs<'a>, env: &mut Env<'a, '_>,