mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 13:29:12 +00:00
Merge pull request #5167 from roc-lang/fix-closure-captures-recursive
Ensure that closures inside recursive closures capture correctly
This commit is contained in:
commit
6b3f3ba1a1
4 changed files with 87 additions and 18 deletions
|
@ -837,21 +837,10 @@ fn fix_values_captured_in_closure_defs(
|
||||||
no_capture_symbols: &mut VecSet<Symbol>,
|
no_capture_symbols: &mut VecSet<Symbol>,
|
||||||
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
||||||
) {
|
) {
|
||||||
// recursive defs cannot capture each other
|
|
||||||
for def in defs.iter() {
|
|
||||||
no_capture_symbols.extend(
|
|
||||||
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
for def in defs.iter_mut() {
|
|
||||||
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mutually recursive functions should both capture the union of all their capture sets
|
// Mutually recursive functions should both capture the union of all their capture sets
|
||||||
//
|
//
|
||||||
// Really unfortunate we make a lot of clones here, can this be done more efficiently?
|
// Really unfortunate we make a lot of clones here, can this be done more efficiently?
|
||||||
let mut total_capture_set = Vec::default();
|
let mut total_capture_set = VecMap::default();
|
||||||
for def in defs.iter_mut() {
|
for def in defs.iter_mut() {
|
||||||
if let Expr::Closure(ClosureData {
|
if let Expr::Closure(ClosureData {
|
||||||
captured_symbols, ..
|
captured_symbols, ..
|
||||||
|
@ -860,8 +849,16 @@ fn fix_values_captured_in_closure_defs(
|
||||||
total_capture_set.extend(captured_symbols.iter().copied());
|
total_capture_set.extend(captured_symbols.iter().copied());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for def in defs.iter() {
|
||||||
|
for symbol in
|
||||||
|
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value)
|
||||||
|
{
|
||||||
|
total_capture_set.remove(&symbol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut total_capture_set: Vec<_> = total_capture_set.into_iter().collect();
|
||||||
total_capture_set.sort_by_key(|(sym, _)| *sym);
|
total_capture_set.sort_by_key(|(sym, _)| *sym);
|
||||||
total_capture_set.dedup_by_key(|(sym, _)| *sym);
|
|
||||||
for def in defs.iter_mut() {
|
for def in defs.iter_mut() {
|
||||||
if let Expr::Closure(ClosureData {
|
if let Expr::Closure(ClosureData {
|
||||||
captured_symbols, ..
|
captured_symbols, ..
|
||||||
|
@ -870,6 +867,10 @@ fn fix_values_captured_in_closure_defs(
|
||||||
*captured_symbols = total_capture_set.clone();
|
*captured_symbols = total_capture_set.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for def in defs.iter_mut() {
|
||||||
|
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fix_values_captured_in_closure_pattern(
|
fn fix_values_captured_in_closure_pattern(
|
||||||
|
@ -1032,9 +1033,9 @@ fn fix_values_captured_in_closure_expr(
|
||||||
captured_symbols.retain(|(s, _)| s != name);
|
captured_symbols.retain(|(s, _)| s != name);
|
||||||
|
|
||||||
let original_captures_len = captured_symbols.len();
|
let original_captures_len = captured_symbols.len();
|
||||||
let mut num_visited = 0;
|
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
while num_visited < original_captures_len {
|
let mut added_captures = false;
|
||||||
|
while i < original_captures_len {
|
||||||
// If we've captured a capturing closure, replace the captured closure symbol with
|
// If we've captured a capturing closure, replace the captured closure symbol with
|
||||||
// the symbols of its captures. That way, we can construct the closure with the
|
// the symbols of its captures. That way, we can construct the closure with the
|
||||||
// captures it needs inside our body.
|
// captures it needs inside our body.
|
||||||
|
@ -1048,19 +1049,21 @@ fn fix_values_captured_in_closure_expr(
|
||||||
let (captured_symbol, _) = captured_symbols[i];
|
let (captured_symbol, _) = captured_symbols[i];
|
||||||
if let Some(captures) = closure_captures.get(&captured_symbol) {
|
if let Some(captures) = closure_captures.get(&captured_symbol) {
|
||||||
debug_assert!(!captures.is_empty());
|
debug_assert!(!captures.is_empty());
|
||||||
captured_symbols.swap_remove(i);
|
|
||||||
captured_symbols.extend(captures);
|
captured_symbols.extend(captures);
|
||||||
|
captured_symbols.swap_remove(i);
|
||||||
// Jump two, because the next element is now one of the newly-added captures,
|
// Jump two, because the next element is now one of the newly-added captures,
|
||||||
// which we don't need to check.
|
// which we don't need to check.
|
||||||
i += 2;
|
i += 2;
|
||||||
|
|
||||||
|
added_captures = true;
|
||||||
} else {
|
} else {
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
num_visited += 1;
|
|
||||||
}
|
}
|
||||||
if captured_symbols.len() > original_captures_len {
|
if added_captures {
|
||||||
// Re-sort, since we've added new captures.
|
// Re-sort, since we've added new captures.
|
||||||
captured_symbols.sort_by_key(|(sym, _)| *sym);
|
captured_symbols.sort_by_key(|(sym, _)| *sym);
|
||||||
|
captured_symbols.dedup_by_key(|(sym, _)| *sym);
|
||||||
}
|
}
|
||||||
|
|
||||||
if captured_symbols.is_empty() {
|
if captured_symbols.is_empty() {
|
||||||
|
|
|
@ -8778,4 +8778,32 @@ mod solve_expr {
|
||||||
@"main : List w_a"
|
@"main : List w_a"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn recursive_closure_with_transiently_used_capture() {
|
||||||
|
infer_queries!(
|
||||||
|
indoc!(
|
||||||
|
r#"
|
||||||
|
app "test" provides [f] to "./platform"
|
||||||
|
|
||||||
|
thenDo = \x, callback ->
|
||||||
|
callback x
|
||||||
|
|
||||||
|
f = \{} ->
|
||||||
|
code = 10u16
|
||||||
|
|
||||||
|
bf = \{} ->
|
||||||
|
#^^{-1}
|
||||||
|
thenDo code \_ -> bf {}
|
||||||
|
# ^^^^^^^^^^^
|
||||||
|
|
||||||
|
bf {}
|
||||||
|
"#
|
||||||
|
),
|
||||||
|
@r###"
|
||||||
|
bf : {} -[[bf(5) U16]]-> *
|
||||||
|
\_ -> bf {} : U16 -[[6 U16]]-> *
|
||||||
|
"###
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
procedure Test.1 (Test.2, Test.3):
|
||||||
|
let Test.14 : [] = CallByName Test.6 Test.2 Test.3;
|
||||||
|
ret Test.14;
|
||||||
|
|
||||||
|
procedure Test.5 (Test.8, Test.4):
|
||||||
|
let Test.12 : [] = CallByName Test.1 Test.4 Test.4;
|
||||||
|
ret Test.12;
|
||||||
|
|
||||||
|
procedure Test.6 (Test.15, Test.4):
|
||||||
|
let Test.18 : {} = Struct {};
|
||||||
|
let Test.17 : [] = CallByName Test.5 Test.18 Test.4;
|
||||||
|
ret Test.17;
|
||||||
|
|
||||||
|
procedure Test.0 (Test.7):
|
||||||
|
let Test.4 : U16 = 10i64;
|
||||||
|
let Test.10 : {} = Struct {};
|
||||||
|
let Test.9 : [] = CallByName Test.5 Test.10 Test.4;
|
||||||
|
ret Test.9;
|
|
@ -2769,3 +2769,23 @@ fn inline_return_joinpoints_in_union_lambda_set() {
|
||||||
"#
|
"#
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[mono_test]
|
||||||
|
fn recursive_closure_with_transiently_used_capture() {
|
||||||
|
indoc!(
|
||||||
|
r#"
|
||||||
|
app "test" provides [f] to "./platform"
|
||||||
|
|
||||||
|
thenDo = \x, callback ->
|
||||||
|
callback x
|
||||||
|
|
||||||
|
f = \{} ->
|
||||||
|
code = 10u16
|
||||||
|
|
||||||
|
bf = \{} ->
|
||||||
|
thenDo code \_ -> bf {}
|
||||||
|
|
||||||
|
bf {}
|
||||||
|
"#
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue