Extract specializations from PartialProc

This commit is contained in:
Richard Feldman 2020-04-18 11:24:25 -04:00
parent ee481e6713
commit f0d76825d2

View file

@ -10,10 +10,27 @@ use roc_region::all::{Located, Region};
use roc_types::subs::{Content, ContentHash, FlatType, Subs, Variable}; use roc_types::subs::{Content, ContentHash, FlatType, Subs, Variable};
use std::hash::Hash; use std::hash::Hash;
#[derive(Clone, Debug, PartialEq)]
pub struct PartialProc<'a> {
pub annotation: Variable,
pub patterns: Vec<'a, Symbol>,
pub body: roc_can::expr::Expr,
}
#[derive(Clone, Debug, PartialEq)]
pub struct Proc<'a> {
pub name: Symbol,
pub args: &'a [(Layout<'a>, Symbol)],
pub body: Expr<'a>,
pub closes_over: Layout<'a>,
pub ret_layout: Layout<'a>,
}
#[derive(Clone, Debug, PartialEq, Default)] #[derive(Clone, Debug, PartialEq, Default)]
pub struct Procs<'a> { pub struct Procs<'a> {
user_defined: MutMap<Symbol, PartialProc<'a>>, user_defined: MutMap<Symbol, PartialProc<'a>>,
anonymous: MutMap<Symbol, Option<Proc<'a>>>, anonymous: MutMap<Symbol, Option<Proc<'a>>>,
specializations: MutMap<ContentHash, (Symbol, Option<Proc<'a>>)>,
builtin: MutSet<Symbol>, builtin: MutSet<Symbol>,
} }
@ -28,14 +45,11 @@ impl<'a> Procs<'a> {
fn insert_specialization( fn insert_specialization(
&mut self, &mut self,
symbol: Symbol,
hash: ContentHash, hash: ContentHash,
spec_name: Symbol, spec_name: Symbol,
proc: Option<Proc<'a>>, proc: Option<Proc<'a>>,
) { ) {
self.user_defined self.specializations.insert(hash, (spec_name, proc));
.get_mut(&symbol)
.map(|partial_proc| partial_proc.specializations.insert(hash, (spec_name, proc)));
} }
fn get_user_defined(&self, symbol: Symbol) -> Option<&PartialProc<'a>> { fn get_user_defined(&self, symbol: Symbol) -> Option<&PartialProc<'a>> {
@ -44,11 +58,7 @@ impl<'a> Procs<'a> {
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
let anonymous: usize = self.anonymous.len(); let anonymous: usize = self.anonymous.len();
let user_defined: usize = self let user_defined: usize = self.specializations.len();
.user_defined
.values()
.map(|v| v.specializations.len())
.sum();
anonymous + user_defined anonymous + user_defined
} }
@ -64,10 +74,8 @@ impl<'a> Procs<'a> {
pub fn as_map(&self) -> MutMap<Symbol, Option<Proc<'a>>> { pub fn as_map(&self) -> MutMap<Symbol, Option<Proc<'a>>> {
let mut result = MutMap::default(); let mut result = MutMap::default();
for partial_proc in self.user_defined.values() { for (symbol, opt_proc) in self.specializations.values() {
for (_, (symbol, opt_proc)) in partial_proc.specializations.clone().into_iter() { result.insert(*symbol, opt_proc.clone());
result.insert(symbol, opt_proc);
}
} }
for (symbol, proc) in self.anonymous.clone().into_iter() { for (symbol, proc) in self.anonymous.clone().into_iter() {
@ -82,23 +90,6 @@ impl<'a> Procs<'a> {
} }
} }
#[derive(Clone, Debug, PartialEq)]
pub struct PartialProc<'a> {
pub annotation: Variable,
pub patterns: Vec<'a, Symbol>,
pub body: roc_can::expr::Expr,
pub specializations: MutMap<ContentHash, (Symbol, Option<Proc<'a>>)>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct Proc<'a> {
pub name: Symbol,
pub args: &'a [(Layout<'a>, Symbol)],
pub body: Expr<'a>,
pub closes_over: Layout<'a>,
pub ret_layout: Layout<'a>,
}
pub struct Env<'a, 'i> { pub struct Env<'a, 'i> {
pub arena: &'a Bump, pub arena: &'a Bump,
pub subs: &'a mut Subs, pub subs: &'a mut Subs,
@ -458,7 +449,6 @@ fn from_can<'a>(
annotation, annotation,
patterns: arg_symbols, patterns: arg_symbols,
body: body.value, body: body.value,
specializations: MutMap::default(),
}, },
); );
symbol symbol
@ -1305,36 +1295,42 @@ fn call_by_name<'a>(
Vec<'a, Symbol>, Vec<'a, Symbol>,
)>; )>;
let specialized_proc_name = if let Some(partial_proc) = procs.get_user_defined(proc_name) { let specialized_proc_name = match procs.get_user_defined(proc_name) {
let content_hash = ContentHash::from_var(fn_var, env.subs); Some(partial_proc) => {
let content_hash = ContentHash::from_var(fn_var, env.subs);
if let Some(specialization) = partial_proc.specializations.get(&content_hash) { match procs.specializations.get(&content_hash) {
Some(specialization) => {
opt_specialize_body = None;
// a specialization with this type hash already exists, use its symbol
specialization.0
}
None => {
opt_specialize_body = Some((
content_hash,
partial_proc.annotation,
partial_proc.body.clone(),
partial_proc.patterns.clone(),
));
// generate a symbol for this specialization
env.fresh_symbol()
}
}
}
None => {
opt_specialize_body = None; opt_specialize_body = None;
// a specialization with this type hash already exists, use its symbol // This happens for built-in symbols (they are never defined as a Closure)
specialization.0 procs.insert_builtin(proc_name);
} else { proc_name
opt_specialize_body = Some((
content_hash,
partial_proc.annotation,
partial_proc.body.clone(),
partial_proc.patterns.clone(),
));
// generate a symbol for this specialization
env.fresh_symbol()
} }
} else {
opt_specialize_body = None;
// This happens for built-in symbols (they are never defined as a Closure)
procs.insert_builtin(proc_name);
proc_name
}; };
if let Some((content_hash, annotation, body, loc_patterns)) = opt_specialize_body { if let Some((content_hash, annotation, body, loc_patterns)) = opt_specialize_body {
// register proc, so specialization doesn't loop infinitely // register proc, so specialization doesn't loop infinitely
procs.insert_specialization(proc_name, content_hash, specialized_proc_name, None); procs.insert_specialization(content_hash, specialized_proc_name, None);
let arg_vars = loc_args.iter().map(|v| v.0).collect::<std::vec::Vec<_>>(); let arg_vars = loc_args.iter().map(|v| v.0).collect::<std::vec::Vec<_>>();
@ -1351,7 +1347,7 @@ fn call_by_name<'a>(
) )
.ok(); .ok();
procs.insert_specialization(proc_name, content_hash, specialized_proc_name, proc); procs.insert_specialization(content_hash, specialized_proc_name, proc);
} }
// generate actual call // generate actual call