diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index d3099dd145..e68f788040 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -546,20 +546,13 @@ fn modify_refcount_layout_help<'a, 'ctx, 'env>( } Closure(_, lambda_set, _) => { if lambda_set.contains_refcounted() { - let wrapper_struct = value.into_struct_value(); - - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, 1, "modify_rc_closure_data") - .unwrap(); - modify_refcount_layout_help( env, parent, layout_ids, mode, when_recursive, - field_ptr, + value, &lambda_set.runtime_representation(), ) } diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index c81039d5d2..e686e83183 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -3950,9 +3950,6 @@ pub fn with_hole<'a>( if let Layout::Closure(_, lambda_set, _) = full_layout { let closure_data_symbol = function_symbol; - // layout of the closure record - let closure_data_layout = lambda_set.runtime_representation(); - result = match_on_lambda_set( env, lambda_set, @@ -3988,9 +3985,6 @@ pub fn with_hole<'a>( Layout::Closure(_, lambda_set, _) => { let closure_data_symbol = env.unique_symbol(); - // layout of the closure record - let closure_data_layout = lambda_set.runtime_representation(); - result = match_on_lambda_set( env, lambda_set, @@ -4110,7 +4104,28 @@ pub fn with_hole<'a>( let list_symbol = arg_symbols[0]; let closure_data_symbol = arg_symbols[1]; - panic!() + let closure_data_layout = return_on_layout_error!( + env, + layout_cache.from_var(env.arena, args[1].0, env.subs) + ); + + let arena = env.arena; + + match closure_data_layout { + Layout::Closure(_, lambda_set, _) => lowlevel_match_on_lambda_set( + env, + lambda_set, + closure_data_symbol, + |top_level_function, closure_data| self::Call { + call_type: CallType::LowLevel { op }, + arguments: arena.alloc([list_symbol, top_level_function]), + }, + layout, + assigned, + hole, + ), + _ => unreachable!(), + } } _ => { let call = self::Call { @@ -5923,7 +5938,6 @@ fn call_by_name_help<'a>( assigned: Symbol, hole: &'a Stmt<'a>, ) -> Stmt<'a> { - dbg!(proc_name, &loc_args); let original_fn_var = fn_var; let arena = env.arena; @@ -7378,18 +7392,78 @@ pub fn num_argument_to_int_or_float( } } -fn lambda_set_lowlevel<'a, F>( +/// Use the lambda set to figure out how to make a lowlevel call +fn lowlevel_match_on_lambda_set<'a, F>( env: &mut Env<'a, '_>, - lambda_set: &'a [(TagName, &'a [Layout<'a>])], - to_lowlevel_call: F, - closure_tag_id_symbol: Symbol, + lambda_set: LambdaSet<'a>, closure_data_symbol: Symbol, + to_lowlevel_call: F, return_layout: Layout<'a>, assigned: Symbol, hole: &'a Stmt<'a>, ) -> Stmt<'a> where - F: Fn(Symbol) -> Stmt<'a>, + F: Fn(Symbol, Symbol) -> Call<'a> + Copy, +{ + match lambda_set.runtime_representation() { + Layout::Union(_) => { + let closure_tag_id_symbol = env.unique_symbol(); + + let result = lowlevel_union_lambda_set_to_switch( + env, + lambda_set.set, + closure_tag_id_symbol, + Layout::Builtin(crate::layout::TAG_SIZE), + closure_data_symbol, + to_lowlevel_call, + return_layout, + assigned, + hole, + ); + + // extract & assign the closure_tag_id_symbol + let expr = Expr::AccessAtIndex { + index: 0, + field_layouts: env.arena.alloc([Layout::Builtin(Builtin::Int64)]), + structure: closure_data_symbol, + wrapped: Wrapped::MultiTagUnion, + }; + + Stmt::Let( + closure_tag_id_symbol, + expr, + Layout::Builtin(Builtin::Int64), + env.arena.alloc(result), + ) + } + Layout::Struct(_) => { + let function_symbol = lambda_set.set[0].0; + + // build the call + Stmt::Let( + assigned, + Expr::Call(to_lowlevel_call(function_symbol, closure_data_symbol)), + return_layout, + env.arena.alloc(hole), + ) + } + other => todo!("{:?}", other), + } +} + +fn lowlevel_union_lambda_set_to_switch<'a, F>( + env: &mut Env<'a, '_>, + lambda_set: &'a [(Symbol, &'a [Layout<'a>])], + closure_tag_id_symbol: Symbol, + closure_tag_id_layout: Layout<'a>, + closure_data_symbol: Symbol, + to_lowlevel_call: F, + return_layout: Layout<'a>, + assigned: Symbol, + hole: &'a Stmt<'a>, +) -> Stmt<'a> +where + F: Fn(Symbol, Symbol) -> Call<'a> + Copy, { debug_assert!(!lambda_set.is_empty()); @@ -7397,13 +7471,19 @@ where let mut branches = Vec::with_capacity_in(lambda_set.len(), env.arena); - for (i, (tag_name, _)) in lambda_set.iter().enumerate() { - if let TagName::Closure(function_symbol) = tag_name { - let stmt = to_lowlevel_call(*function_symbol); - branches.push((i as u64, BranchInfo::None, stmt)); - } else { - unreachable!("non-closure tag in lambda set") - } + for (i, (function_symbol, _)) in lambda_set.iter().enumerate() { + let assigned = env.unique_symbol(); + + let hole = Stmt::Jump(join_point_id, env.arena.alloc([assigned])); + + // build the call + let stmt = Stmt::Let( + assigned, + Expr::Call(to_lowlevel_call(*function_symbol, closure_data_symbol)), + return_layout, + env.arena.alloc(hole), + ); + branches.push((i as u64, BranchInfo::None, stmt)); } let default_branch = { @@ -7414,7 +7494,7 @@ where let switch = Stmt::Switch { cond_symbol: closure_tag_id_symbol, - cond_layout: Layout::Builtin(Builtin::Int64), + cond_layout: closure_tag_id_layout, branches: branches.into_bump_slice(), default_branch, ret_layout: return_layout, @@ -7434,6 +7514,31 @@ where } } +fn lowlevel_union_lambda_set_branch<'a, F>( + env: &mut Env<'a, '_>, + join_point_id: JoinPointId, + function_symbol: Symbol, + closure_data_symbol: Symbol, + closure_data_layout: Layout<'a>, + to_lowlevel_call: F, + return_layout: Layout<'a>, +) -> Stmt<'a> +where + F: Fn(Symbol, Symbol) -> Call<'a> + Copy, +{ + let assigned = env.unique_symbol(); + + let hole = Stmt::Jump(join_point_id, env.arena.alloc([assigned])); + + // build the call + Stmt::Let( + assigned, + Expr::Call(to_lowlevel_call(function_symbol, closure_data_symbol)), + return_layout, + env.arena.alloc(hole), + ) +} + /// Use the lambda set to figure out how to make a call-by-name fn match_on_lambda_set<'a>( env: &mut Env<'a, '_>, @@ -7452,6 +7557,7 @@ fn match_on_lambda_set<'a>( let result = union_lambda_set_to_switch( env, lambda_set.set, + lambda_set.runtime_representation(), closure_tag_id_symbol, Layout::Builtin(crate::layout::TAG_SIZE), closure_data_symbol, @@ -7480,7 +7586,7 @@ fn match_on_lambda_set<'a>( Layout::Struct(fields) => { let function_symbol = lambda_set.set[0].0; - lambda_set_to_switch_make_branch_help( + union_lambda_set_branch_help( env, function_symbol, closure_data_symbol, @@ -7531,6 +7637,7 @@ fn match_on_lambda_set<'a>( fn union_lambda_set_to_switch<'a>( env: &mut Env<'a, '_>, lambda_set: &'a [(Symbol, &'a [Layout<'a>])], + closure_layout: Layout<'a>, closure_tag_id_symbol: Symbol, closure_tag_id_layout: Layout<'a>, closure_data_symbol: Symbol, @@ -7546,16 +7653,8 @@ fn union_lambda_set_to_switch<'a>( let mut branches = Vec::with_capacity_in(lambda_set.len(), env.arena); - let closure_layout = { - let mut temporary = Vec::with_capacity_in(lambda_set.len(), env.arena); - - temporary.extend(lambda_set.iter().map(|x| x.1)); - - Layout::Union(UnionLayout::NonRecursive(temporary.into_bump_slice())) - }; - for (i, (function_symbol, _)) in lambda_set.iter().enumerate() { - let stmt = lambda_set_to_switch_make_branch( + let stmt = union_lambda_set_branch( env, join_point_id, *function_symbol, @@ -7596,7 +7695,7 @@ fn union_lambda_set_to_switch<'a>( } } -fn lambda_set_to_switch_make_branch<'a>( +fn union_lambda_set_branch<'a>( env: &mut Env<'a, '_>, join_point_id: JoinPointId, function_symbol: Symbol, @@ -7610,7 +7709,7 @@ fn lambda_set_to_switch_make_branch<'a>( let hole = Stmt::Jump(join_point_id, env.arena.alloc([result_symbol])); - lambda_set_to_switch_make_branch_help( + union_lambda_set_branch_help( env, function_symbol, closure_data_symbol, @@ -7623,7 +7722,7 @@ fn lambda_set_to_switch_make_branch<'a>( ) } -fn lambda_set_to_switch_make_branch_help<'a>( +fn union_lambda_set_branch_help<'a>( env: &mut Env<'a, '_>, function_symbol: Symbol, closure_data_symbol: Symbol, diff --git a/compiler/test_gen/src/gen_primitives.rs b/compiler/test_gen/src/gen_primitives.rs index cae3c5c862..aff0853438 100644 --- a/compiler/test_gen/src/gen_primitives.rs +++ b/compiler/test_gen/src/gen_primitives.rs @@ -992,18 +992,19 @@ fn specialize_closure() { foo = \{} -> x = 41 - y = 1 + y = [1] f = \{} -> x - g = \{} -> x + y + g = \{} -> x + List.len y [ f, g ] + apply = \f -> f {} + main = items = foo {} - # List.len items - List.map items (\f -> f {}) + List.map items apply "# ), RocList::from_slice(&[41, 42]),