From e97ce32b88d17406f89c1ebb560ac36331e22dfd Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Wed, 10 Aug 2022 14:31:11 -0700 Subject: [PATCH] Fixup transient closure captures during canonicalization Closure captures can be transient, but previously, we did not handle that correctly. For example, in ``` x = "" inner = \{} -> x outer = \{} -> inner {} ``` `outer` captures `inner`, but `inner` captures `x`, and in the body of `outer`, we would not construct the closure data for `inner` correctly before calling it. There are a couple ways around this. 1. Update mono to do something when we are passed the captured environment of a closure, rather than attempting to construct a call-by-name's captured environment before callign it. 2. Fix-up closures during canonicalization to remove captured closures that themselves capture, and replace them with their captures. This patch does (2), since (1) is much more involved and is not likely to bring a lot of wins. In general I think it's reasonable to expect captured environments, even if transient, to be fairly shallow, so I don't think this will produce very large closure environments. Closes #2894 --- crates/compiler/can/src/module.rs | 233 ++++++++++++++++++---- crates/compiler/solve/tests/solve_expr.rs | 38 ++++ 2 files changed, 235 insertions(+), 36 deletions(-) diff --git a/crates/compiler/can/src/module.rs b/crates/compiler/can/src/module.rs index 2b54df0395..ca2d2ed668 100644 --- a/crates/compiler/can/src/module.rs +++ b/crates/compiler/can/src/module.rs @@ -704,7 +704,11 @@ pub fn canonicalize_module_defs<'a>( // def pattern has no default expressions, so skip let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut VecMap::default(), + ); } Function(f_index) | Recursive(f_index) | TailRecursive(f_index) => { let name = declarations.symbols[index].value; @@ -717,35 +721,61 @@ pub fn canonicalize_module_defs<'a>( if function_def.captured_symbols.is_empty() { no_capture_symbols.insert(name); } + let mut closure_captures = VecMap::default(); // patterns can contain default expressions, so must go over them too! for (_, _, loc_pat) in function_def.arguments.iter_mut() { fix_values_captured_in_closure_pattern( &mut loc_pat.value, &mut no_capture_symbols, + &mut closure_captures, ); } - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut no_capture_symbols, + &mut closure_captures, + ); } Destructure(d_index) => { let destruct_def = &mut declarations.destructs[d_index.index()]; let loc_pat = &mut destruct_def.loc_pattern; let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_pattern(&mut loc_pat.value, &mut VecSet::default()); - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + let mut closure_captures = VecMap::default(); + + fix_values_captured_in_closure_pattern( + &mut loc_pat.value, + &mut VecSet::default(), + &mut closure_captures, + ); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut closure_captures, + ); } MutualRecursion { .. } => { // the declarations of this group will be treaded individually by later iterations } Expectation => { let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + let mut closure_captures = Default::default(); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut closure_captures, + ); } ExpectationFx => { let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + let mut closure_captures = Default::default(); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut closure_captures, + ); } } } @@ -771,16 +801,26 @@ pub fn canonicalize_module_defs<'a>( fn fix_values_captured_in_closure_def( def: &mut crate::def::Def, no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { // patterns can contain default expressions, so much go over them too! - fix_values_captured_in_closure_pattern(&mut def.loc_pattern.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut def.loc_pattern.value, + no_capture_symbols, + closure_captures, + ); - fix_values_captured_in_closure_expr(&mut def.loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut def.loc_expr.value, + no_capture_symbols, + closure_captures, + ); } fn fix_values_captured_in_closure_defs( defs: &mut [crate::def::Def], no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { // recursive defs cannot capture each other for def in defs.iter() { @@ -792,13 +832,14 @@ fn fix_values_captured_in_closure_defs( // TODO mutually recursive functions should both capture the union of both their capture sets for def in defs.iter_mut() { - fix_values_captured_in_closure_def(def, no_capture_symbols); + fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures); } } fn fix_values_captured_in_closure_pattern( pattern: &mut crate::pattern::Pattern, no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { use crate::pattern::Pattern::*; @@ -808,24 +849,35 @@ fn fix_values_captured_in_closure_pattern( .. } => { for (_, loc_arg) in loc_args.iter_mut() { - fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } } UnwrappedOpaque { argument, .. } => { let (_, loc_arg) = &mut **argument; - fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } RecordDestructure { destructs, .. } => { for loc_destruct in destructs.iter_mut() { use crate::pattern::DestructType::*; match &mut loc_destruct.value.typ { Required => {} - Optional(_, loc_expr) => { - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols) - } + Optional(_, loc_expr) => fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ), Guard(_, loc_pattern) => fix_values_captured_in_closure_pattern( &mut loc_pattern.value, no_capture_symbols, + closure_captures, ), } } @@ -848,19 +900,28 @@ fn fix_values_captured_in_closure_pattern( fn fix_values_captured_in_closure_expr( expr: &mut crate::expr::Expr, no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { use crate::expr::Expr::*; match expr { LetNonRec(def, loc_expr) => { // LetNonRec(Box, Box>, Variable, Aliases), - fix_values_captured_in_closure_def(def, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ); } LetRec(defs, loc_expr, _) => { // LetRec(Vec, Box>, Variable, Aliases), - fix_values_captured_in_closure_defs(defs, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_defs(defs, no_capture_symbols, closure_captures); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ); } Expect { @@ -868,8 +929,16 @@ fn fix_values_captured_in_closure_expr( loc_continuation, lookups_in_cond: _, } => { - fix_values_captured_in_closure_expr(&mut loc_condition.value, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_continuation.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_condition.value, + no_capture_symbols, + closure_captures, + ); + fix_values_captured_in_closure_expr( + &mut loc_continuation.value, + no_capture_symbols, + closure_captures, + ); } Closure(ClosureData { @@ -882,16 +951,55 @@ fn fix_values_captured_in_closure_expr( captured_symbols.retain(|(s, _)| !no_capture_symbols.contains(s)); captured_symbols.retain(|(s, _)| s != name); + let original_captures_len = captured_symbols.len(); + let mut num_visited = 0; + let mut i = 0; + while num_visited < original_captures_len { + // If we've captured a capturing closure, replace the captured closure symbol with + // the symbols of its captures. That way, we can construct the closure with the + // captures it needs inside our body. + // + // E.g. + // x = "" + // inner = \{} -> x + // outer = \{} -> inner {} + // + // initially `outer` captures [inner], but this is then replaced with just [x]. + let (captured_symbol, _) = captured_symbols[i]; + if let Some(captures) = closure_captures.get(&captured_symbol) { + captured_symbols.remove(i); + captured_symbols.extend(captures); + // Don't advance because the next capture was shifted down. + } else { + i += 1; + } + num_visited += 1; + } + if captured_symbols.len() > original_captures_len { + // Re-sort, since we've added new captures. + captured_symbols.sort_by_key(|(sym, _)| *sym); + } + if captured_symbols.is_empty() { no_capture_symbols.insert(*name); + } else { + closure_captures.insert(*name, captured_symbols.to_vec()); } // patterns can contain default expressions, so much go over them too! for (_, _, loc_pat) in arguments.iter_mut() { - fix_values_captured_in_closure_pattern(&mut loc_pat.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut loc_pat.value, + no_capture_symbols, + closure_captures, + ); } - fix_values_captured_in_closure_expr(&mut loc_body.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_body.value, + no_capture_symbols, + closure_captures, + ); } Num(..) @@ -909,28 +1017,45 @@ fn fix_values_captured_in_closure_expr( List { loc_elems, .. } => { for elem in loc_elems.iter_mut() { - fix_values_captured_in_closure_expr(&mut elem.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut elem.value, + no_capture_symbols, + closure_captures, + ); } } When { loc_cond, branches, .. } => { - fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_cond.value, + no_capture_symbols, + closure_captures, + ); for branch in branches.iter_mut() { - fix_values_captured_in_closure_expr(&mut branch.value.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut branch.value.value, + no_capture_symbols, + closure_captures, + ); // patterns can contain default expressions, so much go over them too! for loc_pat in branch.patterns.iter_mut() { fix_values_captured_in_closure_pattern( &mut loc_pat.pattern.value, no_capture_symbols, + closure_captures, ); } if let Some(guard) = &mut branch.guard { - fix_values_captured_in_closure_expr(&mut guard.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut guard.value, + no_capture_symbols, + closure_captures, + ); } } } @@ -941,23 +1066,43 @@ fn fix_values_captured_in_closure_expr( .. } => { for (loc_cond, loc_then) in branches.iter_mut() { - fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_then.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_cond.value, + no_capture_symbols, + closure_captures, + ); + fix_values_captured_in_closure_expr( + &mut loc_then.value, + no_capture_symbols, + closure_captures, + ); } - fix_values_captured_in_closure_expr(&mut final_else.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut final_else.value, + no_capture_symbols, + closure_captures, + ); } Call(function, arguments, _) => { - fix_values_captured_in_closure_expr(&mut function.1.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut function.1.value, + no_capture_symbols, + closure_captures, + ); for (_, loc_arg) in arguments.iter_mut() { - fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } } RunLowLevel { args, .. } | ForeignCall { args, .. } => { for (_, arg) in args.iter_mut() { - fix_values_captured_in_closure_expr(arg, no_capture_symbols); + fix_values_captured_in_closure_expr(arg, no_capture_symbols, closure_captures); } } @@ -966,22 +1111,38 @@ fn fix_values_captured_in_closure_expr( updates: fields, .. } => { for (_, field) in fields.iter_mut() { - fix_values_captured_in_closure_expr(&mut field.loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut field.loc_expr.value, + no_capture_symbols, + closure_captures, + ); } } Access { loc_expr, .. } => { - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ); } Tag { arguments, .. } => { for (_, loc_arg) in arguments.iter_mut() { - fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } } OpaqueRef { argument, .. } => { let (_, loc_arg) = &mut **argument; - fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } OpaqueWrapFunction(_) => {} } diff --git a/crates/compiler/solve/tests/solve_expr.rs b/crates/compiler/solve/tests/solve_expr.rs index 1502110e06..53562694e1 100644 --- a/crates/compiler/solve/tests/solve_expr.rs +++ b/crates/compiler/solve/tests/solve_expr.rs @@ -7691,4 +7691,42 @@ mod solve_expr { "### ); } + + #[test] + fn transient_captures() { + infer_queries!( + indoc!( + r#" + x = "abc" + + getX = \{} -> x + + h = \{} -> (getX {}) + #^{-1} + + h {} + "# + ), + @"h : {}* -[[h(3) Str]]-> Str" + ); + } + + #[test] + fn transient_captures_after_def_ordering() { + infer_queries!( + indoc!( + r#" + h = \{} -> (getX {}) + #^{-1} + + getX = \{} -> x + + x = "abc" + + h {} + "# + ), + @"h : {}* -[[h(1) Str]]-> Str" + ); + } }