mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 23:04:49 +00:00
cross module specialization WIP
This commit is contained in:
parent
5c558a9a87
commit
79d3b0ac01
6 changed files with 335 additions and 43 deletions
|
@ -9,6 +9,7 @@ use roc_module::low_level::LowLevel;
|
|||
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
||||
use roc_problem::can::RuntimeError;
|
||||
use roc_region::all::{Located, Region};
|
||||
use roc_types::solved_types::SolvedType;
|
||||
use roc_types::subs::{Content, FlatType, Subs, Variable};
|
||||
use std::collections::HashMap;
|
||||
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
|
||||
|
@ -89,6 +90,8 @@ pub struct Procs<'a> {
|
|||
Option<MutMap<Symbol, MutMap<Layout<'a>, PendingSpecialization<'a>>>>,
|
||||
pub specialized: MutMap<(Symbol, Layout<'a>), InProgressProc<'a>>,
|
||||
pub runtime_errors: MutMap<Symbol, &'a str>,
|
||||
pub externals_others_need: MutMap<Symbol, SolvedType>,
|
||||
pub externals_we_need: MutMap<ModuleId, MutMap<Symbol, SolvedType>>,
|
||||
}
|
||||
|
||||
impl<'a> Default for Procs<'a> {
|
||||
|
@ -99,6 +102,8 @@ impl<'a> Default for Procs<'a> {
|
|||
pending_specializations: Some(MutMap::default()),
|
||||
specialized: MutMap::default(),
|
||||
runtime_errors: MutMap::default(),
|
||||
externals_we_need: MutMap::default(),
|
||||
externals_others_need: MutMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1205,6 +1210,41 @@ pub fn specialize_all<'a>(
|
|||
) -> Procs<'a> {
|
||||
let mut pending_specializations = procs.pending_specializations.unwrap_or_default();
|
||||
|
||||
// add the specializations that other modules require of us
|
||||
use roc_constrain::module::{to_type, FreeVars};
|
||||
use roc_solve::solve::insert_type_into_subs;
|
||||
for (name, solved_type) in procs.externals_others_need.drain() {
|
||||
let mut free_vars = FreeVars::default();
|
||||
let mut var_store = ();
|
||||
let normal_type = to_type(solved_type, &mut free_vars, &mut var_store);
|
||||
let fn_var = insert_type_into_subs(env.subs, &normal_type);
|
||||
|
||||
let layout = layout_cache
|
||||
.from_var(&env.arena, fn_var, env.subs)
|
||||
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err));
|
||||
|
||||
let partial_proc = match procs.partial_procs.get(&name) {
|
||||
Some(v) => v.clone(),
|
||||
None => {
|
||||
unreachable!("now this is an error");
|
||||
}
|
||||
};
|
||||
|
||||
match specialize_external(env, &mut procs, name, layout_cache, fn_var, partial_proc) {
|
||||
Ok(proc) => {
|
||||
procs.specialized.insert((name, layout), Done(proc));
|
||||
}
|
||||
Err(error) => {
|
||||
let error_msg = env.arena.alloc(format!(
|
||||
"TODO generate a RuntimeError message for {:?}",
|
||||
error
|
||||
));
|
||||
|
||||
procs.runtime_errors.insert(name, error_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When calling from_can, pending_specializations should be unavailable.
|
||||
// This must be a single pass, and we must not add any more entries to it!
|
||||
procs.pending_specializations = None;
|
||||
|
@ -1219,11 +1259,14 @@ pub fn specialize_all<'a>(
|
|||
#[allow(clippy::map_entry)]
|
||||
if !procs.specialized.contains_key(&(name, layout.clone())) {
|
||||
// TODO should pending_procs hold a Rc<Proc>?
|
||||
let partial_proc = procs
|
||||
.partial_procs
|
||||
.get(&name)
|
||||
.unwrap_or_else(|| panic!("Could not find partial_proc for {:?}", name))
|
||||
.clone();
|
||||
let partial_proc = match procs.partial_procs.get(&name) {
|
||||
Some(v) => v.clone(),
|
||||
None => {
|
||||
// TODO this assumes the specialization is done by another module
|
||||
// make sure this does not become a problem down the road!
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Mark this proc as in-progress, so if we're dealing with
|
||||
// mutually recursive functions, we don't loop forever.
|
||||
|
@ -1250,6 +1293,110 @@ pub fn specialize_all<'a>(
|
|||
procs
|
||||
}
|
||||
|
||||
fn specialize_external<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
procs: &mut Procs<'a>,
|
||||
proc_name: Symbol,
|
||||
layout_cache: &mut LayoutCache<'a>,
|
||||
fn_var: Variable,
|
||||
partial_proc: PartialProc<'a>,
|
||||
) -> Result<Proc<'a>, LayoutProblem> {
|
||||
let PartialProc {
|
||||
annotation,
|
||||
pattern_symbols,
|
||||
body,
|
||||
is_self_recursive,
|
||||
} = partial_proc;
|
||||
|
||||
// unify the called function with the specialized signature, then specialize the function body
|
||||
let snapshot = env.subs.snapshot();
|
||||
let unified = roc_unify::unify::unify(env.subs, annotation, fn_var);
|
||||
|
||||
debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_)));
|
||||
|
||||
let specialized_body = from_can(env, body, procs, layout_cache);
|
||||
|
||||
let (proc_args, ret_layout) =
|
||||
build_specialized_proc_from_var(env, layout_cache, pattern_symbols, fn_var)?;
|
||||
|
||||
// reset subs, so we don't get type errors when specializing for a different signature
|
||||
env.subs.rollback_to(snapshot);
|
||||
|
||||
// TODO WRONG
|
||||
let closes_over_layout = Layout::Struct(&[]);
|
||||
|
||||
let recursivity = if is_self_recursive {
|
||||
SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol()))
|
||||
} else {
|
||||
SelfRecursive::NotSelfRecursive
|
||||
};
|
||||
|
||||
let proc = Proc {
|
||||
name: proc_name,
|
||||
args: proc_args,
|
||||
body: specialized_body,
|
||||
closes_over: closes_over_layout,
|
||||
ret_layout,
|
||||
is_self_recursive: recursivity,
|
||||
};
|
||||
|
||||
Ok(proc)
|
||||
}
|
||||
|
||||
fn build_specialized_proc_from_var<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
layout_cache: &mut LayoutCache<'a>,
|
||||
pattern_symbols: &[Symbol],
|
||||
fn_var: Variable,
|
||||
) -> Result<(&'a [(Layout<'a>, Symbol)], Layout<'a>), LayoutProblem> {
|
||||
match env.subs.get_without_compacting(fn_var).content {
|
||||
Content::Structure(FlatType::Func(pattern_vars, _closure_var, ret_var)) => {
|
||||
build_specialized_proc(env, layout_cache, pattern_symbols, &pattern_vars, ret_var)
|
||||
}
|
||||
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) => {
|
||||
build_specialized_proc_from_var(env, layout_cache, pattern_symbols, args[1])
|
||||
}
|
||||
Content::Alias(_, _, actual) => {
|
||||
build_specialized_proc_from_var(env, layout_cache, pattern_symbols, actual)
|
||||
}
|
||||
_ => {
|
||||
// a top-level constant 0-argument thunk
|
||||
|
||||
build_specialized_proc(env, layout_cache, pattern_symbols, &[], fn_var)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_specialized_proc<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
layout_cache: &mut LayoutCache<'a>,
|
||||
pattern_symbols: &[Symbol],
|
||||
pattern_vars: &[Variable],
|
||||
ret_var: Variable,
|
||||
) -> Result<(&'a [(Layout<'a>, Symbol)], Layout<'a>), LayoutProblem> {
|
||||
let mut proc_args = Vec::with_capacity_in(pattern_vars.len(), &env.arena);
|
||||
|
||||
debug_assert_eq!(
|
||||
&pattern_vars.len(),
|
||||
&pattern_symbols.len(),
|
||||
"Tried to zip two vecs with different lengths!"
|
||||
);
|
||||
|
||||
for (arg_var, arg_name) in pattern_vars.iter().zip(pattern_symbols.iter()) {
|
||||
let layout = layout_cache.from_var(&env.arena, *arg_var, env.subs)?;
|
||||
|
||||
proc_args.push((layout, *arg_name));
|
||||
}
|
||||
|
||||
let proc_args = proc_args.into_bump_slice();
|
||||
|
||||
let ret_layout = layout_cache
|
||||
.from_var(&env.arena, ret_var, env.subs)
|
||||
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err));
|
||||
|
||||
Ok((proc_args, ret_layout))
|
||||
}
|
||||
|
||||
fn specialize<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
procs: &mut Procs<'a>,
|
||||
|
@ -2202,21 +2349,23 @@ pub fn with_hole<'a>(
|
|||
// So we check the function name against the list of partial procedures,
|
||||
// the procedures that we have lifted to the top-level and can call by name
|
||||
// if it's in there, it's a call by name, otherwise it's a call by pointer
|
||||
let known_functions = &procs.partial_procs;
|
||||
let is_known = |key| {
|
||||
// a proc in this module, or an imported symbol
|
||||
procs.partial_procs.contains_key(key) || key.module_id() != assigned.module_id()
|
||||
};
|
||||
|
||||
match loc_expr.value {
|
||||
roc_can::expr::Expr::Var(proc_name) if known_functions.contains_key(&proc_name) => {
|
||||
call_by_name(
|
||||
env,
|
||||
procs,
|
||||
fn_var,
|
||||
ret_var,
|
||||
proc_name,
|
||||
loc_args,
|
||||
layout_cache,
|
||||
assigned,
|
||||
hole,
|
||||
)
|
||||
}
|
||||
roc_can::expr::Expr::Var(proc_name) if is_known(&proc_name) => call_by_name(
|
||||
env,
|
||||
procs,
|
||||
fn_var,
|
||||
ret_var,
|
||||
proc_name,
|
||||
loc_args,
|
||||
layout_cache,
|
||||
assigned,
|
||||
hole,
|
||||
),
|
||||
_ => {
|
||||
// Call by pointer - the closure was anonymous, e.g.
|
||||
//
|
||||
|
@ -3546,6 +3695,36 @@ fn call_by_name<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
None if assigned.module_id() != proc_name.module_id() => {
|
||||
// call of a function that is not not in this module
|
||||
use std::collections::hash_map::Entry::{Occupied, Vacant};
|
||||
|
||||
let existing =
|
||||
match procs.externals_we_need.entry(proc_name.module_id()) {
|
||||
Vacant(entry) => entry.insert(MutMap::default()),
|
||||
Occupied(entry) => entry.into_mut(),
|
||||
};
|
||||
|
||||
existing.insert(
|
||||
proc_name,
|
||||
SolvedType::from_var(env.subs, pending.fn_var),
|
||||
);
|
||||
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: ret_layout.clone(),
|
||||
full_layout: full_layout.clone(),
|
||||
arg_layouts,
|
||||
args: field_symbols,
|
||||
};
|
||||
|
||||
let iter =
|
||||
loc_args.into_iter().rev().zip(field_symbols.iter().rev());
|
||||
|
||||
let result = Stmt::Let(assigned, call, ret_layout.clone(), hole);
|
||||
assign_to_symbols(env, procs, layout_cache, iter, result)
|
||||
}
|
||||
|
||||
None => {
|
||||
// This must have been a runtime error.
|
||||
match procs.runtime_errors.get(&proc_name) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue