From c65f90b8c5a6bfabce5cae8849278a439ecefe9b Mon Sep 17 00:00:00 2001 From: Folkert Date: Wed, 27 Apr 2022 16:22:00 +0200 Subject: [PATCH] refactor closure canonicalization --- compiler/can/src/expr.rs | 218 +++++++++++++++++++++------------------ 1 file changed, 115 insertions(+), 103 deletions(-) diff --git a/compiler/can/src/expr.rs b/compiler/can/src/expr.rs index 836d35467f..5a0c404ed0 100644 --- a/compiler/can/src/expr.rs +++ b/compiler/can/src/expr.rs @@ -668,110 +668,10 @@ pub fn canonicalize_expr<'a>( unreachable!("Backpassing should have been desugared by now") } ast::Expr::Closure(loc_arg_patterns, loc_body_expr) => { - // The globally unique symbol that will refer to this closure once it gets converted - // into a top-level procedure for code gen. - // - // In the Foo module, this will look something like Foo.$1 or Foo.$2. - let symbol = env - .closure_name_symbol - .unwrap_or_else(|| env.gen_unique_symbol()); - env.closure_name_symbol = None; + let (closure_data, output) = + canonicalize_closure(env, var_store, scope, loc_arg_patterns, loc_body_expr); - // The body expression gets a new scope for canonicalization. - // Shadow `scope` to make sure we don't accidentally use the original one for the - // rest of this block, but keep the original around for later diffing. - let original_scope = scope; - let mut scope = original_scope.clone(); - - let mut can_args = Vec::with_capacity(loc_arg_patterns.len()); - let mut output = Output::default(); - - for loc_pattern in loc_arg_patterns.iter() { - let can_argument_pattern = canonicalize_pattern( - env, - var_store, - &mut scope, - &mut output, - FunctionArg, - &loc_pattern.value, - loc_pattern.region, - ); - - can_args.push((var_store.fresh(), can_argument_pattern)); - } - - let bound_by_argument_patterns = bindings_from_patterns(can_args.iter().map(|x| &x.1)); - - let (loc_body_expr, new_output) = canonicalize_expr( - env, - var_store, - &mut scope, - loc_body_expr.region, - &loc_body_expr.value, - ); - - let mut captured_symbols: Vec<_> = new_output - .references - .value_lookups() - .copied() - // filter out the closure's name itself - .filter(|s| *s != symbol) - // symbols bound either in this pattern or deeper down are not captured! - .filter(|s| !new_output.references.bound_symbols().any(|x| x == s)) - .filter(|s| bound_by_argument_patterns.iter().all(|(k, _)| s != k)) - // filter out top-level symbols those will be globally available, and don't need to be captured - .filter(|s| !env.top_level_symbols.contains(s)) - // filter out imported symbols those will be globally available, and don't need to be captured - .filter(|s| s.module_id() == env.home) - // filter out functions that don't close over anything - .filter(|s| !new_output.non_closures.contains(s)) - .filter(|s| !output.non_closures.contains(s)) - .map(|s| (s, var_store.fresh())) - .collect(); - - output.union(new_output); - - // Now that we've collected all the references, check to see if any of the args we defined - // went unreferenced. If any did, report them as unused arguments. - for (sub_symbol, region) in bound_by_argument_patterns { - if !output.references.has_value_lookup(sub_symbol) { - // The body never referenced this argument we declared. It's an unused argument! - env.problem(Problem::UnusedArgument(symbol, sub_symbol, region)); - } else { - // We shouldn't ultimately count arguments as referenced locals. Otherwise, - // we end up with weird conclusions like the expression (\x -> x + 1) - // references the (nonexistent) local variable x! - output.references.remove_value_lookup(&sub_symbol); - } - } - - // store the references of this function in the Env. This information is used - // when we canonicalize a surrounding def (if it exists) - env.closures.insert(symbol, output.references.clone()); - - // sort symbols, so we know the order in which they're stored in the closure record - captured_symbols.sort(); - - // store that this function doesn't capture anything. It will be promoted to a - // top-level function, and does not need to be captured by other surrounding functions. - if captured_symbols.is_empty() { - output.non_closures.insert(symbol); - } - - ( - Closure(ClosureData { - function_type: var_store.fresh(), - closure_type: var_store.fresh(), - closure_ext_var: var_store.fresh(), - return_type: var_store.fresh(), - name: symbol, - captured_symbols, - recursive: Recursive::NotRecursive, - arguments: can_args, - loc_body: Box::new(loc_body_expr), - }), - output, - ) + (Closure(closure_data), output) } ast::Expr::When(loc_cond, branches) => { // Infer the condition expression's type. @@ -1043,6 +943,118 @@ pub fn canonicalize_expr<'a>( ) } +pub fn canonicalize_closure<'a>( + env: &mut Env<'a>, + var_store: &mut VarStore, + scope: &mut Scope, + loc_arg_patterns: &'a [Loc>], + loc_body_expr: &'a Loc>, +) -> (ClosureData, Output) { + // The globally unique symbol that will refer to this closure once it gets converted + // into a top-level procedure for code gen. + // + // In the Foo module, this will look something like Foo.$1 or Foo.$2. + let symbol = env + .closure_name_symbol + .unwrap_or_else(|| env.gen_unique_symbol()); + env.closure_name_symbol = None; + + // The body expression gets a new scope for canonicalization. + // Shadow `scope` to make sure we don't accidentally use the original one for the + // rest of this block, but keep the original around for later diffing. + let original_scope = scope; + let mut scope = original_scope.clone(); + + let mut can_args = Vec::with_capacity(loc_arg_patterns.len()); + let mut output = Output::default(); + + for loc_pattern in loc_arg_patterns.iter() { + let can_argument_pattern = canonicalize_pattern( + env, + var_store, + &mut scope, + &mut output, + FunctionArg, + &loc_pattern.value, + loc_pattern.region, + ); + + can_args.push((var_store.fresh(), can_argument_pattern)); + } + + let bound_by_argument_patterns = bindings_from_patterns(can_args.iter().map(|x| &x.1)); + + let (loc_body_expr, new_output) = canonicalize_expr( + env, + var_store, + &mut scope, + loc_body_expr.region, + &loc_body_expr.value, + ); + + let mut captured_symbols: Vec<_> = new_output + .references + .value_lookups() + .copied() + // filter out the closure's name itself + .filter(|s| *s != symbol) + // symbols bound either in this pattern or deeper down are not captured! + .filter(|s| !new_output.references.bound_symbols().any(|x| x == s)) + .filter(|s| bound_by_argument_patterns.iter().all(|(k, _)| s != k)) + // filter out top-level symbols those will be globally available, and don't need to be captured + .filter(|s| !env.top_level_symbols.contains(s)) + // filter out imported symbols those will be globally available, and don't need to be captured + .filter(|s| s.module_id() == env.home) + // filter out functions that don't close over anything + .filter(|s| !new_output.non_closures.contains(s)) + .filter(|s| !output.non_closures.contains(s)) + .map(|s| (s, var_store.fresh())) + .collect(); + + output.union(new_output); + + // Now that we've collected all the references, check to see if any of the args we defined + // went unreferenced. If any did, report them as unused arguments. + for (sub_symbol, region) in bound_by_argument_patterns { + if !output.references.has_value_lookup(sub_symbol) { + // The body never referenced this argument we declared. It's an unused argument! + env.problem(Problem::UnusedArgument(symbol, sub_symbol, region)); + } else { + // We shouldn't ultimately count arguments as referenced locals. Otherwise, + // we end up with weird conclusions like the expression (\x -> x + 1) + // references the (nonexistent) local variable x! + output.references.remove_value_lookup(&sub_symbol); + } + } + + // store the references of this function in the Env. This information is used + // when we canonicalize a surrounding def (if it exists) + env.closures.insert(symbol, output.references.clone()); + + // sort symbols, so we know the order in which they're stored in the closure record + captured_symbols.sort(); + + // store that this function doesn't capture anything. It will be promoted to a + // top-level function, and does not need to be captured by other surrounding functions. + if captured_symbols.is_empty() { + output.non_closures.insert(symbol); + } + + let closure_data = ClosureData { + function_type: var_store.fresh(), + closure_type: var_store.fresh(), + closure_ext_var: var_store.fresh(), + return_type: var_store.fresh(), + name: symbol, + captured_symbols, + recursive: Recursive::NotRecursive, + arguments: can_args, + loc_body: Box::new(loc_body_expr), + }; + + (closure_data, output) +} + #[inline(always)] fn canonicalize_when_branch<'a>( env: &mut Env<'a>,