Merge pull request #3389 from rtfeldman/3378

Call recursive function with captures, and consolidate proc calling
This commit is contained in:
Folkert de Vries 2022-07-08 16:14:33 +02:00 committed by GitHub
commit 43f9b0d0fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 235 additions and 189 deletions

View file

@ -7646,23 +7646,19 @@ fn call_by_name_help<'a>(
"see call_by_name for background (scroll down a bit), function is {:?}",
proc_name,
);
let field_symbols = field_symbols.into_bump_slice();
let call = self::Call {
call_type: CallType::ByName {
name: proc_name,
ret_layout,
arg_layouts: argument_layouts,
specialization_id: env.next_call_specialization_id(),
},
arguments: field_symbols,
};
let result = build_call(env, call, assigned, *ret_layout, hole);
let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev());
assign_to_symbols(env, procs, layout_cache, iter, result)
call_specialized_proc(
env,
procs,
proc_name,
lambda_set,
RawFunctionLayout::Function(argument_layouts, lambda_set, ret_layout),
top_level_layout,
field_symbols.into_bump_slice(),
loc_args,
layout_cache,
assigned,
hole,
)
} else if env.is_imported_symbol(proc_name.name()) {
add_needed_external(procs, env, original_fn_var, proc_name);
@ -7739,52 +7735,21 @@ fn call_by_name_help<'a>(
proc_name,
);
let has_captures = argument_layouts.len() != top_level_layout.arguments.len();
let closure_argument = env.unique_symbol();
if has_captures {
field_symbols.push(closure_argument);
}
let field_symbols = field_symbols.into_bump_slice();
let call = self::Call {
call_type: CallType::ByName {
name: proc_name,
ret_layout,
arg_layouts: top_level_layout.arguments,
specialization_id: env.next_call_specialization_id(),
},
arguments: field_symbols,
};
let result = build_call(env, call, assigned, *ret_layout, hole);
// NOTE: the zip omits the closure symbol, if it exists,
// because loc_args then is shorter than field_symbols
debug_assert!([0, 1].contains(&(field_symbols.len() - loc_args.len())));
let iter = loc_args.into_iter().zip(field_symbols.iter()).rev();
let result = assign_to_symbols(env, procs, layout_cache, iter, result);
if has_captures {
let partial_proc = procs.partial_procs.get_symbol(proc_name.name()).unwrap();
let captured = match partial_proc.captured_symbols {
CapturedSymbols::None => &[],
CapturedSymbols::Captured(slice) => slice,
};
construct_closure_data(
env,
lambda_set,
proc_name,
captured.iter(),
closure_argument,
env.arena.alloc(result),
)
} else {
result
}
call_specialized_proc(
env,
procs,
proc_name,
lambda_set,
RawFunctionLayout::Function(argument_layouts, lambda_set, ret_layout),
top_level_layout,
field_symbols,
loc_args,
layout_cache,
assigned,
hole,
)
}
PendingSpecializations::Making => {
let opt_partial_proc = procs.partial_procs.symbol_to_id(proc_name.name());
@ -7810,13 +7775,26 @@ fn call_by_name_help<'a>(
partial_proc,
) {
Ok((proc, layout)) => {
let proc_name = proc.name;
let function_layout = ProcLayout::from_raw(
env.arena,
layout,
proc_name.captures_niche(),
);
procs.specialized.insert_specialized(
proc_name.name(),
function_layout,
proc,
);
// now we just call our freshly-specialized function
call_specialized_proc(
env,
procs,
proc,
proc_name,
lambda_set,
layout,
function_layout,
field_symbols,
loc_args,
layout_cache,
@ -7831,12 +7809,25 @@ fn call_by_name_help<'a>(
attempted_layout,
);
let proc_name = proc.name;
let function_layout = ProcLayout::from_raw(
env.arena,
attempted_layout,
proc_name.captures_niche(),
);
procs.specialized.insert_specialized(
proc_name.name(),
function_layout,
proc,
);
call_specialized_proc(
env,
procs,
proc,
proc_name,
lambda_set,
attempted_layout,
function_layout,
field_symbols,
loc_args,
layout_cache,
@ -7985,22 +7976,16 @@ fn call_by_name_module_thunk<'a>(
fn call_specialized_proc<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
proc: Proc<'a>,
proc_name: LambdaName<'a>,
lambda_set: LambdaSet<'a>,
layout: RawFunctionLayout<'a>,
function_layout: ProcLayout<'a>,
field_symbols: &'a [Symbol],
loc_args: std::vec::Vec<(Variable, Loc<roc_can::expr::Expr>)>,
layout_cache: &mut LayoutCache<'a>,
assigned: Symbol,
hole: &'a Stmt<'a>,
) -> Stmt<'a> {
let proc_name = proc.name;
let function_layout = ProcLayout::from_raw(env.arena, layout, proc_name.captures_niche());
procs
.specialized
.insert_specialized(proc_name.name(), function_layout, proc);
if field_symbols.is_empty() {
debug_assert!(loc_args.is_empty());

View file

@ -3624,3 +3624,22 @@ fn lambda_capture_niches_have_captured_function_in_closure() {
(RocStr, RocStr)
)
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
fn recursive_call_capturing_function() {
assert_evals_to!(
indoc!(
r#"
a = \b ->
c = \d ->
if d == 7 then d else c (d + b)
c 1
a 6
"#
),
7,
i64
)
}

View file

@ -1,6 +1,6 @@
procedure List.2 (List.75, List.76):
let List.285 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.285;
let List.284 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.284;
if List.281 then
let List.283 : {} = CallByName List.60 List.75 List.76;
let List.282 : [C {}, C {}] = TagId(1) List.283;
@ -11,12 +11,12 @@ procedure List.2 (List.75, List.76):
ret List.279;
procedure List.6 (#Attr.2):
let List.288 : U64 = lowlevel ListLen #Attr.2;
ret List.288;
let List.286 : U64 = lowlevel ListLen #Attr.2;
ret List.286;
procedure List.60 (#Attr.2, #Attr.3):
let List.287 : {} = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.287;
let List.285 : {} = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.285;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.188 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;

View file

@ -1,22 +1,22 @@
procedure List.2 (List.75, List.76):
let List.294 : U64 = CallByName List.6 List.75;
let List.290 : Int1 = CallByName Num.22 List.76 List.294;
if List.290 then
let List.292 : I64 = CallByName List.60 List.75 List.76;
let List.291 : [C {}, C I64] = TagId(1) List.292;
ret List.291;
let List.293 : U64 = CallByName List.6 List.75;
let List.289 : Int1 = CallByName Num.22 List.76 List.293;
if List.289 then
let List.291 : I64 = CallByName List.60 List.75 List.76;
let List.290 : [C {}, C I64] = TagId(1) List.291;
ret List.290;
else
let List.289 : {} = Struct {};
let List.288 : [C {}, C I64] = TagId(0) List.289;
ret List.288;
let List.288 : {} = Struct {};
let List.287 : [C {}, C I64] = TagId(0) List.288;
ret List.287;
procedure List.6 (#Attr.2):
let List.295 : U64 = lowlevel ListLen #Attr.2;
ret List.295;
let List.294 : U64 = lowlevel ListLen #Attr.2;
ret List.294;
procedure List.60 (#Attr.2, #Attr.3):
let List.293 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.293;
let List.292 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.292;
procedure List.9 (List.201):
let List.286 : U64 = 0i64;

View file

@ -6,22 +6,22 @@ procedure List.3 (List.83, List.84, List.85):
ret List.281;
procedure List.57 (List.80, List.81, List.82):
let List.288 : U64 = CallByName List.6 List.80;
let List.285 : Int1 = CallByName Num.22 List.81 List.288;
if List.285 then
let List.286 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.286;
let List.287 : U64 = CallByName List.6 List.80;
let List.284 : Int1 = CallByName Num.22 List.81 List.287;
if List.284 then
let List.285 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.285;
else
let List.284 : {List I64, I64} = Struct {List.80, List.82};
ret List.284;
let List.283 : {List I64, I64} = Struct {List.80, List.82};
ret List.283;
procedure List.6 (#Attr.2):
let List.280 : U64 = lowlevel ListLen #Attr.2;
ret List.280;
procedure List.61 (#Attr.2, #Attr.3, #Attr.4):
let List.287 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.287;
let List.286 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.286;
procedure Num.19 (#Attr.2, #Attr.3):
let Num.188 : U64 = lowlevel NumAdd #Attr.2 #Attr.3;

View file

@ -1,6 +1,6 @@
procedure List.2 (List.75, List.76):
let List.285 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.285;
let List.284 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.284;
if List.281 then
let List.283 : I64 = CallByName List.60 List.75 List.76;
let List.282 : [C {}, C I64] = TagId(1) List.283;
@ -11,12 +11,12 @@ procedure List.2 (List.75, List.76):
ret List.279;
procedure List.6 (#Attr.2):
let List.288 : U64 = lowlevel ListLen #Attr.2;
ret List.288;
let List.286 : U64 = lowlevel ListLen #Attr.2;
ret List.286;
procedure List.60 (#Attr.2, #Attr.3):
let List.287 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.287;
let List.285 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.285;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.188 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;

View file

@ -1,6 +1,6 @@
procedure List.2 (List.75, List.76):
let List.285 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.285;
let List.284 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.284;
if List.281 then
let List.283 : Str = CallByName List.60 List.75 List.76;
let List.282 : [C {}, C Str] = TagId(1) List.283;
@ -11,16 +11,16 @@ procedure List.2 (List.75, List.76):
ret List.279;
procedure List.5 (#Attr.2, #Attr.3):
let List.287 : List Str = lowlevel ListMap { xs: `#Attr.#arg1` } #Attr.2 Test.3 #Attr.3;
ret List.287;
let List.285 : List Str = lowlevel ListMap { xs: `#Attr.#arg1` } #Attr.2 Test.3 #Attr.3;
ret List.285;
procedure List.6 (#Attr.2):
let List.289 : U64 = lowlevel ListLen #Attr.2;
ret List.289;
let List.287 : U64 = lowlevel ListLen #Attr.2;
ret List.287;
procedure List.60 (#Attr.2, #Attr.3):
let List.288 : Str = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.288;
let List.286 : Str = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.286;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.188 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;

View file

@ -1,6 +1,6 @@
procedure List.2 (List.75, List.76):
let List.285 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.285;
let List.284 : U64 = CallByName List.6 List.75;
let List.281 : Int1 = CallByName Num.22 List.76 List.284;
if List.281 then
let List.283 : Str = CallByName List.60 List.75 List.76;
let List.282 : [C {}, C Str] = TagId(1) List.283;
@ -12,17 +12,17 @@ procedure List.2 (List.75, List.76):
procedure List.5 (#Attr.2, #Attr.3):
inc #Attr.2;
let List.287 : List Str = lowlevel ListMap { xs: `#Attr.#arg1` } #Attr.2 Test.3 #Attr.3;
let List.285 : List Str = lowlevel ListMap { xs: `#Attr.#arg1` } #Attr.2 Test.3 #Attr.3;
decref #Attr.2;
ret List.287;
ret List.285;
procedure List.6 (#Attr.2):
let List.289 : U64 = lowlevel ListLen #Attr.2;
ret List.289;
let List.287 : U64 = lowlevel ListLen #Attr.2;
ret List.287;
procedure List.60 (#Attr.2, #Attr.3):
let List.288 : Str = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.288;
let List.286 : Str = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.286;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.188 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;

View file

@ -6,22 +6,22 @@ procedure List.3 (List.83, List.84, List.85):
ret List.279;
procedure List.57 (List.80, List.81, List.82):
let List.286 : U64 = CallByName List.6 List.80;
let List.283 : Int1 = CallByName Num.22 List.81 List.286;
if List.283 then
let List.284 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.284;
let List.285 : U64 = CallByName List.6 List.80;
let List.282 : Int1 = CallByName Num.22 List.81 List.285;
if List.282 then
let List.283 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.283;
else
let List.282 : {List I64, I64} = Struct {List.80, List.82};
ret List.282;
let List.281 : {List I64, I64} = Struct {List.80, List.82};
ret List.281;
procedure List.6 (#Attr.2):
let List.287 : U64 = lowlevel ListLen #Attr.2;
ret List.287;
let List.286 : U64 = lowlevel ListLen #Attr.2;
ret List.286;
procedure List.61 (#Attr.2, #Attr.3, #Attr.4):
let List.285 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.285;
let List.284 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.284;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.188 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;

View file

@ -1,11 +1,11 @@
procedure List.28 (#Attr.2, #Attr.3):
let List.282 : List I64 = lowlevel ListSortWith { xs: `#Attr.#arg1` } #Attr.2 Num.46 #Attr.3;
let List.281 : List I64 = lowlevel ListSortWith { xs: `#Attr.#arg1` } #Attr.2 Num.46 #Attr.3;
let Bool.9 : Int1 = lowlevel ListIsUnique #Attr.2;
if Bool.9 then
ret List.282;
ret List.281;
else
decref #Attr.2;
ret List.282;
ret List.281;
procedure List.54 (List.196):
let List.280 : {} = Struct {};

View file

@ -1,43 +1,43 @@
procedure List.2 (List.75, List.76):
let List.299 : U64 = CallByName List.6 List.75;
let List.295 : Int1 = CallByName Num.22 List.76 List.299;
if List.295 then
let List.297 : I64 = CallByName List.60 List.75 List.76;
let List.296 : [C {}, C I64] = TagId(1) List.297;
ret List.296;
let List.294 : U64 = CallByName List.6 List.75;
let List.291 : Int1 = CallByName Num.22 List.76 List.294;
if List.291 then
let List.293 : I64 = CallByName List.60 List.75 List.76;
let List.292 : [C {}, C I64] = TagId(1) List.293;
ret List.292;
else
let List.294 : {} = Struct {};
let List.293 : [C {}, C I64] = TagId(0) List.294;
ret List.293;
let List.290 : {} = Struct {};
let List.289 : [C {}, C I64] = TagId(0) List.290;
ret List.289;
procedure List.3 (List.83, List.84, List.85):
let List.283 : {List I64, I64} = CallByName List.57 List.83 List.84 List.85;
let List.282 : List I64 = StructAtIndex 0 List.283;
inc List.282;
dec List.283;
ret List.282;
let List.282 : {List I64, I64} = CallByName List.57 List.83 List.84 List.85;
let List.281 : List I64 = StructAtIndex 0 List.282;
inc List.281;
dec List.282;
ret List.281;
procedure List.57 (List.80, List.81, List.82):
let List.305 : U64 = CallByName List.6 List.80;
let List.302 : Int1 = CallByName Num.22 List.81 List.305;
if List.302 then
let List.303 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.303;
let List.299 : U64 = CallByName List.6 List.80;
let List.296 : Int1 = CallByName Num.22 List.81 List.299;
if List.296 then
let List.297 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.297;
else
let List.301 : {List I64, I64} = Struct {List.80, List.82};
ret List.301;
let List.295 : {List I64, I64} = Struct {List.80, List.82};
ret List.295;
procedure List.6 (#Attr.2):
let List.306 : U64 = lowlevel ListLen #Attr.2;
ret List.306;
let List.300 : U64 = lowlevel ListLen #Attr.2;
ret List.300;
procedure List.60 (#Attr.2, #Attr.3):
let List.307 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.307;
let List.301 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.301;
procedure List.61 (#Attr.2, #Attr.3, #Attr.4):
let List.304 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.304;
let List.298 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.298;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.190 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;

View file

@ -0,0 +1,27 @@
procedure Num.19 (#Attr.2, #Attr.3):
let Num.188 : U32 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Num.188;
procedure Test.1 (Test.2):
let Test.9 : U32 = 0i64;
let Test.16 : {U32} = Struct {Test.2};
let Test.8 : U32 = CallByName Test.3 Test.9 Test.16;
ret Test.8;
procedure Test.3 (Test.18, Test.19):
joinpoint Test.10 Test.4 #Attr.12:
let Test.2 : U32 = StructAtIndex 0 #Attr.12;
let Test.14 : Int1 = true;
if Test.14 then
ret Test.4;
else
let Test.12 : U32 = CallByName Num.19 Test.4 Test.2;
let Test.13 : {U32} = Struct {Test.2};
jump Test.10 Test.12 Test.13;
in
jump Test.10 Test.18 Test.19;
procedure Test.0 ():
let Test.7 : U32 = 6i64;
let Test.6 : U32 = CallByName Test.1 Test.7;
ret Test.6;

View file

@ -1,43 +1,43 @@
procedure List.2 (List.75, List.76):
let List.299 : U64 = CallByName List.6 List.75;
let List.295 : Int1 = CallByName Num.22 List.76 List.299;
if List.295 then
let List.297 : I64 = CallByName List.60 List.75 List.76;
let List.296 : [C {}, C I64] = TagId(1) List.297;
ret List.296;
let List.294 : U64 = CallByName List.6 List.75;
let List.291 : Int1 = CallByName Num.22 List.76 List.294;
if List.291 then
let List.293 : I64 = CallByName List.60 List.75 List.76;
let List.292 : [C {}, C I64] = TagId(1) List.293;
ret List.292;
else
let List.294 : {} = Struct {};
let List.293 : [C {}, C I64] = TagId(0) List.294;
ret List.293;
let List.290 : {} = Struct {};
let List.289 : [C {}, C I64] = TagId(0) List.290;
ret List.289;
procedure List.3 (List.83, List.84, List.85):
let List.283 : {List I64, I64} = CallByName List.57 List.83 List.84 List.85;
let List.282 : List I64 = StructAtIndex 0 List.283;
inc List.282;
dec List.283;
ret List.282;
let List.282 : {List I64, I64} = CallByName List.57 List.83 List.84 List.85;
let List.281 : List I64 = StructAtIndex 0 List.282;
inc List.281;
dec List.282;
ret List.281;
procedure List.57 (List.80, List.81, List.82):
let List.305 : U64 = CallByName List.6 List.80;
let List.302 : Int1 = CallByName Num.22 List.81 List.305;
if List.302 then
let List.303 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.303;
let List.299 : U64 = CallByName List.6 List.80;
let List.296 : Int1 = CallByName Num.22 List.81 List.299;
if List.296 then
let List.297 : {List I64, I64} = CallByName List.61 List.80 List.81 List.82;
ret List.297;
else
let List.301 : {List I64, I64} = Struct {List.80, List.82};
ret List.301;
let List.295 : {List I64, I64} = Struct {List.80, List.82};
ret List.295;
procedure List.6 (#Attr.2):
let List.306 : U64 = lowlevel ListLen #Attr.2;
ret List.306;
let List.300 : U64 = lowlevel ListLen #Attr.2;
ret List.300;
procedure List.60 (#Attr.2, #Attr.3):
let List.307 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.307;
let List.301 : I64 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.301;
procedure List.61 (#Attr.2, #Attr.3, #Attr.4):
let List.304 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.304;
let List.298 : {List I64, I64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.298;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.190 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;

View file

@ -1683,6 +1683,21 @@ fn choose_u128_layout() {
)
}
#[mono_test]
fn recursive_call_capturing_function() {
indoc!(
r#"
a = \b ->
c : U32 -> U32
c = \d ->
if True then d else c (d+b)
c 0
a 6
"#
)
}
#[mono_test]
fn call_function_in_empty_list() {
indoc!(