Fixup transient closure captures during canonicalization

Closure captures can be transient, but previously, we did not handle
that correctly. For example, in

```
x = ""
inner = \{} -> x
outer = \{} -> inner {}
```

`outer` captures `inner`, but `inner` captures `x`, and in the body of
`outer`, we would not construct the closure data for `inner` correctly
before calling it.

There are a couple ways around this.

1. Update mono to do something when we are passed the captured
   environment of a closure, rather than attempting to construct a
   call-by-name's captured environment before callign it.
2. Fix-up closures during canonicalization to remove captured closures
   that themselves capture, and replace them with their captures.

This patch does (2), since (1) is much more involved and is not likely
to bring a lot of wins. In general I think it's reasonable to expect
captured environments, even if transient, to be fairly shallow, so I
don't think this will produce very large closure environments.

Closes #2894
This commit is contained in:
Ayaz Hafiz 2022-08-10 14:31:11 -07:00
parent 565ffacb9a
commit e97ce32b88
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
2 changed files with 235 additions and 36 deletions

View file

@ -704,7 +704,11 @@ pub fn canonicalize_module_defs<'a>(
// def pattern has no default expressions, so skip // def pattern has no default expressions, so skip
let loc_expr = &mut declarations.expressions[index]; let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut VecSet::default(),
&mut VecMap::default(),
);
} }
Function(f_index) | Recursive(f_index) | TailRecursive(f_index) => { Function(f_index) | Recursive(f_index) | TailRecursive(f_index) => {
let name = declarations.symbols[index].value; let name = declarations.symbols[index].value;
@ -717,35 +721,61 @@ pub fn canonicalize_module_defs<'a>(
if function_def.captured_symbols.is_empty() { if function_def.captured_symbols.is_empty() {
no_capture_symbols.insert(name); no_capture_symbols.insert(name);
} }
let mut closure_captures = VecMap::default();
// patterns can contain default expressions, so must go over them too! // patterns can contain default expressions, so must go over them too!
for (_, _, loc_pat) in function_def.arguments.iter_mut() { for (_, _, loc_pat) in function_def.arguments.iter_mut() {
fix_values_captured_in_closure_pattern( fix_values_captured_in_closure_pattern(
&mut loc_pat.value, &mut loc_pat.value,
&mut no_capture_symbols, &mut no_capture_symbols,
&mut closure_captures,
); );
} }
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut no_capture_symbols,
&mut closure_captures,
);
} }
Destructure(d_index) => { Destructure(d_index) => {
let destruct_def = &mut declarations.destructs[d_index.index()]; let destruct_def = &mut declarations.destructs[d_index.index()];
let loc_pat = &mut destruct_def.loc_pattern; let loc_pat = &mut destruct_def.loc_pattern;
let loc_expr = &mut declarations.expressions[index]; let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_pattern(&mut loc_pat.value, &mut VecSet::default()); let mut closure_captures = VecMap::default();
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
fix_values_captured_in_closure_pattern(
&mut loc_pat.value,
&mut VecSet::default(),
&mut closure_captures,
);
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut VecSet::default(),
&mut closure_captures,
);
} }
MutualRecursion { .. } => { MutualRecursion { .. } => {
// the declarations of this group will be treaded individually by later iterations // the declarations of this group will be treaded individually by later iterations
} }
Expectation => { Expectation => {
let loc_expr = &mut declarations.expressions[index]; let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); let mut closure_captures = Default::default();
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut VecSet::default(),
&mut closure_captures,
);
} }
ExpectationFx => { ExpectationFx => {
let loc_expr = &mut declarations.expressions[index]; let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); let mut closure_captures = Default::default();
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut VecSet::default(),
&mut closure_captures,
);
} }
} }
} }
@ -771,16 +801,26 @@ pub fn canonicalize_module_defs<'a>(
fn fix_values_captured_in_closure_def( fn fix_values_captured_in_closure_def(
def: &mut crate::def::Def, def: &mut crate::def::Def,
no_capture_symbols: &mut VecSet<Symbol>, no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) { ) {
// patterns can contain default expressions, so much go over them too! // patterns can contain default expressions, so much go over them too!
fix_values_captured_in_closure_pattern(&mut def.loc_pattern.value, no_capture_symbols); fix_values_captured_in_closure_pattern(
&mut def.loc_pattern.value,
no_capture_symbols,
closure_captures,
);
fix_values_captured_in_closure_expr(&mut def.loc_expr.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut def.loc_expr.value,
no_capture_symbols,
closure_captures,
);
} }
fn fix_values_captured_in_closure_defs( fn fix_values_captured_in_closure_defs(
defs: &mut [crate::def::Def], defs: &mut [crate::def::Def],
no_capture_symbols: &mut VecSet<Symbol>, no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) { ) {
// recursive defs cannot capture each other // recursive defs cannot capture each other
for def in defs.iter() { for def in defs.iter() {
@ -792,13 +832,14 @@ fn fix_values_captured_in_closure_defs(
// TODO mutually recursive functions should both capture the union of both their capture sets // TODO mutually recursive functions should both capture the union of both their capture sets
for def in defs.iter_mut() { for def in defs.iter_mut() {
fix_values_captured_in_closure_def(def, no_capture_symbols); fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
} }
} }
fn fix_values_captured_in_closure_pattern( fn fix_values_captured_in_closure_pattern(
pattern: &mut crate::pattern::Pattern, pattern: &mut crate::pattern::Pattern,
no_capture_symbols: &mut VecSet<Symbol>, no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) { ) {
use crate::pattern::Pattern::*; use crate::pattern::Pattern::*;
@ -808,24 +849,35 @@ fn fix_values_captured_in_closure_pattern(
.. ..
} => { } => {
for (_, loc_arg) in loc_args.iter_mut() { for (_, loc_arg) in loc_args.iter_mut() {
fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols); fix_values_captured_in_closure_pattern(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
} }
} }
UnwrappedOpaque { argument, .. } => { UnwrappedOpaque { argument, .. } => {
let (_, loc_arg) = &mut **argument; let (_, loc_arg) = &mut **argument;
fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols); fix_values_captured_in_closure_pattern(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
} }
RecordDestructure { destructs, .. } => { RecordDestructure { destructs, .. } => {
for loc_destruct in destructs.iter_mut() { for loc_destruct in destructs.iter_mut() {
use crate::pattern::DestructType::*; use crate::pattern::DestructType::*;
match &mut loc_destruct.value.typ { match &mut loc_destruct.value.typ {
Required => {} Required => {}
Optional(_, loc_expr) => { Optional(_, loc_expr) => fix_values_captured_in_closure_expr(
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols) &mut loc_expr.value,
} no_capture_symbols,
closure_captures,
),
Guard(_, loc_pattern) => fix_values_captured_in_closure_pattern( Guard(_, loc_pattern) => fix_values_captured_in_closure_pattern(
&mut loc_pattern.value, &mut loc_pattern.value,
no_capture_symbols, no_capture_symbols,
closure_captures,
), ),
} }
} }
@ -848,19 +900,28 @@ fn fix_values_captured_in_closure_pattern(
fn fix_values_captured_in_closure_expr( fn fix_values_captured_in_closure_expr(
expr: &mut crate::expr::Expr, expr: &mut crate::expr::Expr,
no_capture_symbols: &mut VecSet<Symbol>, no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) { ) {
use crate::expr::Expr::*; use crate::expr::Expr::*;
match expr { match expr {
LetNonRec(def, loc_expr) => { LetNonRec(def, loc_expr) => {
// LetNonRec(Box<Def>, Box<Located<Expr>>, Variable, Aliases), // LetNonRec(Box<Def>, Box<Located<Expr>>, Variable, Aliases),
fix_values_captured_in_closure_def(def, no_capture_symbols); fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_expr.value,
no_capture_symbols,
closure_captures,
);
} }
LetRec(defs, loc_expr, _) => { LetRec(defs, loc_expr, _) => {
// LetRec(Vec<Def>, Box<Located<Expr>>, Variable, Aliases), // LetRec(Vec<Def>, Box<Located<Expr>>, Variable, Aliases),
fix_values_captured_in_closure_defs(defs, no_capture_symbols); fix_values_captured_in_closure_defs(defs, no_capture_symbols, closure_captures);
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_expr.value,
no_capture_symbols,
closure_captures,
);
} }
Expect { Expect {
@ -868,8 +929,16 @@ fn fix_values_captured_in_closure_expr(
loc_continuation, loc_continuation,
lookups_in_cond: _, lookups_in_cond: _,
} => { } => {
fix_values_captured_in_closure_expr(&mut loc_condition.value, no_capture_symbols); fix_values_captured_in_closure_expr(
fix_values_captured_in_closure_expr(&mut loc_continuation.value, no_capture_symbols); &mut loc_condition.value,
no_capture_symbols,
closure_captures,
);
fix_values_captured_in_closure_expr(
&mut loc_continuation.value,
no_capture_symbols,
closure_captures,
);
} }
Closure(ClosureData { Closure(ClosureData {
@ -882,16 +951,55 @@ fn fix_values_captured_in_closure_expr(
captured_symbols.retain(|(s, _)| !no_capture_symbols.contains(s)); captured_symbols.retain(|(s, _)| !no_capture_symbols.contains(s));
captured_symbols.retain(|(s, _)| s != name); captured_symbols.retain(|(s, _)| s != name);
let original_captures_len = captured_symbols.len();
let mut num_visited = 0;
let mut i = 0;
while num_visited < original_captures_len {
// If we've captured a capturing closure, replace the captured closure symbol with
// the symbols of its captures. That way, we can construct the closure with the
// captures it needs inside our body.
//
// E.g.
// x = ""
// inner = \{} -> x
// outer = \{} -> inner {}
//
// initially `outer` captures [inner], but this is then replaced with just [x].
let (captured_symbol, _) = captured_symbols[i];
if let Some(captures) = closure_captures.get(&captured_symbol) {
captured_symbols.remove(i);
captured_symbols.extend(captures);
// Don't advance because the next capture was shifted down.
} else {
i += 1;
}
num_visited += 1;
}
if captured_symbols.len() > original_captures_len {
// Re-sort, since we've added new captures.
captured_symbols.sort_by_key(|(sym, _)| *sym);
}
if captured_symbols.is_empty() { if captured_symbols.is_empty() {
no_capture_symbols.insert(*name); no_capture_symbols.insert(*name);
} else {
closure_captures.insert(*name, captured_symbols.to_vec());
} }
// patterns can contain default expressions, so much go over them too! // patterns can contain default expressions, so much go over them too!
for (_, _, loc_pat) in arguments.iter_mut() { for (_, _, loc_pat) in arguments.iter_mut() {
fix_values_captured_in_closure_pattern(&mut loc_pat.value, no_capture_symbols); fix_values_captured_in_closure_pattern(
&mut loc_pat.value,
no_capture_symbols,
closure_captures,
);
} }
fix_values_captured_in_closure_expr(&mut loc_body.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_body.value,
no_capture_symbols,
closure_captures,
);
} }
Num(..) Num(..)
@ -909,28 +1017,45 @@ fn fix_values_captured_in_closure_expr(
List { loc_elems, .. } => { List { loc_elems, .. } => {
for elem in loc_elems.iter_mut() { for elem in loc_elems.iter_mut() {
fix_values_captured_in_closure_expr(&mut elem.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut elem.value,
no_capture_symbols,
closure_captures,
);
} }
} }
When { When {
loc_cond, branches, .. loc_cond, branches, ..
} => { } => {
fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_cond.value,
no_capture_symbols,
closure_captures,
);
for branch in branches.iter_mut() { for branch in branches.iter_mut() {
fix_values_captured_in_closure_expr(&mut branch.value.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut branch.value.value,
no_capture_symbols,
closure_captures,
);
// patterns can contain default expressions, so much go over them too! // patterns can contain default expressions, so much go over them too!
for loc_pat in branch.patterns.iter_mut() { for loc_pat in branch.patterns.iter_mut() {
fix_values_captured_in_closure_pattern( fix_values_captured_in_closure_pattern(
&mut loc_pat.pattern.value, &mut loc_pat.pattern.value,
no_capture_symbols, no_capture_symbols,
closure_captures,
); );
} }
if let Some(guard) = &mut branch.guard { if let Some(guard) = &mut branch.guard {
fix_values_captured_in_closure_expr(&mut guard.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut guard.value,
no_capture_symbols,
closure_captures,
);
} }
} }
} }
@ -941,23 +1066,43 @@ fn fix_values_captured_in_closure_expr(
.. ..
} => { } => {
for (loc_cond, loc_then) in branches.iter_mut() { for (loc_cond, loc_then) in branches.iter_mut() {
fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols); fix_values_captured_in_closure_expr(
fix_values_captured_in_closure_expr(&mut loc_then.value, no_capture_symbols); &mut loc_cond.value,
no_capture_symbols,
closure_captures,
);
fix_values_captured_in_closure_expr(
&mut loc_then.value,
no_capture_symbols,
closure_captures,
);
} }
fix_values_captured_in_closure_expr(&mut final_else.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut final_else.value,
no_capture_symbols,
closure_captures,
);
} }
Call(function, arguments, _) => { Call(function, arguments, _) => {
fix_values_captured_in_closure_expr(&mut function.1.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut function.1.value,
no_capture_symbols,
closure_captures,
);
for (_, loc_arg) in arguments.iter_mut() { for (_, loc_arg) in arguments.iter_mut() {
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
} }
} }
RunLowLevel { args, .. } | ForeignCall { args, .. } => { RunLowLevel { args, .. } | ForeignCall { args, .. } => {
for (_, arg) in args.iter_mut() { for (_, arg) in args.iter_mut() {
fix_values_captured_in_closure_expr(arg, no_capture_symbols); fix_values_captured_in_closure_expr(arg, no_capture_symbols, closure_captures);
} }
} }
@ -966,22 +1111,38 @@ fn fix_values_captured_in_closure_expr(
updates: fields, .. updates: fields, ..
} => { } => {
for (_, field) in fields.iter_mut() { for (_, field) in fields.iter_mut() {
fix_values_captured_in_closure_expr(&mut field.loc_expr.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut field.loc_expr.value,
no_capture_symbols,
closure_captures,
);
} }
} }
Access { loc_expr, .. } => { Access { loc_expr, .. } => {
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_expr.value,
no_capture_symbols,
closure_captures,
);
} }
Tag { arguments, .. } => { Tag { arguments, .. } => {
for (_, loc_arg) in arguments.iter_mut() { for (_, loc_arg) in arguments.iter_mut() {
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
} }
} }
OpaqueRef { argument, .. } => { OpaqueRef { argument, .. } => {
let (_, loc_arg) = &mut **argument; let (_, loc_arg) = &mut **argument;
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); fix_values_captured_in_closure_expr(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
} }
OpaqueWrapFunction(_) => {} OpaqueWrapFunction(_) => {}
} }

View file

@ -7691,4 +7691,42 @@ mod solve_expr {
"### "###
); );
} }
#[test]
fn transient_captures() {
infer_queries!(
indoc!(
r#"
x = "abc"
getX = \{} -> x
h = \{} -> (getX {})
#^{-1}
h {}
"#
),
@"h : {}* -[[h(3) Str]]-> Str"
);
}
#[test]
fn transient_captures_after_def_ordering() {
infer_queries!(
indoc!(
r#"
h = \{} -> (getX {})
#^{-1}
getX = \{} -> x
x = "abc"
h {}
"#
),
@"h : {}* -[[h(1) Str]]-> Str"
);
}
} }