Merge remote-tracking branch 'origin/main' into expect-fx-codegen

This commit is contained in:
Folkert 2022-08-23 16:28:21 +02:00
commit a22e04361c
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
222 changed files with 10039 additions and 1945 deletions

View file

@ -696,15 +696,26 @@ pub fn canonicalize_module_defs<'a>(
referenced_values.extend(env.qualified_value_lookups.iter().copied());
referenced_types.extend(env.qualified_type_lookups.iter().copied());
let mut fix_closures_no_capture_symbols = VecSet::default();
let mut fix_closures_closure_captures = VecMap::default();
for index in 0..declarations.len() {
use crate::expr::DeclarationTag::*;
// For each declaration, we need to fixup the closures inside its def.
// Reuse the fixup buffer allocations from the previous iteration.
fix_closures_no_capture_symbols.clear();
fix_closures_closure_captures.clear();
match declarations.declarations[index] {
Value => {
// def pattern has no default expressions, so skip
let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut fix_closures_no_capture_symbols,
&mut fix_closures_closure_captures,
);
}
Function(f_index) | Recursive(f_index) | TailRecursive(f_index) => {
let name = declarations.symbols[index].value;
@ -722,30 +733,51 @@ pub fn canonicalize_module_defs<'a>(
for (_, _, loc_pat) in function_def.arguments.iter_mut() {
fix_values_captured_in_closure_pattern(
&mut loc_pat.value,
&mut no_capture_symbols,
&mut fix_closures_no_capture_symbols,
&mut fix_closures_closure_captures,
);
}
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut fix_closures_no_capture_symbols,
&mut fix_closures_closure_captures,
);
}
Destructure(d_index) => {
let destruct_def = &mut declarations.destructs[d_index.index()];
let loc_pat = &mut destruct_def.loc_pattern;
let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_pattern(&mut loc_pat.value, &mut VecSet::default());
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
fix_values_captured_in_closure_pattern(
&mut loc_pat.value,
&mut fix_closures_no_capture_symbols,
&mut fix_closures_closure_captures,
);
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut fix_closures_no_capture_symbols,
&mut fix_closures_closure_captures,
);
}
MutualRecursion { .. } => {
// the declarations of this group will be treaded individually by later iterations
}
Expectation => {
let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut fix_closures_no_capture_symbols,
&mut fix_closures_closure_captures,
);
}
ExpectationFx => {
let loc_expr = &mut declarations.expressions[index];
fix_values_captured_in_closure_expr(&mut loc_expr.value, &mut VecSet::default());
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
&mut fix_closures_no_capture_symbols,
&mut fix_closures_closure_captures,
);
}
}
}
@ -771,16 +803,26 @@ pub fn canonicalize_module_defs<'a>(
fn fix_values_captured_in_closure_def(
def: &mut crate::def::Def,
no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) {
// patterns can contain default expressions, so much go over them too!
fix_values_captured_in_closure_pattern(&mut def.loc_pattern.value, no_capture_symbols);
fix_values_captured_in_closure_pattern(
&mut def.loc_pattern.value,
no_capture_symbols,
closure_captures,
);
fix_values_captured_in_closure_expr(&mut def.loc_expr.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut def.loc_expr.value,
no_capture_symbols,
closure_captures,
);
}
fn fix_values_captured_in_closure_defs(
defs: &mut [crate::def::Def],
no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) {
// recursive defs cannot capture each other
for def in defs.iter() {
@ -789,16 +831,38 @@ fn fix_values_captured_in_closure_defs(
);
}
// TODO mutually recursive functions should both capture the union of both their capture sets
for def in defs.iter_mut() {
fix_values_captured_in_closure_def(def, no_capture_symbols);
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
}
// Mutually recursive functions should both capture the union of all their capture sets
//
// Really unfortunate we make a lot of clones here, can this be done more efficiently?
let mut total_capture_set = Vec::default();
for def in defs.iter_mut() {
if let Expr::Closure(ClosureData {
captured_symbols, ..
}) = &def.loc_expr.value
{
total_capture_set.extend(captured_symbols.iter().copied());
}
}
total_capture_set.sort_by_key(|(sym, _)| *sym);
total_capture_set.dedup_by_key(|(sym, _)| *sym);
for def in defs.iter_mut() {
if let Expr::Closure(ClosureData {
captured_symbols, ..
}) = &mut def.loc_expr.value
{
*captured_symbols = total_capture_set.clone();
}
}
}
fn fix_values_captured_in_closure_pattern(
pattern: &mut crate::pattern::Pattern,
no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) {
use crate::pattern::Pattern::*;
@ -808,24 +872,35 @@ fn fix_values_captured_in_closure_pattern(
..
} => {
for (_, loc_arg) in loc_args.iter_mut() {
fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols);
fix_values_captured_in_closure_pattern(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
}
}
UnwrappedOpaque { argument, .. } => {
let (_, loc_arg) = &mut **argument;
fix_values_captured_in_closure_pattern(&mut loc_arg.value, no_capture_symbols);
fix_values_captured_in_closure_pattern(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
}
RecordDestructure { destructs, .. } => {
for loc_destruct in destructs.iter_mut() {
use crate::pattern::DestructType::*;
match &mut loc_destruct.value.typ {
Required => {}
Optional(_, loc_expr) => {
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols)
}
Optional(_, loc_expr) => fix_values_captured_in_closure_expr(
&mut loc_expr.value,
no_capture_symbols,
closure_captures,
),
Guard(_, loc_pattern) => fix_values_captured_in_closure_pattern(
&mut loc_pattern.value,
no_capture_symbols,
closure_captures,
),
}
}
@ -848,19 +923,28 @@ fn fix_values_captured_in_closure_pattern(
fn fix_values_captured_in_closure_expr(
expr: &mut crate::expr::Expr,
no_capture_symbols: &mut VecSet<Symbol>,
closure_captures: &mut VecMap<Symbol, Vec<(Symbol, Variable)>>,
) {
use crate::expr::Expr::*;
match expr {
LetNonRec(def, loc_expr) => {
// LetNonRec(Box<Def>, Box<Located<Expr>>, Variable, Aliases),
fix_values_captured_in_closure_def(def, no_capture_symbols);
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols);
fix_values_captured_in_closure_def(def, no_capture_symbols, closure_captures);
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
no_capture_symbols,
closure_captures,
);
}
LetRec(defs, loc_expr, _) => {
// LetRec(Vec<Def>, Box<Located<Expr>>, Variable, Aliases),
fix_values_captured_in_closure_defs(defs, no_capture_symbols);
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols);
fix_values_captured_in_closure_defs(defs, no_capture_symbols, closure_captures);
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
no_capture_symbols,
closure_captures,
);
}
Expect {
@ -868,8 +952,16 @@ fn fix_values_captured_in_closure_expr(
loc_continuation,
lookups_in_cond: _,
} => {
fix_values_captured_in_closure_expr(&mut loc_condition.value, no_capture_symbols);
fix_values_captured_in_closure_expr(&mut loc_continuation.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_condition.value,
no_capture_symbols,
closure_captures,
);
fix_values_captured_in_closure_expr(
&mut loc_continuation.value,
no_capture_symbols,
closure_captures,
);
}
ExpectFx {
@ -891,16 +983,58 @@ fn fix_values_captured_in_closure_expr(
captured_symbols.retain(|(s, _)| !no_capture_symbols.contains(s));
captured_symbols.retain(|(s, _)| s != name);
let original_captures_len = captured_symbols.len();
let mut num_visited = 0;
let mut i = 0;
while num_visited < original_captures_len {
// If we've captured a capturing closure, replace the captured closure symbol with
// the symbols of its captures. That way, we can construct the closure with the
// captures it needs inside our body.
//
// E.g.
// x = ""
// inner = \{} -> x
// outer = \{} -> inner {}
//
// initially `outer` captures [inner], but this is then replaced with just [x].
let (captured_symbol, _) = captured_symbols[i];
if let Some(captures) = closure_captures.get(&captured_symbol) {
debug_assert!(!captures.is_empty());
captured_symbols.swap_remove(i);
captured_symbols.extend(captures);
// Jump two, because the next element is now one of the newly-added captures,
// which we don't need to check.
i += 2;
} else {
i += 1;
}
num_visited += 1;
}
if captured_symbols.len() > original_captures_len {
// Re-sort, since we've added new captures.
captured_symbols.sort_by_key(|(sym, _)| *sym);
}
if captured_symbols.is_empty() {
no_capture_symbols.insert(*name);
} else {
closure_captures.insert(*name, captured_symbols.to_vec());
}
// patterns can contain default expressions, so much go over them too!
for (_, _, loc_pat) in arguments.iter_mut() {
fix_values_captured_in_closure_pattern(&mut loc_pat.value, no_capture_symbols);
fix_values_captured_in_closure_pattern(
&mut loc_pat.value,
no_capture_symbols,
closure_captures,
);
}
fix_values_captured_in_closure_expr(&mut loc_body.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_body.value,
no_capture_symbols,
closure_captures,
);
}
Num(..)
@ -918,28 +1052,45 @@ fn fix_values_captured_in_closure_expr(
List { loc_elems, .. } => {
for elem in loc_elems.iter_mut() {
fix_values_captured_in_closure_expr(&mut elem.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut elem.value,
no_capture_symbols,
closure_captures,
);
}
}
When {
loc_cond, branches, ..
} => {
fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_cond.value,
no_capture_symbols,
closure_captures,
);
for branch in branches.iter_mut() {
fix_values_captured_in_closure_expr(&mut branch.value.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut branch.value.value,
no_capture_symbols,
closure_captures,
);
// patterns can contain default expressions, so much go over them too!
for loc_pat in branch.patterns.iter_mut() {
fix_values_captured_in_closure_pattern(
&mut loc_pat.pattern.value,
no_capture_symbols,
closure_captures,
);
}
if let Some(guard) = &mut branch.guard {
fix_values_captured_in_closure_expr(&mut guard.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut guard.value,
no_capture_symbols,
closure_captures,
);
}
}
}
@ -950,23 +1101,43 @@ fn fix_values_captured_in_closure_expr(
..
} => {
for (loc_cond, loc_then) in branches.iter_mut() {
fix_values_captured_in_closure_expr(&mut loc_cond.value, no_capture_symbols);
fix_values_captured_in_closure_expr(&mut loc_then.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_cond.value,
no_capture_symbols,
closure_captures,
);
fix_values_captured_in_closure_expr(
&mut loc_then.value,
no_capture_symbols,
closure_captures,
);
}
fix_values_captured_in_closure_expr(&mut final_else.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut final_else.value,
no_capture_symbols,
closure_captures,
);
}
Call(function, arguments, _) => {
fix_values_captured_in_closure_expr(&mut function.1.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut function.1.value,
no_capture_symbols,
closure_captures,
);
for (_, loc_arg) in arguments.iter_mut() {
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
}
}
RunLowLevel { args, .. } | ForeignCall { args, .. } => {
for (_, arg) in args.iter_mut() {
fix_values_captured_in_closure_expr(arg, no_capture_symbols);
fix_values_captured_in_closure_expr(arg, no_capture_symbols, closure_captures);
}
}
@ -975,22 +1146,38 @@ fn fix_values_captured_in_closure_expr(
updates: fields, ..
} => {
for (_, field) in fields.iter_mut() {
fix_values_captured_in_closure_expr(&mut field.loc_expr.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut field.loc_expr.value,
no_capture_symbols,
closure_captures,
);
}
}
Access { loc_expr, .. } => {
fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_expr.value,
no_capture_symbols,
closure_captures,
);
}
Tag { arguments, .. } => {
for (_, loc_arg) in arguments.iter_mut() {
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
}
}
OpaqueRef { argument, .. } => {
let (_, loc_arg) = &mut **argument;
fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols);
fix_values_captured_in_closure_expr(
&mut loc_arg.value,
no_capture_symbols,
closure_captures,
);
}
OpaqueWrapFunction(_) => {}
}