Handle recursive calls to capturing function

This commit is contained in:
Ayaz Hafiz 2022-07-03 16:32:33 -04:00
parent d07c273542
commit b868e0e469
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
3 changed files with 97 additions and 23 deletions

View file

@ -7635,23 +7635,37 @@ fn call_by_name_help<'a>(
"see call_by_name for background (scroll down a bit), function is {:?}",
proc_name,
);
call_specialized_proc(
env,
procs,
proc_name,
lambda_set,
RawFunctionLayout::Function(argument_layouts, lambda_set, ret_layout),
top_level_layout,
field_symbols.into_bump_slice(),
loc_args,
layout_cache,
assigned,
hole,
)
let field_symbols = field_symbols.into_bump_slice();
// let field_symbols = field_symbols.into_bump_slice();
let call = self::Call {
call_type: CallType::ByName {
name: proc_name,
ret_layout,
arg_layouts: argument_layouts,
specialization_id: env.next_call_specialization_id(),
},
arguments: field_symbols,
};
// dbg!((proc_name, argument_layouts));
// let call = self::Call {
// call_type: CallType::ByName {
// name: proc_name,
// ret_layout,
// arg_layouts: argument_layouts,
// specialization_id: env.next_call_specialization_id(),
// },
// arguments: field_symbols,
// };
let result = build_call(env, call, assigned, *ret_layout, hole);
// let result = build_call(env, call, assigned, *ret_layout, hole);
let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev());
assign_to_symbols(env, procs, layout_cache, iter, result)
// let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev());
// assign_to_symbols(env, procs, layout_cache, iter, result)
} else if env.is_imported_symbol(proc_name.name()) {
add_needed_external(procs, env, original_fn_var, proc_name);
@ -7799,13 +7813,25 @@ fn call_by_name_help<'a>(
partial_proc,
) {
Ok((proc, layout)) => {
let function_layout = ProcLayout::from_raw(
env.arena,
layout,
proc.name.captures_niche(),
);
procs.specialized.insert_specialized(
proc_name.name(),
function_layout,
proc,
);
// now we just call our freshly-specialized function
call_specialized_proc(
env,
procs,
proc,
proc_name,
lambda_set,
layout,
function_layout,
field_symbols,
loc_args,
layout_cache,
@ -7820,12 +7846,24 @@ fn call_by_name_help<'a>(
attempted_layout,
);
let function_layout = ProcLayout::from_raw(
env.arena,
attempted_layout,
proc_name.captures_niche(),
);
procs.specialized.insert_specialized(
proc_name.name(),
function_layout,
proc,
);
call_specialized_proc(
env,
procs,
proc,
proc_name,
lambda_set,
attempted_layout,
function_layout,
field_symbols,
loc_args,
layout_cache,
@ -7974,22 +8012,16 @@ fn call_by_name_module_thunk<'a>(
fn call_specialized_proc<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
proc: Proc<'a>,
proc_name: LambdaName<'a>,
lambda_set: LambdaSet<'a>,
layout: RawFunctionLayout<'a>,
function_layout: ProcLayout<'a>,
field_symbols: &'a [Symbol],
loc_args: std::vec::Vec<(Variable, Loc<roc_can::expr::Expr>)>,
layout_cache: &mut LayoutCache<'a>,
assigned: Symbol,
hole: &'a Stmt<'a>,
) -> Stmt<'a> {
let proc_name = proc.name;
let function_layout = ProcLayout::from_raw(env.arena, layout, proc_name.captures_niche());
procs
.specialized
.insert_specialized(proc_name.name(), function_layout, proc);
if field_symbols.is_empty() {
debug_assert!(loc_args.is_empty());

View file

@ -0,0 +1,27 @@
procedure Num.19 (#Attr.2, #Attr.3):
let Num.188 : U32 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Num.188;
procedure Test.1 (Test.2):
let Test.9 : U32 = 0i64;
let Test.16 : {U32} = Struct {Test.2};
let Test.8 : U32 = CallByName Test.3 Test.9 Test.16;
ret Test.8;
procedure Test.3 (Test.18, Test.19):
joinpoint Test.10 Test.4 #Attr.12:
let Test.2 : U32 = StructAtIndex 0 #Attr.12;
let Test.14 : Int1 = true;
if Test.14 then
ret Test.4;
else
let Test.12 : U32 = CallByName Num.19 Test.4 Test.2;
let Test.13 : {U32} = Struct {Test.2};
jump Test.10 Test.12 Test.13;
in
jump Test.10 Test.18 Test.19;
procedure Test.0 ():
let Test.7 : U32 = 6i64;
let Test.6 : U32 = CallByName Test.1 Test.7;
ret Test.6;

View file

@ -1682,3 +1682,18 @@ fn choose_u128_layout() {
"#
)
}
#[mono_test]
fn recursive_call_capturing_function() {
indoc!(
r#"
a = \b ->
c : U32 -> U32
c = \d ->
if True then d else c (d+b)
c 0
a 6
"#
)
}