make new PartialProcs struct

This commit is contained in:
Folkert 2021-11-03 13:59:00 +01:00
parent 487948e514
commit 74df66a472
2 changed files with 172 additions and 117 deletions

View file

@ -3958,7 +3958,10 @@ fn make_specializations<'a>(
let mut procs = Procs::new_in(arena); let mut procs = Procs::new_in(arena);
procs.partial_procs = procs_base.partial_procs; for (symbol, partial_proc) in procs_base.partial_procs.into_iter() {
procs.partial_procs.insert(symbol, partial_proc);
}
procs.module_thunks = procs_base.module_thunks; procs.module_thunks = procs_base.module_thunks;
procs.runtime_errors = procs_base.runtime_errors; procs.runtime_errors = procs_base.runtime_errors;
procs.imported_module_thunks = procs_base.imported_module_thunks; procs.imported_module_thunks = procs_base.imported_module_thunks;

View file

@ -71,6 +71,61 @@ pub struct EntryPoint<'a> {
pub layout: ProcLayout<'a>, pub layout: ProcLayout<'a>,
} }
#[derive(Clone, Copy, Debug)]
pub struct PartialProcId(usize);
#[derive(Clone, Debug, PartialEq)]
pub struct PartialProcs<'a> {
/// maps a function name (symbol) to an index
symbols: Vec<'a, Symbol>,
partial_procs: Vec<'a, PartialProc<'a>>,
}
impl<'a> PartialProcs<'a> {
fn new_in(arena: &'a Bump) -> Self {
Self {
symbols: Vec::new_in(arena),
partial_procs: Vec::new_in(arena),
}
}
fn contains_key(&self, symbol: Symbol) -> bool {
self.symbol_to_id(symbol).is_some()
}
fn symbol_to_id(&self, symbol: Symbol) -> Option<PartialProcId> {
self.symbols
.iter()
.position(|s| *s == symbol)
.map(PartialProcId)
}
fn get_symbol(&self, symbol: Symbol) -> Option<&PartialProc<'a>> {
let id = self.symbol_to_id(symbol)?;
Some(self.get_id(id))
}
fn get_id(&self, id: PartialProcId) -> &PartialProc<'a> {
&self.partial_procs[id.0]
}
pub fn insert(&mut self, symbol: Symbol, partial_proc: PartialProc<'a>) -> PartialProcId {
debug_assert!(
!self.contains_key(symbol),
"The {:?} is inserted as a partial proc twice: that's a bug!",
symbol,
);
let id = PartialProcId(self.symbols.len());
self.symbols.push(symbol);
self.partial_procs.push(partial_proc);
id
}
}
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct PartialProc<'a> { pub struct PartialProc<'a> {
pub annotation: Variable, pub annotation: Variable,
@ -129,7 +184,7 @@ impl<'a> PartialProc<'a> {
} }
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
pub enum CapturedSymbols<'a> { pub enum CapturedSymbols<'a> {
None, None,
Captured(&'a [(Symbol, Variable)]), Captured(&'a [(Symbol, Variable)]),
@ -418,7 +473,7 @@ impl<'a> ExternalSpecializations<'a> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Procs<'a> { pub struct Procs<'a> {
pub partial_procs: BumpMap<Symbol, PartialProc<'a>>, pub partial_procs: PartialProcs<'a>,
pub imported_module_thunks: &'a [Symbol], pub imported_module_thunks: &'a [Symbol],
pub module_thunks: &'a [Symbol], pub module_thunks: &'a [Symbol],
pub pending_specializations: pub pending_specializations:
@ -432,7 +487,7 @@ pub struct Procs<'a> {
impl<'a> Procs<'a> { impl<'a> Procs<'a> {
pub fn new_in(arena: &'a Bump) -> Self { pub fn new_in(arena: &'a Bump) -> Self {
Self { Self {
partial_procs: BumpMap::new_in(arena), partial_procs: PartialProcs::new_in(arena),
imported_module_thunks: &[], imported_module_thunks: &[],
module_thunks: &[], module_thunks: &[],
pending_specializations: Some(BumpMap::new_in(arena)), pending_specializations: Some(BumpMap::new_in(arena)),
@ -460,7 +515,7 @@ impl<'a> Procs<'a> {
} }
fn get_partial_proc<'b>(&'b self, symbol: Symbol) -> Option<&'b PartialProc<'a>> { fn get_partial_proc<'b>(&'b self, symbol: Symbol) -> Option<&'b PartialProc<'a>> {
self.partial_procs.get(&symbol) self.partial_procs.get_symbol(symbol)
} }
pub fn get_specialized_procs_without_rc( pub fn get_specialized_procs_without_rc(
@ -539,9 +594,9 @@ impl<'a> Procs<'a> {
// register the pending specialization, so this gets code genned later // register the pending specialization, so this gets code genned later
add_pending(pending_specializations, symbol, layout, pending); add_pending(pending_specializations, symbol, layout, pending);
match self.partial_procs.entry(symbol) { match self.partial_procs.symbol_to_id(symbol) {
Entry::Occupied(occupied) => { Some(occupied) => {
let existing = occupied.get(); let existing = self.partial_procs.get_id(occupied);
// if we're adding the same partial proc twice, they must be the actual same! // if we're adding the same partial proc twice, they must be the actual same!
// //
// NOTE we can't skip extra work! we still need to make the specialization for this // NOTE we can't skip extra work! we still need to make the specialization for this
@ -553,7 +608,7 @@ impl<'a> Procs<'a> {
// the partial proc is already in there, do nothing // the partial proc is already in there, do nothing
} }
Entry::Vacant(vacant) => { None => {
let pattern_symbols = pattern_symbols.into_bump_slice(); let pattern_symbols = pattern_symbols.into_bump_slice();
let partial_proc = PartialProc { let partial_proc = PartialProc {
@ -564,7 +619,7 @@ impl<'a> Procs<'a> {
is_self_recursive, is_self_recursive,
}; };
vacant.insert(partial_proc); self.partial_procs.insert(symbol, partial_proc);
} }
} }
} }
@ -576,8 +631,10 @@ impl<'a> Procs<'a> {
let outside_layout = layout; let outside_layout = layout;
let partial_proc; let partial_proc_id = if let Some(partial_proc_id) =
if let Some(existing) = self.get_partial_proc(symbol) { self.partial_procs.symbol_to_id(symbol)
{
let existing = self.partial_procs.get_id(partial_proc_id);
// if we're adding the same partial proc twice, they must be the actual same! // if we're adding the same partial proc twice, they must be the actual same!
// //
// NOTE we can't skip extra work! we still need to make the specialization for this // NOTE we can't skip extra work! we still need to make the specialization for this
@ -587,21 +644,29 @@ impl<'a> Procs<'a> {
debug_assert_eq!(captured_symbols, existing.captured_symbols); debug_assert_eq!(captured_symbols, existing.captured_symbols);
debug_assert_eq!(is_self_recursive, existing.is_self_recursive); debug_assert_eq!(is_self_recursive, existing.is_self_recursive);
partial_proc = existing; partial_proc_id
} else { } else {
let pattern_symbols = pattern_symbols.into_bump_slice(); let pattern_symbols = pattern_symbols.into_bump_slice();
partial_proc = env.arena.alloc(PartialProc { let partial_proc = PartialProc {
annotation, annotation,
pattern_symbols, pattern_symbols,
captured_symbols, captured_symbols,
body: body.value, body: body.value,
is_self_recursive, is_self_recursive,
}); };
}
match specialize(env, self, symbol, layout_cache, pending, partial_proc) self.partial_procs.insert(symbol, partial_proc)
{ };
match specialize(
env,
self,
symbol,
layout_cache,
pending,
partial_proc_id,
) {
Ok((proc, layout)) => { Ok((proc, layout)) => {
let top_level = ProcLayout::from_raw(env.arena, layout); let top_level = ProcLayout::from_raw(env.arena, layout);
@ -666,7 +731,7 @@ impl<'a> Procs<'a> {
None => { None => {
let symbol = name; let symbol = name;
let partial_proc = match self.get_partial_proc(symbol) { let partial_proc_id = match self.partial_procs.symbol_to_id(symbol) {
Some(p) => p, Some(p) => p,
None => panic!("no partial_proc for {:?} in module {:?}", symbol, env.home), None => panic!("no partial_proc for {:?} in module {:?}", symbol, env.home),
}; };
@ -694,7 +759,7 @@ impl<'a> Procs<'a> {
layout_cache, layout_cache,
fn_var, fn_var,
Default::default(), Default::default(),
partial_proc, partial_proc_id,
) { ) {
Ok((proc, _ignore_layout)) => { Ok((proc, _ignore_layout)) => {
// the `layout` is a function pointer, while `_ignore_layout` can be a // the `layout` is a function pointer, while `_ignore_layout` can be a
@ -1725,7 +1790,7 @@ pub fn specialize_all<'a>(
continue; continue;
} }
Entry::Vacant(vacant) => { Entry::Vacant(vacant) => {
match procs.get_partial_proc(name) { match procs.partial_procs.symbol_to_id(name) {
Some(v) => { Some(v) => {
// Mark this proc as in-progress, so if we're dealing with // Mark this proc as in-progress, so if we're dealing with
// mutually recursive functions, we don't loop forever. // mutually recursive functions, we don't loop forever.
@ -1807,7 +1872,7 @@ fn specialize_externals_others_need<'a>(
let name = *symbol; let name = *symbol;
let partial_proc = match procs.get_partial_proc(name) { let partial_proc_id = match procs.partial_procs.symbol_to_id(name) {
Some(v) => v, Some(v) => v,
None => { None => {
panic!("Cannot find a partial proc for {:?}", name); panic!("Cannot find a partial proc for {:?}", name);
@ -1822,7 +1887,7 @@ fn specialize_externals_others_need<'a>(
layout_cache, layout_cache,
solved_type, solved_type,
BumpMap::new_in(env.arena), BumpMap::new_in(env.arena),
partial_proc, partial_proc_id,
) { ) {
Ok((proc, layout)) => { Ok((proc, layout)) => {
let top_level = ProcLayout::from_raw(env.arena, layout); let top_level = ProcLayout::from_raw(env.arena, layout);
@ -1900,32 +1965,28 @@ fn specialize_external<'a>(
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
fn_var: Variable, fn_var: Variable,
host_exposed_variables: &[(Symbol, Variable)], host_exposed_variables: &[(Symbol, Variable)],
partial_proc: &PartialProc<'a>, partial_proc_id: PartialProcId,
) -> Result<Proc<'a>, LayoutProblem> { ) -> Result<Proc<'a>, LayoutProblem> {
let PartialProc { let partial_proc = procs.partial_procs.get_id(partial_proc_id);
annotation, let captured_symbols = partial_proc.captured_symbols;
pattern_symbols,
captured_symbols,
body,
is_self_recursive,
} = partial_proc;
// unify the called function with the specialized signature, then specialize the function body // unify the called function with the specialized signature, then specialize the function body
let snapshot = env.subs.snapshot(); let snapshot = env.subs.snapshot();
let cache_snapshot = layout_cache.snapshot(); let cache_snapshot = layout_cache.snapshot();
let _unified = roc_unify::unify::unify(env.subs, *annotation, fn_var); let _unified = roc_unify::unify::unify(env.subs, partial_proc.annotation, fn_var);
// This will not hold for programs with type errors // This will not hold for programs with type errors
// let is_valid = matches!(unified, roc_unify::unify::Unified::Success(_)); // let is_valid = matches!(unified, roc_unify::unify::Unified::Success(_));
// debug_assert!(is_valid, "unificaton failure for {:?}", proc_name); // debug_assert!(is_valid, "unificaton failure for {:?}", proc_name);
// if this is a closure, add the closure record argument // if this is a closure, add the closure record argument
let pattern_symbols = match captured_symbols { let pattern_symbols = match partial_proc.captured_symbols {
CapturedSymbols::None => pattern_symbols, CapturedSymbols::None => partial_proc.pattern_symbols,
CapturedSymbols::Captured([]) => pattern_symbols, CapturedSymbols::Captured([]) => partial_proc.pattern_symbols,
CapturedSymbols::Captured(_) => { CapturedSymbols::Captured(_) => {
let mut temp = Vec::from_iter_in(pattern_symbols.iter().copied(), env.arena); let mut temp =
Vec::from_iter_in(partial_proc.pattern_symbols.iter().copied(), env.arena);
temp.push(Symbol::ARG_CLOSURE); temp.push(Symbol::ARG_CLOSURE);
temp.into_bump_slice() temp.into_bump_slice()
} }
@ -2022,13 +2083,14 @@ fn specialize_external<'a>(
} }
}; };
let recursivity = if *is_self_recursive { let recursivity = if partial_proc.is_self_recursive {
SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol())) SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol()))
} else { } else {
SelfRecursive::NotSelfRecursive SelfRecursive::NotSelfRecursive
}; };
let mut specialized_body = from_can(env, fn_var, body.clone(), procs, layout_cache); let body = partial_proc.body.clone();
let mut specialized_body = from_can(env, fn_var, body, procs, layout_cache);
match specialized { match specialized {
SpecializedLayout::FunctionPointerBody { SpecializedLayout::FunctionPointerBody {
@ -2418,24 +2480,13 @@ struct SpecializeFailure<'a> {
type SpecializeSuccess<'a> = (Proc<'a>, RawFunctionLayout<'a>); type SpecializeSuccess<'a> = (Proc<'a>, RawFunctionLayout<'a>);
fn specialize2<'a, 'b>(
env: &mut Env<'a, '_>,
partial_proc: &'b PartialProc<'a>,
procs: &'b mut Procs<'a>,
proc_name: Symbol,
layout_cache: &mut LayoutCache<'a>,
pending: PendingSpecialization,
) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> {
todo!()
}
fn specialize<'a, 'b>( fn specialize<'a, 'b>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &'b mut Procs<'a>, procs: &'b mut Procs<'a>,
proc_name: Symbol, proc_name: Symbol,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
pending: PendingSpecialization, pending: PendingSpecialization,
partial_proc: &'b PartialProc<'a>, partial_proc_id: PartialProcId,
) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> { ) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> {
let PendingSpecialization { let PendingSpecialization {
solved_type, solved_type,
@ -2450,7 +2501,7 @@ fn specialize<'a, 'b>(
layout_cache, layout_cache,
&solved_type, &solved_type,
host_exposed_aliases, host_exposed_aliases,
partial_proc, partial_proc_id,
) )
} }
@ -2480,7 +2531,7 @@ fn specialize_solved_type<'a>(
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
solved_type: &SolvedType, solved_type: &SolvedType,
host_exposed_aliases: BumpMap<Symbol, SolvedType>, host_exposed_aliases: BumpMap<Symbol, SolvedType>,
partial_proc: &PartialProc<'a>, partial_proc_id: PartialProcId,
) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> { ) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> {
specialize_variable_help( specialize_variable_help(
env, env,
@ -2489,7 +2540,7 @@ fn specialize_solved_type<'a>(
layout_cache, layout_cache,
|env| introduce_solved_type_to_subs(env, solved_type), |env| introduce_solved_type_to_subs(env, solved_type),
host_exposed_aliases, host_exposed_aliases,
partial_proc, partial_proc_id,
) )
} }
@ -2500,7 +2551,7 @@ fn specialize_variable<'a>(
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
fn_var: Variable, fn_var: Variable,
host_exposed_aliases: BumpMap<Symbol, SolvedType>, host_exposed_aliases: BumpMap<Symbol, SolvedType>,
partial_proc: &PartialProc<'a>, partial_proc_id: PartialProcId,
) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> { ) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> {
specialize_variable_help( specialize_variable_help(
env, env,
@ -2509,7 +2560,7 @@ fn specialize_variable<'a>(
layout_cache, layout_cache,
|_| fn_var, |_| fn_var,
host_exposed_aliases, host_exposed_aliases,
partial_proc, partial_proc_id,
) )
} }
@ -2520,7 +2571,7 @@ fn specialize_variable_help<'a, F>(
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
fn_var_thunk: F, fn_var_thunk: F,
host_exposed_aliases: BumpMap<Symbol, SolvedType>, host_exposed_aliases: BumpMap<Symbol, SolvedType>,
partial_proc: &PartialProc<'a>, partial_proc_id: PartialProcId,
) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>> ) -> Result<SpecializeSuccess<'a>, SpecializeFailure<'a>>
where where
F: FnOnce(&mut Env<'a, '_>) -> Variable, F: FnOnce(&mut Env<'a, '_>) -> Variable,
@ -2551,7 +2602,8 @@ where
}; };
// make sure rigid variables in the annotation are converted to flex variables // make sure rigid variables in the annotation are converted to flex variables
instantiate_rigids(env.subs, partial_proc.annotation); let annotation_var = procs.partial_procs.get_id(partial_proc_id).annotation;
instantiate_rigids(env.subs, annotation_var);
let mut host_exposed_variables = Vec::with_capacity_in(host_exposed_aliases.len(), env.arena); let mut host_exposed_variables = Vec::with_capacity_in(host_exposed_aliases.len(), env.arena);
@ -2568,7 +2620,7 @@ where
layout_cache, layout_cache,
fn_var, fn_var,
&host_exposed_variables, &host_exposed_variables,
partial_proc, partial_proc_id,
); );
match specialized { match specialized {
@ -3824,11 +3876,11 @@ pub fn with_hole<'a>(
// if it's in there, it's a call by name, otherwise it's a call by pointer // if it's in there, it's a call by name, otherwise it's a call by pointer
let is_known = |key| { let is_known = |key| {
// a proc in this module, or an imported symbol // a proc in this module, or an imported symbol
procs.partial_procs.contains_key(key) || env.is_imported_symbol(*key) procs.partial_procs.contains_key(key) || env.is_imported_symbol(key)
}; };
match loc_expr.value { match loc_expr.value {
roc_can::expr::Expr::Var(proc_name) if is_known(&proc_name) => { roc_can::expr::Expr::Var(proc_name) if is_known(proc_name) => {
// a call by a known name // a call by a known name
call_by_name( call_by_name(
env, env,
@ -4807,67 +4859,67 @@ pub fn from_can<'a>(
captured_symbols, captured_symbols,
.. ..
} => { } => {
// Extract Procs, but discard the resulting Expr::Load. if true || !procs.partial_procs.contains_key(*symbol) {
// That Load looks up the pointer, which we won't use here! let loc_body = *boxed_body;
let loc_body = *boxed_body; let is_self_recursive =
!matches!(recursive, roc_can::expr::Recursive::NotRecursive);
let is_self_recursive = // does this function capture any local values?
!matches!(recursive, roc_can::expr::Recursive::NotRecursive); let function_layout =
layout_cache.raw_from_var(env.arena, function_type, env.subs);
// does this function capture any local values? let captured_symbols = match function_layout {
let function_layout = Ok(RawFunctionLayout::Function(_, lambda_set, _)) => {
layout_cache.raw_from_var(env.arena, function_type, env.subs); if let Layout::Struct(&[]) =
lambda_set.runtime_representation()
let captured_symbols = match function_layout { {
Ok(RawFunctionLayout::Function(_, lambda_set, _)) => { CapturedSymbols::None
if let Layout::Struct(&[]) = lambda_set.runtime_representation() } else {
{ let mut temp =
CapturedSymbols::None Vec::from_iter_in(captured_symbols, env.arena);
} else { temp.sort();
let mut temp = CapturedSymbols::Captured(temp.into_bump_slice())
Vec::from_iter_in(captured_symbols, env.arena); }
temp.sort();
CapturedSymbols::Captured(temp.into_bump_slice())
} }
} Ok(RawFunctionLayout::ZeroArgumentThunk(_)) => {
Ok(RawFunctionLayout::ZeroArgumentThunk(_)) => { // top-level thunks cannot capture any variables
// top-level thunks cannot capture any variables debug_assert!(
debug_assert!( captured_symbols.is_empty(),
captured_symbols.is_empty(), "{:?} with layout {:?} {:?} {:?}",
"{:?} with layout {:?} {:?} {:?}", &captured_symbols,
&captured_symbols, function_layout,
function_layout, env.subs,
env.subs, (function_type, closure_type, closure_ext_var),
(function_type, closure_type, closure_ext_var), );
);
CapturedSymbols::None
}
Err(_) => {
// just allow this. see https://github.com/rtfeldman/roc/issues/1585
if captured_symbols.is_empty() {
CapturedSymbols::None CapturedSymbols::None
} else {
let mut temp =
Vec::from_iter_in(captured_symbols, env.arena);
temp.sort();
CapturedSymbols::Captured(temp.into_bump_slice())
} }
} Err(_) => {
}; // just allow this. see https://github.com/rtfeldman/roc/issues/1585
if captured_symbols.is_empty() {
CapturedSymbols::None
} else {
let mut temp =
Vec::from_iter_in(captured_symbols, env.arena);
temp.sort();
CapturedSymbols::Captured(temp.into_bump_slice())
}
}
};
let partial_proc = PartialProc::from_named_function( let partial_proc = PartialProc::from_named_function(
env, env,
layout_cache, layout_cache,
function_type, function_type,
arguments, arguments,
loc_body, loc_body,
captured_symbols, captured_symbols,
is_self_recursive, is_self_recursive,
return_type, return_type,
); );
procs.partial_procs.insert(*symbol, partial_proc); procs.partial_procs.insert(*symbol, partial_proc);
}
return from_can(env, variable, cont.value, procs, layout_cache); return from_can(env, variable, cont.value, procs, layout_cache);
} }
@ -6011,7 +6063,7 @@ fn can_reuse_symbol<'a>(
if env.is_imported_symbol(symbol) { if env.is_imported_symbol(symbol) {
Imported(symbol) Imported(symbol)
} else if procs.partial_procs.contains_key(&symbol) { } else if procs.partial_procs.contains_key(symbol) {
LocalFunction(symbol) LocalFunction(symbol)
} else { } else {
Value(symbol) Value(symbol)
@ -6673,7 +6725,7 @@ fn call_by_name_help<'a>(
assign_to_symbols(env, procs, layout_cache, iter, result) assign_to_symbols(env, procs, layout_cache, iter, result)
} }
None => { None => {
let opt_partial_proc = procs.get_partial_proc(proc_name); let opt_partial_proc = procs.partial_procs.symbol_to_id(proc_name);
/* /*
debug_assert_eq!( debug_assert_eq!(
@ -6808,7 +6860,7 @@ fn call_by_name_module_thunk<'a>(
force_thunk(env, proc_name, inner_layout, assigned, hole) force_thunk(env, proc_name, inner_layout, assigned, hole)
} }
None => { None => {
let opt_partial_proc = procs.get_partial_proc(proc_name); let opt_partial_proc = procs.partial_procs.symbol_to_id(proc_name);
match opt_partial_proc { match opt_partial_proc {
Some(partial_proc) => { Some(partial_proc) => {
@ -6928,7 +6980,7 @@ fn call_specialized_proc<'a>(
match procs match procs
.partial_procs .partial_procs
.get(&proc_name) .get_symbol(proc_name)
.map(|pp| &pp.captured_symbols) .map(|pp| &pp.captured_symbols)
{ {
Some(&CapturedSymbols::Captured(captured_symbols)) => { Some(&CapturedSymbols::Captured(captured_symbols)) => {