Capture whole params record when needed

This commit is contained in:
Agus Zubiaga 2024-08-27 11:20:59 -03:00
parent 80770fae11
commit 49a6b1bfba
No known key found for this signature in database
8 changed files with 62 additions and 30 deletions

View file

@ -8,6 +8,7 @@ use roc_module::ident::{Ident, ModuleName};
use roc_module::symbol::{IdentIdsByModule, ModuleId, PQModuleName, PackageModuleIds, Symbol};
use roc_problem::can::{Problem, RuntimeError};
use roc_region::all::{Loc, Region};
use roc_types::subs::Variable;
/// The canonicalization environment for a particular module.
pub struct Env<'a> {
@ -38,7 +39,7 @@ pub struct Env<'a> {
pub top_level_symbols: VecSet<Symbol>,
pub home_param_symbols: VecSet<Symbol>,
pub home_params_record: Option<(Symbol, Variable)>,
pub arena: &'a Bump,
@ -66,7 +67,7 @@ impl<'a> Env<'a> {
qualified_type_lookups: VecSet::default(),
tailcallable_symbol: None,
top_level_symbols: VecSet::default(),
home_param_symbols: VecSet::default(),
home_params_record: None,
opt_shorthand,
}
}

View file

@ -1553,6 +1553,8 @@ fn canonicalize_closure_body<'a>(
&loc_body_expr.value,
);
let mut references_top_level = false;
let mut captured_symbols: Vec<_> = new_output
.references
.value_lookups()
@ -1563,18 +1565,28 @@ fn canonicalize_closure_body<'a>(
.filter(|s| !new_output.references.bound_symbols().any(|x| x == s))
.filter(|s| bound_by_argument_patterns.iter().all(|(k, _)| s != k))
// filter out top-level symbols those will be globally available, and don't need to be captured
.filter(|s| !env.top_level_symbols.contains(s))
.filter(|s| {
let is_top_level = env.top_level_symbols.contains(s);
references_top_level = references_top_level || is_top_level;
!is_top_level
})
// filter out imported symbols those will be globally available, and don't need to be captured
.filter(|s| s.module_id() == env.home)
// filter out functions that don't close over anything
.filter(|s| !new_output.non_closures.contains(s))
.filter(|s| !output.non_closures.contains(s))
// module params are not captured by top-level defs, because they are passed in as arguments
// nested defs, however, do capture them
.filter(|s| scope.depth > 1 || !env.home_param_symbols.contains(s))
.map(|s| (s, var_store.fresh()))
.collect();
if references_top_level {
if let Some(params_record) = env.home_params_record {
// If this module has params and the closure references top-level symbols,
// we need to capture the whole record so we can pass it.
// The lower_params pass will take care of removing the captures for top-level fns.
captured_symbols.push(params_record);
}
}
output.union(new_output);
// Now that we've collected all the references, check to see if any of the args we defined

View file

@ -433,18 +433,16 @@ pub fn canonicalize_module_defs<'a>(
PermitShadows(false),
);
env.home_param_symbols.reserve(destructs.len());
for destruct in destructs.iter() {
env.home_param_symbols.insert(destruct.value.symbol);
}
let whole_symbol = scope.gen_unique_symbol();
env.top_level_symbols.insert(whole_symbol);
let whole_var = var_store.fresh();
env.home_params_record = Some((whole_symbol, whole_var));
ModuleParams {
region: pattern.region,
whole_var: var_store.fresh(),
whole_var,
whole_symbol,
record_var: var_store.fresh(),
record_ext_var: var_store.fresh(),

View file

@ -48,9 +48,6 @@ pub struct Scope {
/// Ignored variables (variables that start with an underscore).
/// We won't intern them because they're only used during canonicalization for error reporting.
ignored_locals: VecMap<String, Region>,
/// How many nested scopes deep we are.
pub depth: usize,
}
impl Scope {
@ -76,7 +73,6 @@ impl Scope {
modules: ScopeModules::new(home, module_name),
imported_symbols: default_imports,
ignored_locals: VecMap::default(),
depth: 0,
}
}
@ -450,11 +446,9 @@ impl Scope {
let locals_snapshot = self.locals.in_scope.len();
let imported_symbols_snapshot = self.imported_symbols.len();
let imported_modules_snapshot = self.modules.len();
self.depth += 1;
let result = f(self);
self.depth -= 1;
self.aliases.truncate(aliases_count);
self.ignored_locals.truncate(ignored_locals_count);
self.imported_symbols.truncate(imported_symbols_snapshot);

View file

@ -15,6 +15,7 @@ use roc_types::{
subs::{VarStore, Variable},
types::Type,
};
use std::iter::once;
struct LowerParams<'a> {
home_id: ModuleId,
@ -53,6 +54,16 @@ pub fn lower(
impl<'a> LowerParams<'a> {
fn lower_decls(&mut self, decls: &mut Declarations) {
let home_param_symbols = match self.home_params {
Some(params) => params
.destructs
.iter()
.map(|destruct| destruct.value.symbol)
.chain(once(params.whole_symbol))
.collect::<Vec<Symbol>>(),
None => vec![],
};
for index in 0..decls.len() {
let tag = decls.declarations[index];
@ -72,10 +83,13 @@ impl<'a> LowerParams<'a> {
// This module has params, and this is a top-level function,
// so we need to extend its definition to take them.
decls.function_bodies[fn_def_index.index()]
.value
.arguments
.push((var, mark, pattern));
let function_body = &mut decls.function_bodies[fn_def_index.index()].value;
function_body.arguments.push((var, mark, pattern));
// Remove home params from the captured symbols, only nested lambdas need them.
function_body
.captured_symbols
.retain(|(sym, _)| !home_param_symbols.contains(sym));
if let Some(ann) = &mut decls.annotations[index] {
if let Type::Function(args, _, _) = &mut ann.signature {
@ -409,13 +423,12 @@ impl<'a> LowerParams<'a> {
fn params_extended_home_symbol(&self, symbol: &Symbol) -> Option<(&ModuleParams, usize)> {
if symbol.module_id() == self.home_id {
match self.home_params {
Some(params) => match params.arity_by_name.get(&symbol.ident_id()) {
Some(arity) => Some((params, *arity)),
None => None,
},
None => None,
}
self.home_params.as_ref().and_then(|params| {
params
.arity_by_name
.get(&symbol.ident_id())
.map(|arity| (params, *arity))
})
} else {
None
}