mirror of
https://github.com/roc-lang/roc.git
synced 2025-08-02 11:22:19 +00:00
Merge pull request #5167 from roc-lang/fix-closure-captures-recursive
Ensure that closures inside recursive closures capture correctly
This commit is contained in:
commit
6b3f3ba1a1
4 changed files with 87 additions and 18 deletions
|
@ -837,21 +837,10 @@ fn fix_values_captured_in_closure_defs(
|
|||
no_capture_symbols: &mut VecSet<Symbol>,
|
||||
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
||||
) {
|
||||
// recursive defs cannot capture each other
|
||||
for def in defs.iter() {
|
||||
no_capture_symbols.extend(
|
||||
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value),
|
||||
);
|
||||
}
|
||||
|
||||
for def in defs.iter_mut() {
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
||||
}
|
||||
|
||||
// Mutually recursive functions should both capture the union of all their capture sets
|
||||
//
|
||||
// Really unfortunate we make a lot of clones here, can this be done more efficiently?
|
||||
let mut total_capture_set = Vec::default();
|
||||
let mut total_capture_set = VecMap::default();
|
||||
for def in defs.iter_mut() {
|
||||
if let Expr::Closure(ClosureData {
|
||||
captured_symbols, ..
|
||||
|
@ -860,8 +849,16 @@ fn fix_values_captured_in_closure_defs(
|
|||
total_capture_set.extend(captured_symbols.iter().copied());
|
||||
}
|
||||
}
|
||||
for def in defs.iter() {
|
||||
for symbol in
|
||||
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value)
|
||||
{
|
||||
total_capture_set.remove(&symbol);
|
||||
}
|
||||
}
|
||||
|
||||
let mut total_capture_set: Vec<_> = total_capture_set.into_iter().collect();
|
||||
total_capture_set.sort_by_key(|(sym, _)| *sym);
|
||||
total_capture_set.dedup_by_key(|(sym, _)| *sym);
|
||||
for def in defs.iter_mut() {
|
||||
if let Expr::Closure(ClosureData {
|
||||
captured_symbols, ..
|
||||
|
@ -870,6 +867,10 @@ fn fix_values_captured_in_closure_defs(
|
|||
*captured_symbols = total_capture_set.clone();
|
||||
}
|
||||
}
|
||||
|
||||
for def in defs.iter_mut() {
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
||||
}
|
||||
}
|
||||
|
||||
fn fix_values_captured_in_closure_pattern(
|
||||
|
@ -1032,9 +1033,9 @@ fn fix_values_captured_in_closure_expr(
|
|||
captured_symbols.retain(|(s, _)| s != name);
|
||||
|
||||
let original_captures_len = captured_symbols.len();
|
||||
let mut num_visited = 0;
|
||||
let mut i = 0;
|
||||
while num_visited < original_captures_len {
|
||||
let mut added_captures = false;
|
||||
while i < original_captures_len {
|
||||
// If we've captured a capturing closure, replace the captured closure symbol with
|
||||
// the symbols of its captures. That way, we can construct the closure with the
|
||||
// captures it needs inside our body.
|
||||
|
@ -1048,19 +1049,21 @@ fn fix_values_captured_in_closure_expr(
|
|||
let (captured_symbol, _) = captured_symbols[i];
|
||||
if let Some(captures) = closure_captures.get(&captured_symbol) {
|
||||
debug_assert!(!captures.is_empty());
|
||||
captured_symbols.swap_remove(i);
|
||||
captured_symbols.extend(captures);
|
||||
captured_symbols.swap_remove(i);
|
||||
// Jump two, because the next element is now one of the newly-added captures,
|
||||
// which we don't need to check.
|
||||
i += 2;
|
||||
|
||||
added_captures = true;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
num_visited += 1;
|
||||
}
|
||||
if captured_symbols.len() > original_captures_len {
|
||||
if added_captures {
|
||||
// Re-sort, since we've added new captures.
|
||||
captured_symbols.sort_by_key(|(sym, _)| *sym);
|
||||
captured_symbols.dedup_by_key(|(sym, _)| *sym);
|
||||
}
|
||||
|
||||
if captured_symbols.is_empty() {
|
||||
|
|
|
@ -8778,4 +8778,32 @@ mod solve_expr {
|
|||
@"main : List w_a"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recursive_closure_with_transiently_used_capture() {
|
||||
infer_queries!(
|
||||
indoc!(
|
||||
r#"
|
||||
app "test" provides [f] to "./platform"
|
||||
|
||||
thenDo = \x, callback ->
|
||||
callback x
|
||||
|
||||
f = \{} ->
|
||||
code = 10u16
|
||||
|
||||
bf = \{} ->
|
||||
#^^{-1}
|
||||
thenDo code \_ -> bf {}
|
||||
# ^^^^^^^^^^^
|
||||
|
||||
bf {}
|
||||
"#
|
||||
),
|
||||
@r###"
|
||||
bf : {} -[[bf(5) U16]]-> *
|
||||
\_ -> bf {} : U16 -[[6 U16]]-> *
|
||||
"###
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
procedure Test.1 (Test.2, Test.3):
|
||||
let Test.14 : [] = CallByName Test.6 Test.2 Test.3;
|
||||
ret Test.14;
|
||||
|
||||
procedure Test.5 (Test.8, Test.4):
|
||||
let Test.12 : [] = CallByName Test.1 Test.4 Test.4;
|
||||
ret Test.12;
|
||||
|
||||
procedure Test.6 (Test.15, Test.4):
|
||||
let Test.18 : {} = Struct {};
|
||||
let Test.17 : [] = CallByName Test.5 Test.18 Test.4;
|
||||
ret Test.17;
|
||||
|
||||
procedure Test.0 (Test.7):
|
||||
let Test.4 : U16 = 10i64;
|
||||
let Test.10 : {} = Struct {};
|
||||
let Test.9 : [] = CallByName Test.5 Test.10 Test.4;
|
||||
ret Test.9;
|
|
@ -2769,3 +2769,23 @@ fn inline_return_joinpoints_in_union_lambda_set() {
|
|||
"#
|
||||
)
|
||||
}
|
||||
|
||||
#[mono_test]
|
||||
fn recursive_closure_with_transiently_used_capture() {
|
||||
indoc!(
|
||||
r#"
|
||||
app "test" provides [f] to "./platform"
|
||||
|
||||
thenDo = \x, callback ->
|
||||
callback x
|
||||
|
||||
f = \{} ->
|
||||
code = 10u16
|
||||
|
||||
bf = \{} ->
|
||||
thenDo code \_ -> bf {}
|
||||
|
||||
bf {}
|
||||
"#
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue