mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 13:29:12 +00:00
Ensure that closures inside recursive closures capture correctly
With a code like ``` thenDo = \x, callback -> callback x f = \{} -> code = 10u16 bf = \{} -> thenDo code \_ -> bf {} bf {} ``` The lambda `\_ -> bf {}` must capture `bf`. Previously, this would not happen correctly, because we assumed that mutually recursive functions (including singleton recursive functions, like `bf` here) cannot capture themselves. Of course, that premise does not hold in general. Instead, we should have mutually recursive functions capture the closure (haha, get it) of values captured by all functions constituting the mutual recursion. Then, any nested closures can capture outer recursive closures' values appropriately.
This commit is contained in:
parent
ceacc1792d
commit
e8a29d2df4
3 changed files with 69 additions and 18 deletions
|
@ -837,21 +837,10 @@ fn fix_values_captured_in_closure_defs(
|
|||
no_capture_symbols: &mut VecSet<Symbol>,
|
||||
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
||||
) {
|
||||
// recursive defs cannot capture each other
|
||||
for def in defs.iter() {
|
||||
no_capture_symbols.extend(
|
||||
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value),
|
||||
);
|
||||
}
|
||||
|
||||
for def in defs.iter_mut() {
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
||||
}
|
||||
|
||||
// Mutually recursive functions should both capture the union of all their capture sets
|
||||
//
|
||||
// Really unfortunate we make a lot of clones here, can this be done more efficiently?
|
||||
let mut total_capture_set = Vec::default();
|
||||
let mut total_capture_set = VecMap::default();
|
||||
for def in defs.iter_mut() {
|
||||
if let Expr::Closure(ClosureData {
|
||||
captured_symbols, ..
|
||||
|
@ -860,8 +849,16 @@ fn fix_values_captured_in_closure_defs(
|
|||
total_capture_set.extend(captured_symbols.iter().copied());
|
||||
}
|
||||
}
|
||||
for def in defs.iter() {
|
||||
for symbol in
|
||||
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value)
|
||||
{
|
||||
total_capture_set.remove(&symbol);
|
||||
}
|
||||
}
|
||||
|
||||
let mut total_capture_set: Vec<_> = total_capture_set.into_iter().collect();
|
||||
total_capture_set.sort_by_key(|(sym, _)| *sym);
|
||||
total_capture_set.dedup_by_key(|(sym, _)| *sym);
|
||||
for def in defs.iter_mut() {
|
||||
if let Expr::Closure(ClosureData {
|
||||
captured_symbols, ..
|
||||
|
@ -870,6 +867,10 @@ fn fix_values_captured_in_closure_defs(
|
|||
*captured_symbols = total_capture_set.clone();
|
||||
}
|
||||
}
|
||||
|
||||
for def in defs.iter_mut() {
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
||||
}
|
||||
}
|
||||
|
||||
fn fix_values_captured_in_closure_pattern(
|
||||
|
@ -1032,9 +1033,9 @@ fn fix_values_captured_in_closure_expr(
|
|||
captured_symbols.retain(|(s, _)| s != name);
|
||||
|
||||
let original_captures_len = captured_symbols.len();
|
||||
let mut num_visited = 0;
|
||||
let mut i = 0;
|
||||
while num_visited < original_captures_len {
|
||||
let mut added_captures = false;
|
||||
while i < original_captures_len {
|
||||
// If we've captured a capturing closure, replace the captured closure symbol with
|
||||
// the symbols of its captures. That way, we can construct the closure with the
|
||||
// captures it needs inside our body.
|
||||
|
@ -1048,19 +1049,21 @@ fn fix_values_captured_in_closure_expr(
|
|||
let (captured_symbol, _) = captured_symbols[i];
|
||||
if let Some(captures) = closure_captures.get(&captured_symbol) {
|
||||
debug_assert!(!captures.is_empty());
|
||||
captured_symbols.swap_remove(i);
|
||||
captured_symbols.extend(captures);
|
||||
captured_symbols.swap_remove(i);
|
||||
// Jump two, because the next element is now one of the newly-added captures,
|
||||
// which we don't need to check.
|
||||
i += 2;
|
||||
|
||||
added_captures = true;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
num_visited += 1;
|
||||
}
|
||||
if captured_symbols.len() > original_captures_len {
|
||||
if added_captures {
|
||||
// Re-sort, since we've added new captures.
|
||||
captured_symbols.sort_by_key(|(sym, _)| *sym);
|
||||
captured_symbols.dedup_by_key(|(sym, _)| *sym);
|
||||
}
|
||||
|
||||
if captured_symbols.is_empty() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue