diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index d6d2359064..e91d8ad0b6 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -945,6 +945,31 @@ mod gen_primitives { ); } + #[test] + fn closure_in_list() { + assert_evals_to!( + indoc!( + r#" + app Test provides [ main ] imports [] + + foo = \{} -> + x = 41 + + f = \{} -> x + + [ f ] + + main = + items = foo {} + + List.len items + "# + ), + 1, + i64 + ); + } + #[test] #[ignore] fn specialize_closure() { diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 3cef1366f4..0d2cb82902 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -34,6 +34,15 @@ pub enum CapturedSymbols<'a> { Captured(&'a [(Symbol, Variable)]), } +impl<'a> CapturedSymbols<'a> { + fn captures(&self) -> bool { + match self { + CapturedSymbols::None => false, + CapturedSymbols::Captured(_) => true, + } + } +} + #[derive(Clone, Debug, PartialEq)] pub struct PendingSpecialization { solved_type: SolvedType, @@ -1741,6 +1750,7 @@ impl<'a> FunctionLayouts<'a> { result: (*result).clone(), full: layout, }, + Layout::Closure(_, _, _) => todo!(), _ => FunctionLayouts { full: layout.clone(), arguments: &[], @@ -4226,32 +4236,52 @@ fn assign_to_symbol<'a>( ) -> Stmt<'a> { // if this argument is already a symbol, we don't need to re-define it if let roc_can::expr::Expr::Var(original) = loc_arg.value { - if procs.partial_procs.contains_key(&original) { - // this symbol is a function, that is used by-name (e.g. as an argument to another - // function). Register it with the current variable, then create a function pointer - // to it in the IR. - let layout = layout_cache - .from_var(env.arena, arg_var, env.subs) - .expect("creating layout does not fail"); - procs.insert_passed_by_name(env, arg_var, original, layout.clone(), layout_cache); + match procs.partial_procs.get(&original) { + Some(partial_proc) => { + // this symbol is a function, that is used by-name (e.g. as an argument to another + // function). Register it with the current variable, then create a function pointer + // to it in the IR. + let layout = layout_cache + .from_var(env.arena, arg_var, env.subs) + .expect("creating layout does not fail"); - return Stmt::Let( - symbol, - Expr::FunctionPointer(original, layout.clone()), - layout, - env.arena.alloc(result), - ); + // we have three kinds of functions really. Plain functions, closures by capture, + // and closures by unification. Here we record whether this function captures + // anything. + let captures = partial_proc.captured_symbols.captures(); + drop(partial_proc); + + procs.insert_passed_by_name(env, arg_var, original, layout.clone(), layout_cache); + + match layout { + Layout::Closure(_, _, _) if captures => { + // this is a closure by capture, meaning it itself captures local variables. + // we've defined the closure as a (function_ptr, closure_data) pair already + // replace `symbol` with `original` + let mut stmt = result; + substitute_in_exprs(env.arena, &mut stmt, symbol, original); + stmt + } + _ => Stmt::Let( + symbol, + Expr::FunctionPointer(original, layout.clone()), + layout, + env.arena.alloc(result), + ), + } + } + _ => result, } - return result; + } else { + with_hole( + env, + loc_arg.value, + procs, + layout_cache, + symbol, + env.arena.alloc(result), + ) } - with_hole( - env, - loc_arg.value, - procs, - layout_cache, - symbol, - env.arena.alloc(result), - ) } fn assign_to_symbols<'a, I>( diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 1bddbc8953..58087a7c00 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -64,30 +64,33 @@ impl<'a> ClosureLayout<'a> { } } - fn from_wrapped(tags: &'a [(TagName, &'a [Layout<'a>])]) -> Self { + fn from_wrapped(arena: &'a Bump, tags: &'a [(TagName, &'a [Layout<'a>])]) -> Self { + debug_assert!(!tags.is_empty()); // NOTE we fabricate a pointer size here. // That's fine because we don't care about the exact size, just the biggest one let pointer_size = 8; let mut largest_size = 0; - let mut largest = None; + let mut max_size = &[] as &[_]; - for (_, tag_args) in tags.iter() { + let mut tag_arguments = Vec::with_capacity_in(tags.len(), arena); + + for (name, tag_args_with_discr) in tags.iter() { + let tag_args = &tag_args_with_discr[1..]; let size = tag_args.iter().map(|l| l.stack_size(pointer_size)).sum(); // >= because some of our layouts have 0 size, but are still valid layouts if size >= largest_size { largest_size = size; - largest = Some(tag_args); + max_size = tag_args; } + + tag_arguments.push((name.clone(), tag_args)); } - match largest { - None => unreachable!("A tag union layout must always contain 2 or more tags"), - Some(max_size) => ClosureLayout { - captured: tags, - max_size, - }, + ClosureLayout { + captured: tag_arguments.into_bump_slice(), + max_size, } } @@ -105,7 +108,7 @@ impl<'a> ClosureLayout<'a> { use UnionVariant::*; match variant { Never | Unit => { - // a max closure size of 0 means this is a standart top-level function + // a max closure size of 0 means this is a standard top-level function Ok(None) } BoolUnion { .. } => { @@ -126,7 +129,8 @@ impl<'a> ClosureLayout<'a> { } Wrapped(tags) => { // Wrapped(Vec<'a, (TagName, &'a [Layout<'a>])>), - let closure_layout = ClosureLayout::from_wrapped(tags.into_bump_slice()); + let closure_layout = + ClosureLayout::from_wrapped(arena, tags.into_bump_slice()); Ok(Some(closure_layout)) } } @@ -214,7 +218,11 @@ impl<'a> ClosureLayout<'a> { } pub fn as_block_of_memory_layout(&self) -> Layout<'a> { - Layout::Struct(self.max_size) + if self.max_size.len() == 1 { + self.max_size[0].clone() + } else { + Layout::Struct(self.max_size) + } } }