diff --git a/crates/compiler/can/src/module.rs b/crates/compiler/can/src/module.rs index 2b54df0395..ca2d2ed668 100644 --- a/crates/compiler/can/src/module.rs +++ b/crates/compiler/can/src/module.rs @@ -704,7 +704,11 @@ pub fn canonicalize_module_defs<'a>( // def pattern has no default expressions, so skip let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut VecMap::default(), + ); } Function(f_index) | Recursive(f_index) | TailRecursive(f_index) => { let name = declarations.symbols[index].value; @@ -717,35 +721,61 @@ pub fn canonicalize_module_defs<'a>( if function_def.captured_symbols.is_empty() { no_capture_symbols.insert(name); } + let mut closure_captures = VecMap::default(); // patterns can contain default expressions, so must go over them too! for (_, _, loc_pat) in function_def.arguments.iter_mut() { fix_values_captured_in_closure_pattern( &mut loc_pat.value, &mut no_capture_symbols, + &mut closure_captures, ); } - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut no_capture_symbols, + &mut closure_captures, + ); } Destructure(d_index) => { let destruct_def = &mut declarations.destructs[d_index.index()]; let loc_pat = &mut destruct_def.loc_pattern; let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_pattern(&mut loc_pat.value, &mut VecSet::default()); - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + let mut closure_captures = VecMap::default(); + + fix_values_captured_in_closure_pattern( + &mut loc_pat.value, + &mut VecSet::default(), + &mut closure_captures, + ); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut closure_captures, + ); } MutualRecursion { .. } => { // the declarations of this group will be treaded individually by later iterations } Expectation => { let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + let mut closure_captures = Default::default(); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut closure_captures, + ); } ExpectationFx => { let loc_expr = &mut declarations.expressions[index]; - fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default()); + let mut closure_captures = Default::default(); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + &mut VecSet::default(), + &mut closure_captures, + ); } } } @@ -771,16 +801,26 @@ pub fn canonicalize_module_defs<'a>( fn fix_values_captured_in_closure_def( def: &mut crate::def::Def, no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { // patterns can contain default expressions, so much go over them too! - fix_values_captured_in_closure_pattern(&mut def.loc_pattern.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut def.loc_pattern.value, + no_capture_symbols, + closure_captures, + ); - fix_values_captured_in_closure_expr(&mut def.loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut def.loc_expr.value, + no_capture_symbols, + closure_captures, + ); } fn fix_values_captured_in_closure_defs( defs: &mut [crate::def::Def], no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { // recursive defs cannot capture each other for def in defs.iter() { @@ -792,13 +832,14 @@ fn fix_values_captured_in_closure_defs( // TODO mutually recursive functions should both capture the union of both their capture sets for def in defs.iter_mut() { - fix_values_captured_in_closure_def(def, no_capture_symbols); + fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures); } } fn fix_values_captured_in_closure_pattern( pattern: &mut crate::pattern::Pattern, no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { use crate::pattern::Pattern::*; @@ -808,24 +849,35 @@ fn fix_values_captured_in_closure_pattern( .. } => { for (_, loc_arg) in loc_args.iter_mut() { - fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } } UnwrappedOpaque { argument, .. } => { let (_, loc_arg) = &mut **argument; - fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } RecordDestructure { destructs, .. } => { for loc_destruct in destructs.iter_mut() { use crate::pattern::DestructType::*; match &mut loc_destruct.value.typ { Required => {} - Optional(_, loc_expr) => { - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols) - } + Optional(_, loc_expr) => fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ), Guard(_, loc_pattern) => fix_values_captured_in_closure_pattern( &mut loc_pattern.value, no_capture_symbols, + closure_captures, ), } } @@ -848,19 +900,28 @@ fn fix_values_captured_in_closure_pattern( fn fix_values_captured_in_closure_expr( expr: &mut crate::expr::Expr, no_capture_symbols: &mut VecSet, + closure_captures: &mut VecMap>, ) { use crate::expr::Expr::*; match expr { LetNonRec(def, loc_expr) => { // LetNonRec(Box, Box>, Variable, Aliases), - fix_values_captured_in_closure_def(def, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ); } LetRec(defs, loc_expr, _) => { // LetRec(Vec, Box>, Variable, Aliases), - fix_values_captured_in_closure_defs(defs, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_defs(defs, no_capture_symbols, closure_captures); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ); } Expect { @@ -868,8 +929,16 @@ fn fix_values_captured_in_closure_expr( loc_continuation, lookups_in_cond: _, } => { - fix_values_captured_in_closure_expr(&mut loc_condition.value, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_continuation.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_condition.value, + no_capture_symbols, + closure_captures, + ); + fix_values_captured_in_closure_expr( + &mut loc_continuation.value, + no_capture_symbols, + closure_captures, + ); } Closure(ClosureData { @@ -882,16 +951,55 @@ fn fix_values_captured_in_closure_expr( captured_symbols.retain(|(s, _)| !no_capture_symbols.contains(s)); captured_symbols.retain(|(s, _)| s != name); + let original_captures_len = captured_symbols.len(); + let mut num_visited = 0; + let mut i = 0; + while num_visited < original_captures_len { + // If we've captured a capturing closure, replace the captured closure symbol with + // the symbols of its captures. That way, we can construct the closure with the + // captures it needs inside our body. + // + // E.g. + // x = "" + // inner = \{} -> x + // outer = \{} -> inner {} + // + // initially `outer` captures [inner], but this is then replaced with just [x]. + let (captured_symbol, _) = captured_symbols[i]; + if let Some(captures) = closure_captures.get(&captured_symbol) { + captured_symbols.remove(i); + captured_symbols.extend(captures); + // Don't advance because the next capture was shifted down. + } else { + i += 1; + } + num_visited += 1; + } + if captured_symbols.len() > original_captures_len { + // Re-sort, since we've added new captures. + captured_symbols.sort_by_key(|(sym, _)| *sym); + } + if captured_symbols.is_empty() { no_capture_symbols.insert(*name); + } else { + closure_captures.insert(*name, captured_symbols.to_vec()); } // patterns can contain default expressions, so much go over them too! for (_, _, loc_pat) in arguments.iter_mut() { - fix_values_captured_in_closure_pattern(&mut loc_pat.value, no_capture_symbols); + fix_values_captured_in_closure_pattern( + &mut loc_pat.value, + no_capture_symbols, + closure_captures, + ); } - fix_values_captured_in_closure_expr(&mut loc_body.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_body.value, + no_capture_symbols, + closure_captures, + ); } Num(..) @@ -909,28 +1017,45 @@ fn fix_values_captured_in_closure_expr( List { loc_elems, .. } => { for elem in loc_elems.iter_mut() { - fix_values_captured_in_closure_expr(&mut elem.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut elem.value, + no_capture_symbols, + closure_captures, + ); } } When { loc_cond, branches, .. } => { - fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_cond.value, + no_capture_symbols, + closure_captures, + ); for branch in branches.iter_mut() { - fix_values_captured_in_closure_expr(&mut branch.value.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut branch.value.value, + no_capture_symbols, + closure_captures, + ); // patterns can contain default expressions, so much go over them too! for loc_pat in branch.patterns.iter_mut() { fix_values_captured_in_closure_pattern( &mut loc_pat.pattern.value, no_capture_symbols, + closure_captures, ); } if let Some(guard) = &mut branch.guard { - fix_values_captured_in_closure_expr(&mut guard.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut guard.value, + no_capture_symbols, + closure_captures, + ); } } } @@ -941,23 +1066,43 @@ fn fix_values_captured_in_closure_expr( .. } => { for (loc_cond, loc_then) in branches.iter_mut() { - fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols); - fix_values_captured_in_closure_expr(&mut loc_then.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_cond.value, + no_capture_symbols, + closure_captures, + ); + fix_values_captured_in_closure_expr( + &mut loc_then.value, + no_capture_symbols, + closure_captures, + ); } - fix_values_captured_in_closure_expr(&mut final_else.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut final_else.value, + no_capture_symbols, + closure_captures, + ); } Call(function, arguments, _) => { - fix_values_captured_in_closure_expr(&mut function.1.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut function.1.value, + no_capture_symbols, + closure_captures, + ); for (_, loc_arg) in arguments.iter_mut() { - fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } } RunLowLevel { args, .. } | ForeignCall { args, .. } => { for (_, arg) in args.iter_mut() { - fix_values_captured_in_closure_expr(arg, no_capture_symbols); + fix_values_captured_in_closure_expr(arg, no_capture_symbols, closure_captures); } } @@ -966,22 +1111,38 @@ fn fix_values_captured_in_closure_expr( updates: fields, .. } => { for (_, field) in fields.iter_mut() { - fix_values_captured_in_closure_expr(&mut field.loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut field.loc_expr.value, + no_capture_symbols, + closure_captures, + ); } } Access { loc_expr, .. } => { - fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_expr.value, + no_capture_symbols, + closure_captures, + ); } Tag { arguments, .. } => { for (_, loc_arg) in arguments.iter_mut() { - fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } } OpaqueRef { argument, .. } => { let (_, loc_arg) = &mut **argument; - fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); + fix_values_captured_in_closure_expr( + &mut loc_arg.value, + no_capture_symbols, + closure_captures, + ); } OpaqueWrapFunction(_) => {} } diff --git a/crates/compiler/solve/tests/solve_expr.rs b/crates/compiler/solve/tests/solve_expr.rs index 1502110e06..53562694e1 100644 --- a/crates/compiler/solve/tests/solve_expr.rs +++ b/crates/compiler/solve/tests/solve_expr.rs @@ -7691,4 +7691,42 @@ mod solve_expr { "### ); } + + #[test] + fn transient_captures() { + infer_queries!( + indoc!( + r#" + x = "abc" + + getX = \{} -> x + + h = \{} -> (getX {}) + #^{-1} + + h {} + "# + ), + @"h : {}* -[[h(3) Str]]-> Str" + ); + } + + #[test] + fn transient_captures_after_def_ordering() { + infer_queries!( + indoc!( + r#" + h = \{} -> (getX {}) + #^{-1} + + getX = \{} -> x + + x = "abc" + + h {} + "# + ), + @"h : {}* -[[h(1) Str]]-> Str" + ); + } }