diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 0619e3c3d6..e5822a8f89 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -111,6 +111,11 @@ impl Symbol { pub const fn to_ne_bytes(self) -> [u8; 8] { self.0.to_ne_bytes() } + + #[cfg(debug_assertions)] + pub fn contains(self, needle: &str) -> bool { + format!("{:?}", self).contains(needle) + } } /// Rather than displaying as this: diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 4bd609b72e..31fae15711 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -4172,7 +4172,7 @@ pub fn with_hole<'a>( layout_cache, lambda_set, name, - symbols, + symbols.iter().copied(), assigned, hole, ) @@ -4744,17 +4744,25 @@ fn get_specialization<'a>( } #[allow(clippy::too_many_arguments)] -fn construct_closure_data<'a>( +fn construct_closure_data<'a, I>( env: &mut Env<'a, '_>, procs: &mut Procs<'a>, layout_cache: &mut LayoutCache<'a>, lambda_set: LambdaSet<'a>, name: Symbol, - symbols: &'a [&(Symbol, Variable)], + symbols: I, assigned: Symbol, hole: &'a Stmt<'a>, -) -> Stmt<'a> { +) -> Stmt<'a> +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ let lambda_set_layout = Layout::LambdaSet(lambda_set); + let symbols = symbols.into_iter(); + + // arguments with a polymorphic type that we have to deal with + let mut polymorphic_arguments = Vec::new_in(env.arena); let mut result = match lambda_set.layout_for_member(name) { ClosureRepresentation::Union { @@ -4765,10 +4773,14 @@ fn construct_closure_data<'a>( } => { // captured variables are in symbol-alphabetic order, but now we want // them ordered by their alignment requirements - let mut combined = Vec::from_iter_in( - symbols.iter().map(|&&(s, _)| s).zip(field_layouts.iter()), - env.arena, - ); + let mut combined = Vec::with_capacity_in(symbols.len(), env.arena); + for ((symbol, variable), layout) in symbols.zip(field_layouts.iter()) { + if procs.partial_exprs.contains(*symbol) { + polymorphic_arguments.push((*symbol, *variable)); + } + + combined.push((*symbol, layout)) + } let ptr_bytes = env.target_info; @@ -4796,10 +4808,14 @@ fn construct_closure_data<'a>( // captured variables are in symbol-alphabetic order, but now we want // them ordered by their alignment requirements - let mut combined = Vec::from_iter_in( - symbols.iter().map(|&(s, _)| s).zip(field_layouts.iter()), - env.arena, - ); + let mut combined = Vec::with_capacity_in(symbols.len(), env.arena); + for ((symbol, variable), layout) in symbols.zip(field_layouts.iter()) { + if procs.partial_exprs.contains(*symbol) { + polymorphic_arguments.push((*symbol, *variable)); + } + + combined.push((*symbol, layout)) + } let ptr_bytes = env.target_info; @@ -4811,7 +4827,7 @@ fn construct_closure_data<'a>( }); let symbols = - Vec::from_iter_in(combined.iter().map(|(a, _)| **a), env.arena).into_bump_slice(); + Vec::from_iter_in(combined.iter().map(|(a, _)| *a), env.arena).into_bump_slice(); let field_layouts = Vec::from_iter_in(combined.iter().map(|(_, b)| **b), env.arena).into_bump_slice(); @@ -4852,10 +4868,8 @@ fn construct_closure_data<'a>( // TODO: this is not quite right. What we should actually be doing is removing references to // polymorphic expressions from the captured symbols, and allowing the specializations of those // symbols to be inlined when specializing the closure body elsewhere. - for &&(symbol, var) in symbols { - if procs.partial_exprs.contains(symbol) { - result = specialize_symbol(env, procs, layout_cache, Some(var), symbol, result, symbol); - } + for (symbol, var) in polymorphic_arguments { + result = specialize_symbol(env, procs, layout_cache, Some(var), symbol, result, symbol); } result @@ -6947,7 +6961,7 @@ fn specialize_symbol<'a>( layout_cache, lambda_set, original, - symbols, + symbols.iter().copied(), closure_data, env.arena.alloc(result), ) @@ -7407,13 +7421,20 @@ fn call_by_name_help<'a>( proc_name, ); + let has_closure = argument_layouts.len() != top_level_layout.arguments.len(); + let closure_argument = env.unique_symbol(); + + if has_closure { + field_symbols.push(closure_argument); + } + let field_symbols = field_symbols.into_bump_slice(); let call = self::Call { call_type: CallType::ByName { name: proc_name, ret_layout, - arg_layouts: argument_layouts, + arg_layouts: top_level_layout.arguments, specialization_id: env.next_call_specialization_id(), }, arguments: field_symbols, @@ -7421,8 +7442,33 @@ fn call_by_name_help<'a>( let result = build_call(env, call, assigned, *ret_layout, hole); - let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev()); - assign_to_symbols(env, procs, layout_cache, iter, result) + // NOTE: the zip omits the closure symbol, if it exists, + // because loc_args then is shorter than field_symbols + debug_assert!([0, 1].contains(&(field_symbols.len() - loc_args.len()))); + let iter = loc_args.into_iter().zip(field_symbols.iter()).rev(); + let result = assign_to_symbols(env, procs, layout_cache, iter, result); + + if has_closure { + let partial_proc = procs.partial_procs.get_symbol(proc_name).unwrap(); + + let captured = match partial_proc.captured_symbols { + CapturedSymbols::None => &[], + CapturedSymbols::Captured(slice) => slice, + }; + + construct_closure_data( + env, + procs, + layout_cache, + lambda_set, + proc_name, + captured.iter(), + closure_argument, + env.arena.alloc(result), + ) + } else { + result + } } PendingSpecializations::Making => { let opt_partial_proc = procs.partial_procs.symbol_to_id(proc_name); @@ -7706,7 +7752,7 @@ fn call_specialized_proc<'a>( layout_cache, lambda_set, proc_name, - symbols, + symbols.iter().copied(), closure_data_symbol, env.arena.alloc(new_hole), ); diff --git a/compiler/test_gen/src/gen_primitives.rs b/compiler/test_gen/src/gen_primitives.rs index 593096087b..8e7e86eca1 100644 --- a/compiler/test_gen/src/gen_primitives.rs +++ b/compiler/test_gen/src/gen_primitives.rs @@ -3344,3 +3344,55 @@ fn box_and_unbox_tag_union() { (u8, u8) ) } + +#[test] +#[cfg(any(feature = "gen-llvm"))] +fn closure_called_in_its_defining_scope() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + main : Str + main = + g : Str + g = "hello world" + + getG : {} -> Str + getG = \{} -> g + + getG {} + "# + ), + RocStr::from("hello world"), + RocStr + ) +} + +#[test] +#[ignore] +#[cfg(any(feature = "gen-llvm"))] +fn issue_2894() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + main : U32 + main = + g : { x : U32 } + g = { x: 1u32 } + + getG : {} -> { x : U32 } + getG = \{} -> g + + h : {} -> U32 + h = \{} -> (getG {}).x + + h {} + "# + ), + 1u32, + u32 + ) +}