mirror of
https://github.com/roc-lang/roc.git
synced 2025-08-03 19:58:18 +00:00
Capture whole params record when needed
This commit is contained in:
parent
80770fae11
commit
49a6b1bfba
8 changed files with 62 additions and 30 deletions
|
@ -8,6 +8,7 @@ use roc_module::ident::{Ident, ModuleName};
|
|||
use roc_module::symbol::{IdentIdsByModule, ModuleId, PQModuleName, PackageModuleIds, Symbol};
|
||||
use roc_problem::can::{Problem, RuntimeError};
|
||||
use roc_region::all::{Loc, Region};
|
||||
use roc_types::subs::Variable;
|
||||
|
||||
/// The canonicalization environment for a particular module.
|
||||
pub struct Env<'a> {
|
||||
|
@ -38,7 +39,7 @@ pub struct Env<'a> {
|
|||
|
||||
pub top_level_symbols: VecSet<Symbol>,
|
||||
|
||||
pub home_param_symbols: VecSet<Symbol>,
|
||||
pub home_params_record: Option<(Symbol, Variable)>,
|
||||
|
||||
pub arena: &'a Bump,
|
||||
|
||||
|
@ -66,7 +67,7 @@ impl<'a> Env<'a> {
|
|||
qualified_type_lookups: VecSet::default(),
|
||||
tailcallable_symbol: None,
|
||||
top_level_symbols: VecSet::default(),
|
||||
home_param_symbols: VecSet::default(),
|
||||
home_params_record: None,
|
||||
opt_shorthand,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1553,6 +1553,8 @@ fn canonicalize_closure_body<'a>(
|
|||
&loc_body_expr.value,
|
||||
);
|
||||
|
||||
let mut references_top_level = false;
|
||||
|
||||
let mut captured_symbols: Vec<_> = new_output
|
||||
.references
|
||||
.value_lookups()
|
||||
|
@ -1563,18 +1565,28 @@ fn canonicalize_closure_body<'a>(
|
|||
.filter(|s| !new_output.references.bound_symbols().any(|x| x == s))
|
||||
.filter(|s| bound_by_argument_patterns.iter().all(|(k, _)| s != k))
|
||||
// filter out top-level symbols those will be globally available, and don't need to be captured
|
||||
.filter(|s| !env.top_level_symbols.contains(s))
|
||||
.filter(|s| {
|
||||
let is_top_level = env.top_level_symbols.contains(s);
|
||||
references_top_level = references_top_level || is_top_level;
|
||||
!is_top_level
|
||||
})
|
||||
// filter out imported symbols those will be globally available, and don't need to be captured
|
||||
.filter(|s| s.module_id() == env.home)
|
||||
// filter out functions that don't close over anything
|
||||
.filter(|s| !new_output.non_closures.contains(s))
|
||||
.filter(|s| !output.non_closures.contains(s))
|
||||
// module params are not captured by top-level defs, because they are passed in as arguments
|
||||
// nested defs, however, do capture them
|
||||
.filter(|s| scope.depth > 1 || !env.home_param_symbols.contains(s))
|
||||
.map(|s| (s, var_store.fresh()))
|
||||
.collect();
|
||||
|
||||
if references_top_level {
|
||||
if let Some(params_record) = env.home_params_record {
|
||||
// If this module has params and the closure references top-level symbols,
|
||||
// we need to capture the whole record so we can pass it.
|
||||
// The lower_params pass will take care of removing the captures for top-level fns.
|
||||
captured_symbols.push(params_record);
|
||||
}
|
||||
}
|
||||
|
||||
output.union(new_output);
|
||||
|
||||
// Now that we've collected all the references, check to see if any of the args we defined
|
||||
|
|
|
@ -433,18 +433,16 @@ pub fn canonicalize_module_defs<'a>(
|
|||
PermitShadows(false),
|
||||
);
|
||||
|
||||
env.home_param_symbols.reserve(destructs.len());
|
||||
|
||||
for destruct in destructs.iter() {
|
||||
env.home_param_symbols.insert(destruct.value.symbol);
|
||||
}
|
||||
|
||||
let whole_symbol = scope.gen_unique_symbol();
|
||||
env.top_level_symbols.insert(whole_symbol);
|
||||
|
||||
let whole_var = var_store.fresh();
|
||||
|
||||
env.home_params_record = Some((whole_symbol, whole_var));
|
||||
|
||||
ModuleParams {
|
||||
region: pattern.region,
|
||||
whole_var: var_store.fresh(),
|
||||
whole_var,
|
||||
whole_symbol,
|
||||
record_var: var_store.fresh(),
|
||||
record_ext_var: var_store.fresh(),
|
||||
|
|
|
@ -48,9 +48,6 @@ pub struct Scope {
|
|||
/// Ignored variables (variables that start with an underscore).
|
||||
/// We won't intern them because they're only used during canonicalization for error reporting.
|
||||
ignored_locals: VecMap<String, Region>,
|
||||
|
||||
/// How many nested scopes deep we are.
|
||||
pub depth: usize,
|
||||
}
|
||||
|
||||
impl Scope {
|
||||
|
@ -76,7 +73,6 @@ impl Scope {
|
|||
modules: ScopeModules::new(home, module_name),
|
||||
imported_symbols: default_imports,
|
||||
ignored_locals: VecMap::default(),
|
||||
depth: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -450,11 +446,9 @@ impl Scope {
|
|||
let locals_snapshot = self.locals.in_scope.len();
|
||||
let imported_symbols_snapshot = self.imported_symbols.len();
|
||||
let imported_modules_snapshot = self.modules.len();
|
||||
self.depth += 1;
|
||||
|
||||
let result = f(self);
|
||||
|
||||
self.depth -= 1;
|
||||
self.aliases.truncate(aliases_count);
|
||||
self.ignored_locals.truncate(ignored_locals_count);
|
||||
self.imported_symbols.truncate(imported_symbols_snapshot);
|
||||
|
|
|
@ -15,6 +15,7 @@ use roc_types::{
|
|||
subs::{VarStore, Variable},
|
||||
types::Type,
|
||||
};
|
||||
use std::iter::once;
|
||||
|
||||
struct LowerParams<'a> {
|
||||
home_id: ModuleId,
|
||||
|
@ -53,6 +54,16 @@ pub fn lower(
|
|||
|
||||
impl<'a> LowerParams<'a> {
|
||||
fn lower_decls(&mut self, decls: &mut Declarations) {
|
||||
let home_param_symbols = match self.home_params {
|
||||
Some(params) => params
|
||||
.destructs
|
||||
.iter()
|
||||
.map(|destruct| destruct.value.symbol)
|
||||
.chain(once(params.whole_symbol))
|
||||
.collect::<Vec<Symbol>>(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
for index in 0..decls.len() {
|
||||
let tag = decls.declarations[index];
|
||||
|
||||
|
@ -72,10 +83,13 @@ impl<'a> LowerParams<'a> {
|
|||
// This module has params, and this is a top-level function,
|
||||
// so we need to extend its definition to take them.
|
||||
|
||||
decls.function_bodies[fn_def_index.index()]
|
||||
.value
|
||||
.arguments
|
||||
.push((var, mark, pattern));
|
||||
let function_body = &mut decls.function_bodies[fn_def_index.index()].value;
|
||||
function_body.arguments.push((var, mark, pattern));
|
||||
|
||||
// Remove home params from the captured symbols, only nested lambdas need them.
|
||||
function_body
|
||||
.captured_symbols
|
||||
.retain(|(sym, _)| !home_param_symbols.contains(sym));
|
||||
|
||||
if let Some(ann) = &mut decls.annotations[index] {
|
||||
if let Type::Function(args, _, _) = &mut ann.signature {
|
||||
|
@ -409,13 +423,12 @@ impl<'a> LowerParams<'a> {
|
|||
|
||||
fn params_extended_home_symbol(&self, symbol: &Symbol) -> Option<(&ModuleParams, usize)> {
|
||||
if symbol.module_id() == self.home_id {
|
||||
match self.home_params {
|
||||
Some(params) => match params.arity_by_name.get(&symbol.ident_id()) {
|
||||
Some(arity) => Some((params, *arity)),
|
||||
None => None,
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
self.home_params.as_ref().and_then(|params| {
|
||||
params
|
||||
.arity_by_name
|
||||
.get(&symbol.ident_id())
|
||||
.map(|arity| (params, *arity))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue