Correctly choose specialized shapes for anonymous closures

This commit is contained in:
Ayaz Hafiz 2022-12-27 09:18:41 -06:00
parent 593344f5c5
commit 1e847efbfe
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
4 changed files with 205 additions and 39 deletions

View file

@ -1089,12 +1089,8 @@ impl<'a> Procs<'a> {
.raw_from_var(env.arena, annotation, env.subs)
.unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err));
let top_level = ProcLayout::from_raw(
env.arena,
&layout_cache.interner,
raw_layout,
name.captures_niche(),
);
let top_level =
ProcLayout::from_raw_named(env.arena, name, raw_layout, name.captures_niche());
// anonymous functions cannot reference themselves, therefore cannot be tail-recursive
// EXCEPT when the closure conversion makes it tail-recursive.
@ -1169,16 +1165,8 @@ impl<'a> Procs<'a> {
let partial_proc_id = if let Some(partial_proc_id) =
self.partial_procs.symbol_to_id(name.name())
{
let existing = self.partial_procs.get_id(partial_proc_id);
// if we're adding the same partial proc twice, they must be the actual same!
//
// NOTE we can't skip extra work! we still need to make the specialization for this
// invocation. The content of the `annotation` can be different, even if the variable
// number is the same
debug_assert_eq!(annotation, existing.annotation);
debug_assert_eq!(captured_symbols, existing.captured_symbols);
debug_assert_eq!(is_self_recursive, existing.is_self_recursive);
// NOTE we can't skip extra work! We still need to make the specialization for this
// invocation.
partial_proc_id
} else {
let pattern_symbols = pattern_symbols.into_bump_slice();
@ -1204,29 +1192,17 @@ impl<'a> Procs<'a> {
&[],
partial_proc_id,
) {
Ok((proc, _ignore_layout)) => {
// the `layout` is a function pointer, while `_ignore_layout` can be a
// closure. We only specialize functions, storing this value with a closure
// layout will give trouble.
let arguments = Vec::from_iter_in(
proc.args.iter().map(|(l, _)| *l),
Ok((proc, layout)) => {
let proc_name = proc.name;
let function_layout = ProcLayout::from_raw_named(
env.arena,
)
.into_bump_slice();
let proper_layout = ProcLayout {
arguments,
result: proc.ret_layout,
captures_niche: proc.name.captures_niche(),
};
// NOTE: some functions are specialized to have a closure, but don't actually
// need any closure argument. Here is where we correct this sort of thing,
// by trusting the layout of the Proc, not of what we specialize for
self.specialized.remove_specialized(name.name(), &layout);
proc_name,
layout,
proc_name.captures_niche(),
);
self.specialized.insert_specialized(
name.name(),
proper_layout,
proc_name.name(),
function_layout,
proc,
);
}

View file

@ -0,0 +1,170 @@
procedure Bool.11 (#Attr.2, #Attr.3):
let Bool.23 : Int1 = lowlevel Eq #Attr.2 #Attr.3;
ret Bool.23;
procedure Bool.11 (#Attr.2, #Attr.3):
let Bool.24 : Int1 = lowlevel Eq #Attr.2 #Attr.3;
ret Bool.24;
procedure List.26 (List.152, List.153, List.154):
let List.493 : [C U64, C U64] = CallByName List.90 List.152 List.153 List.154;
let List.496 : U8 = 1i64;
let List.497 : U8 = GetTagId List.493;
let List.498 : Int1 = lowlevel Eq List.496 List.497;
if List.498 then
let List.155 : U64 = UnionAtIndex (Id 1) (Index 0) List.493;
ret List.155;
else
let List.156 : U64 = UnionAtIndex (Id 0) (Index 0) List.493;
ret List.156;
procedure List.26 (List.152, List.153, List.154):
let List.515 : [C I64, C I64] = CallByName List.90 List.152 List.153 List.154;
let List.518 : U8 = 1i64;
let List.519 : U8 = GetTagId List.515;
let List.520 : Int1 = lowlevel Eq List.518 List.519;
if List.520 then
let List.155 : I64 = UnionAtIndex (Id 1) (Index 0) List.515;
ret List.155;
else
let List.156 : I64 = UnionAtIndex (Id 0) (Index 0) List.515;
ret List.156;
procedure List.29 (List.294, List.295):
let List.492 : U64 = CallByName List.6 List.294;
let List.296 : U64 = CallByName Num.77 List.492 List.295;
let List.478 : List U8 = CallByName List.43 List.294 List.296;
ret List.478;
procedure List.43 (List.292, List.293):
let List.490 : U64 = CallByName List.6 List.292;
let List.489 : U64 = CallByName Num.77 List.490 List.293;
let List.480 : {U64, U64} = Struct {List.293, List.489};
let List.479 : List U8 = CallByName List.49 List.292 List.480;
ret List.479;
procedure List.49 (List.366, List.367):
let List.487 : U64 = StructAtIndex 0 List.367;
let List.488 : U64 = 0i64;
let List.485 : Int1 = CallByName Bool.11 List.487 List.488;
if List.485 then
dec List.366;
let List.486 : List U8 = Array [];
ret List.486;
else
let List.482 : U64 = StructAtIndex 1 List.367;
let List.483 : U64 = StructAtIndex 0 List.367;
let List.481 : List U8 = CallByName List.72 List.366 List.482 List.483;
ret List.481;
procedure List.6 (#Attr.2):
let List.491 : U64 = lowlevel ListLen #Attr.2;
ret List.491;
procedure List.66 (#Attr.2, #Attr.3):
let List.514 : U8 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.514;
procedure List.72 (#Attr.2, #Attr.3, #Attr.4):
let List.484 : List U8 = lowlevel ListSublist #Attr.2 #Attr.3 #Attr.4;
ret List.484;
procedure List.90 (List.426, List.427, List.428):
let List.500 : U64 = 0i64;
let List.501 : U64 = CallByName List.6 List.426;
let List.499 : [C U64, C U64] = CallByName List.91 List.426 List.427 List.428 List.500 List.501;
ret List.499;
procedure List.90 (List.426, List.427, List.428):
let List.522 : U64 = 0i64;
let List.523 : U64 = CallByName List.6 List.426;
let List.521 : [C I64, C I64] = CallByName List.91 List.426 List.427 List.428 List.522 List.523;
ret List.521;
procedure List.91 (List.549, List.550, List.551, List.552, List.553):
joinpoint List.502 List.429 List.430 List.431 List.432 List.433:
let List.504 : Int1 = CallByName Num.22 List.432 List.433;
if List.504 then
let List.513 : U8 = CallByName List.66 List.429 List.432;
let List.505 : [C U64, C U64] = CallByName Test.4 List.430 List.513;
let List.510 : U8 = 1i64;
let List.511 : U8 = GetTagId List.505;
let List.512 : Int1 = lowlevel Eq List.510 List.511;
if List.512 then
let List.434 : U64 = UnionAtIndex (Id 1) (Index 0) List.505;
let List.508 : U64 = 1i64;
let List.507 : U64 = CallByName Num.19 List.432 List.508;
jump List.502 List.429 List.434 List.431 List.507 List.433;
else
let List.435 : U64 = UnionAtIndex (Id 0) (Index 0) List.505;
let List.509 : [C U64, C U64] = TagId(0) List.435;
ret List.509;
else
let List.503 : [C U64, C U64] = TagId(1) List.430;
ret List.503;
in
jump List.502 List.549 List.550 List.551 List.552 List.553;
procedure List.91 (List.562, List.563, List.564, List.565, List.566):
joinpoint List.524 List.429 List.430 List.431 List.432 List.433:
let List.526 : Int1 = CallByName Num.22 List.432 List.433;
if List.526 then
let List.535 : U8 = CallByName List.66 List.429 List.432;
let List.527 : [C I64, C I64] = CallByName Test.4 List.430 List.535;
let List.532 : U8 = 1i64;
let List.533 : U8 = GetTagId List.527;
let List.534 : Int1 = lowlevel Eq List.532 List.533;
if List.534 then
let List.434 : I64 = UnionAtIndex (Id 1) (Index 0) List.527;
let List.530 : U64 = 1i64;
let List.529 : U64 = CallByName Num.19 List.432 List.530;
jump List.524 List.429 List.434 List.431 List.529 List.433;
else
let List.435 : I64 = UnionAtIndex (Id 0) (Index 0) List.527;
let List.531 : [C I64, C I64] = TagId(0) List.435;
ret List.531;
else
let List.525 : [C I64, C I64] = TagId(1) List.430;
ret List.525;
in
jump List.524 List.562 List.563 List.564 List.565 List.566;
procedure Num.19 (#Attr.2, #Attr.3):
let Num.259 : U64 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Num.259;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.261 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;
ret Num.261;
procedure Num.77 (#Attr.2, #Attr.3):
let Num.257 : U64 = lowlevel NumSubSaturated #Attr.2 #Attr.3;
ret Num.257;
procedure Test.1 (Test.2):
let Test.18 : I64 = 0i64;
let Test.19 : {} = Struct {};
let Test.12 : I64 = CallByName List.26 Test.2 Test.18 Test.19;
let Test.14 : U64 = 0i64;
let Test.15 : {} = Struct {};
let Test.3 : U64 = CallByName List.26 Test.2 Test.14 Test.15;
let Test.13 : I64 = 0i64;
let Test.10 : Int1 = CallByName Bool.11 Test.12 Test.13;
if Test.10 then
ret Test.2;
else
let Test.9 : List U8 = CallByName List.29 Test.2 Test.3;
ret Test.9;
procedure Test.4 (Test.5, Test.16):
let Test.17 : [C U64, C U64] = TagId(0) Test.5;
ret Test.17;
procedure Test.4 (Test.5, Test.16):
let Test.21 : [C U64, C U64] = TagId(0) Test.5;
ret Test.21;
procedure Test.0 ():
let Test.8 : List U8 = Array [1i64, 2i64, 3i64];
let Test.7 : List U8 = CallByName Test.1 Test.8;
ret Test.7;

View file

@ -62,7 +62,7 @@ procedure Decode.26 (Decode.100, Decode.101):
let Decode.116 : [C [C List U8, C ], C Str] = TagId(0) Decode.117;
ret Decode.116;
procedure Json.139 (Json.450, Json.451):
procedure Json.139 (Json.452, Json.453):
joinpoint Json.421 Json.418 Json.138:
let Json.141 : List U8 = StructAtIndex 0 Json.418;
inc Json.141;
@ -91,7 +91,7 @@ procedure Json.139 (Json.450, Json.451):
let Json.435 : {List U8, List U8} = Struct {Json.141, Json.140};
ret Json.435;
in
jump Json.421 Json.450 Json.451;
jump Json.421 Json.452 Json.453;
procedure Json.143 (Json.432):
let Json.433 : List U8 = StructAtIndex 1 Json.432;

View file

@ -2220,3 +2220,23 @@ fn lambda_set_with_imported_toplevels_issue_4733() {
"###
)
}
#[mono_test]
fn issue_4717() {
indoc!(
r###"
app "test" provides [main] to "platform"
chompWhile : (List U8) -> (List U8)
chompWhile = \input ->
index = List.walkUntil input 0 \i, _ -> Break i
if index == 0 then
input
else
List.drop input index
main = chompWhile [1u8, 2u8, 3u8]
"###
)
}