mirror of
https://github.com/roc-lang/roc.git
synced 2025-07-24 06:55:15 +00:00
Fixup transient closure captures during canonicalization
Closure captures can be transient, but previously, we did not handle that correctly. For example, in ``` x = "" inner = \{} -> x outer = \{} -> inner {} ``` `outer` captures `inner`, but `inner` captures `x`, and in the body of `outer`, we would not construct the closure data for `inner` correctly before calling it. There are a couple ways around this. 1. Update mono to do something when we are passed the captured environment of a closure, rather than attempting to construct a call-by-name's captured environment before callign it. 2. Fix-up closures during canonicalization to remove captured closures that themselves capture, and replace them with their captures. This patch does (2), since (1) is much more involved and is not likely to bring a lot of wins. In general I think it's reasonable to expect captured environments, even if transient, to be fairly shallow, so I don't think this will produce very large closure environments. Closes #2894
This commit is contained in:
parent
565ffacb9a
commit
e97ce32b88
2 changed files with 235 additions and 36 deletions
|
@ -704,7 +704,11 @@ pub fn canonicalize_module_defs<'a>(
|
|||
// def pattern has no default expressions, so skip
|
||||
let loc_expr = &mut declarations.expressions[index];
|
||||
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
&mut VecSet::default(),
|
||||
&mut VecMap::default(),
|
||||
);
|
||||
}
|
||||
Function(f_index) | Recursive(f_index) | TailRecursive(f_index) => {
|
||||
let name = declarations.symbols[index].value;
|
||||
|
@ -717,35 +721,61 @@ pub fn canonicalize_module_defs<'a>(
|
|||
if function_def.captured_symbols.is_empty() {
|
||||
no_capture_symbols.insert(name);
|
||||
}
|
||||
let mut closure_captures = VecMap::default();
|
||||
|
||||
// patterns can contain default expressions, so must go over them too!
|
||||
for (_, _, loc_pat) in function_def.arguments.iter_mut() {
|
||||
fix_values_captured_in_closure_pattern(
|
||||
&mut loc_pat.value,
|
||||
&mut no_capture_symbols,
|
||||
&mut closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
&mut no_capture_symbols,
|
||||
&mut closure_captures,
|
||||
);
|
||||
}
|
||||
Destructure(d_index) => {
|
||||
let destruct_def = &mut declarations.destructs[d_index.index()];
|
||||
let loc_pat = &mut destruct_def.loc_pattern;
|
||||
let loc_expr = &mut declarations.expressions[index];
|
||||
|
||||
fix_values_captured_in_closure_pattern(&mut loc_pat.value, &mut VecSet::default());
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
|
||||
let mut closure_captures = VecMap::default();
|
||||
|
||||
fix_values_captured_in_closure_pattern(
|
||||
&mut loc_pat.value,
|
||||
&mut VecSet::default(),
|
||||
&mut closure_captures,
|
||||
);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
&mut VecSet::default(),
|
||||
&mut closure_captures,
|
||||
);
|
||||
}
|
||||
MutualRecursion { .. } => {
|
||||
// the declarations of this group will be treaded individually by later iterations
|
||||
}
|
||||
Expectation => {
|
||||
let loc_expr = &mut declarations.expressions[index];
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
|
||||
let mut closure_captures = Default::default();
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
&mut VecSet::default(),
|
||||
&mut closure_captures,
|
||||
);
|
||||
}
|
||||
ExpectationFx => {
|
||||
let loc_expr = &mut declarations.expressions[index];
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
|
||||
let mut closure_captures = Default::default();
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
&mut VecSet::default(),
|
||||
&mut closure_captures,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -771,16 +801,26 @@ pub fn canonicalize_module_defs<'a>(
|
|||
fn fix_values_captured_in_closure_def(
|
||||
def: &mut crate::def::Def,
|
||||
no_capture_symbols: &mut VecSet<Symbol>,
|
||||
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
||||
) {
|
||||
// patterns can contain default expressions, so much go over them too!
|
||||
fix_values_captured_in_closure_pattern(&mut def.loc_pattern.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_pattern(
|
||||
&mut def.loc_pattern.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
|
||||
fix_values_captured_in_closure_expr(&mut def.loc_expr.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut def.loc_expr.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
fn fix_values_captured_in_closure_defs(
|
||||
defs: &mut [crate::def::Def],
|
||||
no_capture_symbols: &mut VecSet<Symbol>,
|
||||
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
||||
) {
|
||||
// recursive defs cannot capture each other
|
||||
for def in defs.iter() {
|
||||
|
@ -792,13 +832,14 @@ fn fix_values_captured_in_closure_defs(
|
|||
// TODO mutually recursive functions should both capture the union of both their capture sets
|
||||
|
||||
for def in defs.iter_mut() {
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols);
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
||||
}
|
||||
}
|
||||
|
||||
fn fix_values_captured_in_closure_pattern(
|
||||
pattern: &mut crate::pattern::Pattern,
|
||||
no_capture_symbols: &mut VecSet<Symbol>,
|
||||
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
||||
) {
|
||||
use crate::pattern::Pattern::*;
|
||||
|
||||
|
@ -808,24 +849,35 @@ fn fix_values_captured_in_closure_pattern(
|
|||
..
|
||||
} => {
|
||||
for (_, loc_arg) in loc_args.iter_mut() {
|
||||
fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_pattern(
|
||||
&mut loc_arg.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
}
|
||||
UnwrappedOpaque { argument, .. } => {
|
||||
let (_, loc_arg) = &mut **argument;
|
||||
fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_pattern(
|
||||
&mut loc_arg.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
RecordDestructure { destructs, .. } => {
|
||||
for loc_destruct in destructs.iter_mut() {
|
||||
use crate::pattern::DestructType::*;
|
||||
match &mut loc_destruct.value.typ {
|
||||
Required => {}
|
||||
Optional(_, loc_expr) => {
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols)
|
||||
}
|
||||
Optional(_, loc_expr) => fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
),
|
||||
Guard(_, loc_pattern) => fix_values_captured_in_closure_pattern(
|
||||
&mut loc_pattern.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
@ -848,19 +900,28 @@ fn fix_values_captured_in_closure_pattern(
|
|||
fn fix_values_captured_in_closure_expr(
|
||||
expr: &mut crate::expr::Expr,
|
||||
no_capture_symbols: &mut VecSet<Symbol>,
|
||||
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
|
||||
) {
|
||||
use crate::expr::Expr::*;
|
||||
|
||||
match expr {
|
||||
LetNonRec(def, loc_expr) => {
|
||||
// LetNonRec(Box<Def>, Box<Located<Expr>>, Variable, Aliases),
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
LetRec(defs, loc_expr, _) => {
|
||||
// LetRec(Vec<Def>, Box<Located<Expr>>, Variable, Aliases),
|
||||
fix_values_captured_in_closure_defs(defs, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_defs(defs, no_capture_symbols, closure_captures);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
Expect {
|
||||
|
@ -868,8 +929,16 @@ fn fix_values_captured_in_closure_expr(
|
|||
loc_continuation,
|
||||
lookups_in_cond: _,
|
||||
} => {
|
||||
fix_values_captured_in_closure_expr(&mut loc_condition.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(&mut loc_continuation.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_condition.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_continuation.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
Closure(ClosureData {
|
||||
|
@ -882,16 +951,55 @@ fn fix_values_captured_in_closure_expr(
|
|||
captured_symbols.retain(|(s, _)| !no_capture_symbols.contains(s));
|
||||
captured_symbols.retain(|(s, _)| s != name);
|
||||
|
||||
let original_captures_len = captured_symbols.len();
|
||||
let mut num_visited = 0;
|
||||
let mut i = 0;
|
||||
while num_visited < original_captures_len {
|
||||
// If we've captured a capturing closure, replace the captured closure symbol with
|
||||
// the symbols of its captures. That way, we can construct the closure with the
|
||||
// captures it needs inside our body.
|
||||
//
|
||||
// E.g.
|
||||
// x = ""
|
||||
// inner = \{} -> x
|
||||
// outer = \{} -> inner {}
|
||||
//
|
||||
// initially `outer` captures [inner], but this is then replaced with just [x].
|
||||
let (captured_symbol, _) = captured_symbols[i];
|
||||
if let Some(captures) = closure_captures.get(&captured_symbol) {
|
||||
captured_symbols.remove(i);
|
||||
captured_symbols.extend(captures);
|
||||
// Don't advance because the next capture was shifted down.
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
num_visited += 1;
|
||||
}
|
||||
if captured_symbols.len() > original_captures_len {
|
||||
// Re-sort, since we've added new captures.
|
||||
captured_symbols.sort_by_key(|(sym, _)| *sym);
|
||||
}
|
||||
|
||||
if captured_symbols.is_empty() {
|
||||
no_capture_symbols.insert(*name);
|
||||
} else {
|
||||
closure_captures.insert(*name, captured_symbols.to_vec());
|
||||
}
|
||||
|
||||
// patterns can contain default expressions, so much go over them too!
|
||||
for (_, _, loc_pat) in arguments.iter_mut() {
|
||||
fix_values_captured_in_closure_pattern(&mut loc_pat.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_pattern(
|
||||
&mut loc_pat.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
fix_values_captured_in_closure_expr(&mut loc_body.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_body.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
Num(..)
|
||||
|
@ -909,28 +1017,45 @@ fn fix_values_captured_in_closure_expr(
|
|||
|
||||
List { loc_elems, .. } => {
|
||||
for elem in loc_elems.iter_mut() {
|
||||
fix_values_captured_in_closure_expr(&mut elem.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut elem.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
When {
|
||||
loc_cond, branches, ..
|
||||
} => {
|
||||
fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_cond.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
|
||||
for branch in branches.iter_mut() {
|
||||
fix_values_captured_in_closure_expr(&mut branch.value.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut branch.value.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
|
||||
// patterns can contain default expressions, so much go over them too!
|
||||
for loc_pat in branch.patterns.iter_mut() {
|
||||
fix_values_captured_in_closure_pattern(
|
||||
&mut loc_pat.pattern.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(guard) = &mut branch.guard {
|
||||
fix_values_captured_in_closure_expr(&mut guard.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut guard.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -941,23 +1066,43 @@ fn fix_values_captured_in_closure_expr(
|
|||
..
|
||||
} => {
|
||||
for (loc_cond, loc_then) in branches.iter_mut() {
|
||||
fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(&mut loc_then.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_cond.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_then.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
fix_values_captured_in_closure_expr(&mut final_else.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut final_else.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
Call(function, arguments, _) => {
|
||||
fix_values_captured_in_closure_expr(&mut function.1.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut function.1.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
|
||||
for (_, loc_arg) in arguments.iter_mut() {
|
||||
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_arg.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
}
|
||||
RunLowLevel { args, .. } | ForeignCall { args, .. } => {
|
||||
for (_, arg) in args.iter_mut() {
|
||||
fix_values_captured_in_closure_expr(arg, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(arg, no_capture_symbols, closure_captures);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -966,22 +1111,38 @@ fn fix_values_captured_in_closure_expr(
|
|||
updates: fields, ..
|
||||
} => {
|
||||
for (_, field) in fields.iter_mut() {
|
||||
fix_values_captured_in_closure_expr(&mut field.loc_expr.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut field.loc_expr.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Access { loc_expr, .. } => {
|
||||
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_expr.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
|
||||
Tag { arguments, .. } => {
|
||||
for (_, loc_arg) in arguments.iter_mut() {
|
||||
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_arg.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
}
|
||||
OpaqueRef { argument, .. } => {
|
||||
let (_, loc_arg) = &mut **argument;
|
||||
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols);
|
||||
fix_values_captured_in_closure_expr(
|
||||
&mut loc_arg.value,
|
||||
no_capture_symbols,
|
||||
closure_captures,
|
||||
);
|
||||
}
|
||||
OpaqueWrapFunction(_) => {}
|
||||
}
|
||||
|
|
|
@ -7691,4 +7691,42 @@ mod solve_expr {
|
|||
"###
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transient_captures() {
|
||||
infer_queries!(
|
||||
indoc!(
|
||||
r#"
|
||||
x = "abc"
|
||||
|
||||
getX = \{} -> x
|
||||
|
||||
h = \{} -> (getX {})
|
||||
#^{-1}
|
||||
|
||||
h {}
|
||||
"#
|
||||
),
|
||||
@"h : {}* -[[h(3) Str]]-> Str"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transient_captures_after_def_ordering() {
|
||||
infer_queries!(
|
||||
indoc!(
|
||||
r#"
|
||||
h = \{} -> (getX {})
|
||||
#^{-1}
|
||||
|
||||
getX = \{} -> x
|
||||
|
||||
x = "abc"
|
||||
|
||||
h {}
|
||||
"#
|
||||
),
|
||||
@"h : {}* -[[h(1) Str]]-> Str"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue