Merge pull request #5167 from roc-lang/fix-closure-captures-recursive

Ensure that closures inside recursive closures capture correctly
This commit is contained in:
Ayaz 2023-03-21 13:53:33 -04:00 committed by GitHub
commit 6b3f3ba1a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 18 deletions

View file

@ -837,21 +837,10 @@ fn fix_values_captured_in_closure_defs(
no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) {
// recursive defs cannot capture each other
for def in defs.iter() {
no_capture_symbols.extend(
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value),
);
}
for def in defs.iter_mut() {
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
}
// Mutually recursive functions should both capture the union of all their capture sets
//
// Really unfortunate we make a lot of clones here, can this be done more efficiently?
let mut total_capture_set = Vec::default();
let mut total_capture_set = VecMap::default();
for def in defs.iter_mut() {
if let Expr::Closure(ClosureData {
captured_symbols, ..
@ -860,8 +849,16 @@ fn fix_values_captured_in_closure_defs(
total_capture_set.extend(captured_symbols.iter().copied());
}
}
for def in defs.iter() {
for symbol in
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value)
{
total_capture_set.remove(&symbol);
}
}
let mut total_capture_set: Vec<_> = total_capture_set.into_iter().collect();
total_capture_set.sort_by_key(|(sym, _)| *sym);
total_capture_set.dedup_by_key(|(sym, _)| *sym);
for def in defs.iter_mut() {
if let Expr::Closure(ClosureData {
captured_symbols, ..
@ -870,6 +867,10 @@ fn fix_values_captured_in_closure_defs(
*captured_symbols = total_capture_set.clone();
}
}
for def in defs.iter_mut() {
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
}
}
fn fix_values_captured_in_closure_pattern(
@ -1032,9 +1033,9 @@ fn fix_values_captured_in_closure_expr(
captured_symbols.retain(|(s, _)| s != name);
let original_captures_len = captured_symbols.len();
let mut num_visited = 0;
let mut i = 0;
while num_visited < original_captures_len {
let mut added_captures = false;
while i < original_captures_len {
// If we've captured a capturing closure, replace the captured closure symbol with
// the symbols of its captures. That way, we can construct the closure with the
// captures it needs inside our body.
@ -1048,19 +1049,21 @@ fn fix_values_captured_in_closure_expr(
let (captured_symbol, _) = captured_symbols[i];
if let Some(captures) = closure_captures.get(&captured_symbol) {
debug_assert!(!captures.is_empty());
captured_symbols.swap_remove(i);
captured_symbols.extend(captures);
captured_symbols.swap_remove(i);
// Jump two, because the next element is now one of the newly-added captures,
// which we don't need to check.
i += 2;
added_captures = true;
} else {
i += 1;
}
num_visited += 1;
}
if captured_symbols.len() > original_captures_len {
if added_captures {
// Re-sort, since we've added new captures.
captured_symbols.sort_by_key(|(sym, _)| *sym);
captured_symbols.dedup_by_key(|(sym, _)| *sym);
}
if captured_symbols.is_empty() {

View file

@ -8778,4 +8778,32 @@ mod solve_expr {
@"main : List w_a"
);
}
#[test]
fn recursive_closure_with_transiently_used_capture() {
infer_queries!(
indoc!(
r#"
app "test" provides [f] to "./platform"
thenDo = \x, callback ->
callback x
f = \{} ->
code = 10u16
bf = \{} ->
#^^{-1}
thenDo code \_ -> bf {}
# ^^^^^^^^^^^
bf {}
"#
),
@r###"
bf : {} -[[bf(5) U16]]-> *
\_ -> bf {} : U16 -[[6 U16]]-> *
"###
);
}
}

View file

@ -0,0 +1,18 @@
procedure Test.1 (Test.2, Test.3):
let Test.14 : [] = CallByName Test.6 Test.2 Test.3;
ret Test.14;
procedure Test.5 (Test.8, Test.4):
let Test.12 : [] = CallByName Test.1 Test.4 Test.4;
ret Test.12;
procedure Test.6 (Test.15, Test.4):
let Test.18 : {} = Struct {};
let Test.17 : [] = CallByName Test.5 Test.18 Test.4;
ret Test.17;
procedure Test.0 (Test.7):
let Test.4 : U16 = 10i64;
let Test.10 : {} = Struct {};
let Test.9 : [] = CallByName Test.5 Test.10 Test.4;
ret Test.9;

View file

@ -2769,3 +2769,23 @@ fn inline_return_joinpoints_in_union_lambda_set() {
"#
)
}
#[mono_test]
fn recursive_closure_with_transiently_used_capture() {
indoc!(
r#"
app "test" provides [f] to "./platform"
thenDo = \x, callback ->
callback x
f = \{} ->
code = 10u16
bf = \{} ->
thenDo code \_ -> bf {}
bf {}
"#
)
}