Ensure that closures inside recursive closures capture correctly

With a code like

```
thenDo = \x, callback ->
    callback x

f = \{} ->
    code = 10u16

    bf = \{} ->
        thenDo code \_ -> bf {}

    bf {}
```

The lambda `\_ -> bf {}` must capture `bf`. Previously, this would not
happen correctly, because we assumed that mutually recursive functions
(including singleton recursive functions, like `bf` here) cannot capture
themselves.

Of course, that premise does not hold in general. Instead, we should have
mutually recursive functions capture the closure (haha, get it) of
values captured by all functions constituting the mutual recursion.
Then, any nested closures can capture outer recursive closures' values
appropriately.
This commit is contained in:
Ayaz Hafiz 2023-03-20 17:40:13 -04:00
parent ceacc1792d
commit e8a29d2df4
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
3 changed files with 69 additions and 18 deletions

View file

@ -837,21 +837,10 @@ fn fix_values_captured_in_closure_defs(
no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) {
// recursive defs cannot capture each other
for def in defs.iter() {
no_capture_symbols.extend(
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value),
);
}
for def in defs.iter_mut() {
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
}
// Mutually recursive functions should both capture the union of all their capture sets
//
// Really unfortunate we make a lot of clones here, can this be done more efficiently?
let mut total_capture_set = Vec::default();
let mut total_capture_set = VecMap::default();
for def in defs.iter_mut() {
if let Expr::Closure(ClosureData {
captured_symbols, ..
@ -860,8 +849,16 @@ fn fix_values_captured_in_closure_defs(
total_capture_set.extend(captured_symbols.iter().copied());
}
}
for def in defs.iter() {
for symbol in
crate::traverse::symbols_introduced_from_pattern(&def.loc_pattern).map(|ls| ls.value)
{
total_capture_set.remove(&symbol);
}
}
let mut total_capture_set: Vec<_> = total_capture_set.into_iter().collect();
total_capture_set.sort_by_key(|(sym, _)| *sym);
total_capture_set.dedup_by_key(|(sym, _)| *sym);
for def in defs.iter_mut() {
if let Expr::Closure(ClosureData {
captured_symbols, ..
@ -870,6 +867,10 @@ fn fix_values_captured_in_closure_defs(
*captured_symbols = total_capture_set.clone();
}
}
for def in defs.iter_mut() {
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
}
}
fn fix_values_captured_in_closure_pattern(
@ -1032,9 +1033,9 @@ fn fix_values_captured_in_closure_expr(
captured_symbols.retain(|(s, _)| s != name);
let original_captures_len = captured_symbols.len();
let mut num_visited = 0;
let mut i = 0;
while num_visited < original_captures_len {
let mut added_captures = false;
while i < original_captures_len {
// If we've captured a capturing closure, replace the captured closure symbol with
// the symbols of its captures. That way, we can construct the closure with the
// captures it needs inside our body.
@ -1048,19 +1049,21 @@ fn fix_values_captured_in_closure_expr(
let (captured_symbol, _) = captured_symbols[i];
if let Some(captures) = closure_captures.get(&captured_symbol) {
debug_assert!(!captures.is_empty());
captured_symbols.swap_remove(i);
captured_symbols.extend(captures);
captured_symbols.swap_remove(i);
// Jump two, because the next element is now one of the newly-added captures,
// which we don't need to check.
i += 2;
added_captures = true;
} else {
i += 1;
}
num_visited += 1;
}
if captured_symbols.len() > original_captures_len {
if added_captures {
// Re-sort, since we've added new captures.
captured_symbols.sort_by_key(|(sym, _)| *sym);
captured_symbols.dedup_by_key(|(sym, _)| *sym);
}
if captured_symbols.is_empty() {