cross module specialization WIP

This commit is contained in:
Folkert 2020-10-11 01:23:52 +02:00
parent 5c558a9a87
commit 79d3b0ac01
6 changed files with 335 additions and 43 deletions

View file

@ -22,8 +22,10 @@ use roc_region::all::{Located, Region};
use roc_solve::module::SolvedModule;
use roc_solve::solve;
use roc_types::solved_types::Solved;
use roc_types::solved_types::SolvedType;
use roc_types::subs::{Subs, VarStore, Variable};
use roc_types::types::Alias;
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io;
@ -182,6 +184,7 @@ struct ModuleCache<'a> {
constrained: MutMap<ModuleId, ConstrainedModule<'a>>,
typechecked: MutMap<ModuleId, TypeCheckedModule<'a>>,
found_specializations: MutMap<ModuleId, FoundSpecializationsModule<'a>>,
external_specializations_requested: MutMap<ModuleId, MutMap<Symbol, SolvedType>>,
}
fn start_phase<'a>(module_id: ModuleId, phase: Phase, state: &mut State<'a>) -> BuildTask<'a> {
@ -303,6 +306,7 @@ fn start_phase<'a>(module_id: ModuleId, phase: Phase, state: &mut State<'a>) ->
decls,
finished_info,
ident_ids,
pending_specializations: state.all_pending_specializations.clone(),
}
}
Phase::MakeSpecializations => {
@ -312,6 +316,12 @@ fn start_phase<'a>(module_id: ModuleId, phase: Phase, state: &mut State<'a>) ->
.remove(&module_id)
.unwrap();
let specializations_we_must_make = state
.module_cache
.external_specializations_requested
.remove(&module_id)
.unwrap_or(MutMap::default());
let FoundSpecializationsModule {
module_id,
ident_ids,
@ -327,6 +337,7 @@ fn start_phase<'a>(module_id: ModuleId, phase: Phase, state: &mut State<'a>) ->
subs,
procs,
layout_cache,
specializations_we_must_make,
finished_info,
}
}
@ -406,7 +417,7 @@ pub struct MonomorphizedModule<'a> {
pub type_problems: Vec<solve::TypeError>,
pub mono_problems: Vec<roc_mono::ir::MonoProblem>,
pub procedures: MutMap<(Symbol, Layout<'a>), Proc<'a>>,
pub exposed_vars_by_symbol: Vec<(Symbol, Variable)>,
pub exposed_to_host: MutSet<Symbol>,
pub src: Box<str>,
pub timings: MutMap<ModuleId, ModuleTiming>,
}
@ -453,6 +464,7 @@ enum Msg<'a> {
module_id: ModuleId,
ident_ids: IdentIds,
layout_cache: LayoutCache<'a>,
external_specializations_requested: MutMap<ModuleId, MutMap<Symbol, SolvedType>>,
procedures: MutMap<(Symbol, Layout<'a>), Proc<'a>>,
problems: Vec<roc_mono::ir::MonoProblem>,
subs: Subs,
@ -464,7 +476,7 @@ enum Msg<'a> {
FinishedAllSpecialization {
subs: Subs,
problems: Vec<MonoProblem>,
exposed_vars_by_symbol: Vec<(Symbol, Variable)>,
exposed_to_host: MutSet<Symbol>,
src: &'a str,
},
}
@ -491,6 +503,7 @@ struct State<'a> {
pub module_cache: ModuleCache<'a>,
pub dependencies: Dependencies,
pub procedures: MutMap<(Symbol, Layout<'a>), Proc<'a>>,
pub exposed_to_host: MutSet<Symbol>,
/// This is the "final" list of IdentIds, after canonicalization and constraint gen
/// have completed for a given module.
@ -518,7 +531,7 @@ struct State<'a> {
/// pending specializations in the same thread.
pub needs_specialization: MutSet<ModuleId>,
pub all_pending_specializations: MutMap<(Symbol, Layout<'a>), PendingSpecialization<'a>>,
pub all_pending_specializations: MutMap<Symbol, MutMap<Layout<'a>, PendingSpecialization<'a>>>,
pub specializations_in_flight: u32,
@ -632,6 +645,8 @@ enum BuildTask<'a> {
ident_ids: IdentIds,
decls: Vec<Declaration>,
finished_info: FinishedInfo<'a>,
// TODO remove?
pending_specializations: MutMap<Symbol, MutMap<Layout<'a>, PendingSpecialization<'a>>>,
},
MakeSpecializations {
module_id: ModuleId,
@ -640,6 +655,7 @@ enum BuildTask<'a> {
procs: Procs<'a>,
layout_cache: LayoutCache<'a>,
finished_info: FinishedInfo<'a>,
specializations_we_must_make: MutMap<Symbol, SolvedType>,
},
}
@ -932,6 +948,7 @@ where
module_cache: ModuleCache::default(),
dependencies: Dependencies::default(),
procedures: MutMap::default(),
exposed_to_host: MutSet::default(),
exposed_types,
headers_parsed,
loading_started,
@ -993,7 +1010,7 @@ where
Msg::FinishedAllSpecialization {
subs,
problems,
exposed_vars_by_symbol,
exposed_to_host,
src,
} => {
// We're done! There should be no more messages pending.
@ -1010,7 +1027,7 @@ where
state,
subs,
problems,
exposed_vars_by_symbol,
exposed_to_host,
src,
)));
}
@ -1140,6 +1157,12 @@ fn update<'a>(
let work = state.dependencies.notify(module_id, Phase::SolveTypes);
if module_id == state.root_id {
state
.exposed_to_host
.extend(solved_module.exposed_vars_by_symbol.iter().map(|x| x.0));
}
if module_id == state.root_id && state.goal_phase == Phase::SolveTypes {
debug_assert!(work.is_empty());
debug_assert!(state.dependencies.solved_all());
@ -1220,6 +1243,19 @@ fn update<'a>(
} => {
let subs = solved_subs.into_inner();
if let Some(pending) = &procs.pending_specializations {
for (symbol, specs) in pending {
let mut existing = match state.all_pending_specializations.entry(*symbol) {
Vacant(entry) => entry.insert(MutMap::default()),
Occupied(entry) => entry.into_mut(),
};
for (layout, pend) in specs {
existing.insert(layout.clone(), pend.clone());
}
}
}
let found_specializations_module = FoundSpecializationsModule {
layout_cache,
module_id,
@ -1251,17 +1287,32 @@ fn update<'a>(
subs,
finished_info,
procedures,
external_specializations_requested,
..
} => {
println!("done specializing {:?}", module_id);
for (module_id, requested) in external_specializations_requested {
let existing = match state
.module_cache
.external_specializations_requested
.entry(module_id)
{
Vacant(entry) => entry.insert(MutMap::default()),
Occupied(entry) => entry.into_mut(),
};
existing.extend(requested);
}
state.procedures.extend(procedures);
dbg!(&state.procedures);
let work = state
.dependencies
.notify(module_id, Phase::MakeSpecializations);
state.constrained_ident_ids.insert(module_id, ident_ids);
if work.is_empty()
&& state.dependencies.solved_all()
&& state.goal_phase == Phase::MakeSpecializations
@ -1273,14 +1324,11 @@ fn update<'a>(
subs,
// TODO thread through mono problems
problems: vec![],
exposed_vars_by_symbol: finished_info.exposed_vars_by_symbol,
exposed_to_host: state.exposed_to_host.clone(),
src: finished_info.src,
})
.map_err(|_| LoadingProblem::MsgChannelDied)?;
// bookkeeping
state.constrained_ident_ids.insert(module_id, ident_ids);
// As far as type-checking goes, once we've solved
// the originally requested module, we're all done!
return Ok(state);
@ -1307,7 +1355,7 @@ fn finish_specialization<'a>(
mut state: State<'a>,
subs: Subs,
problems: Vec<MonoProblem>,
exposed_vars_by_symbol: Vec<(Symbol, Variable)>,
exposed_to_host: MutSet<Symbol>,
src: &'a str,
) -> MonomorphizedModule<'a> {
state.mono_problems.extend(problems);
@ -1334,7 +1382,7 @@ fn finish_specialization<'a>(
can_problems,
mono_problems,
type_problems,
exposed_vars_by_symbol,
exposed_to_host,
module_id: state.root_id,
subs,
interns,
@ -1902,6 +1950,7 @@ fn make_specializations<'a>(
mut subs: Subs,
mut procs: Procs<'a>,
mut layout_cache: LayoutCache<'a>,
specializations_we_must_make: MutMap<Symbol, SolvedType>,
finished_info: FinishedInfo<'a>,
) -> Msg<'a> {
let mut mono_problems = Vec::new();
@ -1914,7 +1963,10 @@ fn make_specializations<'a>(
ident_ids: &mut ident_ids,
};
dbg!(&procs);
procs
.externals_others_need
.extend(specializations_we_must_make);
// TODO: for now this final specialization pass is sequential,
// with no parallelization at all. We should try to parallelize
// this, but doing so will require a redesign of Procs.
@ -1925,8 +1977,8 @@ fn make_specializations<'a>(
// &finished_info.vars_by_symbol,
);
let external_specializations_requested = procs.externals_we_need.clone();
let (procedures, _param_map) = procs.get_specialized_procs_help(mono_env.arena);
dbg!(&procedures);
Msg::MadeSpecializations {
module_id: home,
@ -1936,6 +1988,7 @@ fn make_specializations<'a>(
problems: mono_problems,
subs,
finished_info,
external_specializations_requested,
}
}
@ -1949,9 +2002,12 @@ fn build_pending_specializations<'a>(
// TODO use this?
_module_timing: ModuleTiming,
mut layout_cache: LayoutCache<'a>,
// TODO remove
_pending_specializations: MutMap<Symbol, MutMap<Layout<'a>, PendingSpecialization<'a>>>,
finished_info: FinishedInfo<'a>,
) -> Msg<'a> {
let mut procs = Procs::default();
let mut mono_problems = std::vec::Vec::new();
let mut subs = solved_subs.into_inner();
let mut mono_env = roc_mono::ir::Env {
@ -1971,6 +2027,12 @@ fn build_pending_specializations<'a>(
match decl {
Declare(def) | Builtin(def) => match def.loc_pattern.value {
Identifier(symbol) => {
let is_exposed = finished_info
.exposed_vars_by_symbol
.iter()
.find(|(k, _)| *k == symbol)
.is_some();
match def.loc_expr.value {
Closure {
function_type: annotation,
@ -1981,6 +2043,32 @@ fn build_pending_specializations<'a>(
} => {
// this is a non-recursive declaration
let is_tail_recursive = false;
// If this is an exposed symbol, we need to
// register it as such. Otherwise, since it
// never gets called by Roc code, it will never
// get specialized!
if is_exposed {
let mut pattern_vars = bumpalo::collections::Vec::with_capacity_in(
loc_args.len(),
arena,
);
for (var, _) in loc_args.iter() {
pattern_vars.push(*var);
}
let layout = layout_cache.from_var(mono_env.arena, annotation, mono_env.subs).unwrap_or_else(|err|
todo!("TODO gracefully handle the situation where we expose a function to the host which doesn't have a valid layout (e.g. maybe the function wasn't monomorphic): {:?}", err)
);
procs.insert_exposed(
symbol,
layout,
pattern_vars.into_bump_slice(),
annotation,
ret_var,
);
}
procs.insert_named(
&mut mono_env,
@ -1994,6 +2082,20 @@ fn build_pending_specializations<'a>(
);
}
body => {
// If this is an exposed symbol, we need to
// register it as such. Otherwise, since it
// never gets called by Roc code, it will never
// get specialized!
if is_exposed {
let annotation = def.expr_var;
let ret_var = def.expr_var;
let layout = layout_cache.from_var(mono_env.arena, annotation, mono_env.subs).unwrap_or_else(|err|
todo!("TODO gracefully handle the situation where we expose a function to the host which doesn't have a valid layout (e.g. maybe the function wasn't monomorphic): {:?}", err)
);
procs.insert_exposed(symbol, layout, &[], annotation, ret_var);
}
let proc = PartialProc {
annotation: def.expr_var,
// This is a 0-arity thunk, so it has no arguments.
@ -2087,6 +2189,7 @@ fn run_task<'a>(
layout_cache,
solved_subs,
finished_info,
pending_specializations,
} => Ok(build_pending_specializations(
arena,
solved_subs,
@ -2095,6 +2198,7 @@ fn run_task<'a>(
decls,
module_timing,
layout_cache,
pending_specializations,
finished_info,
)),
MakeSpecializations {
@ -2103,6 +2207,7 @@ fn run_task<'a>(
subs,
procs,
layout_cache,
specializations_we_must_make,
finished_info,
} => Ok(make_specializations(
arena,
@ -2111,6 +2216,7 @@ fn run_task<'a>(
subs,
procs,
layout_cache,
specializations_we_must_make,
finished_info,
)),
}?;