infer closure size!

This commit is contained in:
Folkert 2020-10-14 23:26:44 +02:00
parent 7f1dd80392
commit b440ab90f6
2 changed files with 58 additions and 17 deletions

View file

@ -1372,16 +1372,13 @@ fn specialize_external<'a>(
let specialized_body = from_can(env, body, procs, layout_cache); let specialized_body = from_can(env, body, procs, layout_cache);
let (proc_args, ret_layout) = let (proc_args, closes_over, ret_layout) =
build_specialized_proc_from_var(env, layout_cache, pattern_symbols, fn_var)?; build_specialized_proc_from_var(env, layout_cache, pattern_symbols, fn_var)?;
// reset subs, so we don't get type errors when specializing for a different signature // reset subs, so we don't get type errors when specializing for a different signature
layout_cache.rollback_to(cache_snapshot); layout_cache.rollback_to(cache_snapshot);
env.subs.rollback_to(snapshot); env.subs.rollback_to(snapshot);
// TODO WRONG
let closes_over_layout = Layout::Struct(&[]);
let recursivity = if is_self_recursive { let recursivity = if is_self_recursive {
SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol())) SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol()))
} else { } else {
@ -1392,7 +1389,7 @@ fn specialize_external<'a>(
name: proc_name, name: proc_name,
args: proc_args, args: proc_args,
body: specialized_body, body: specialized_body,
closes_over: closes_over_layout, closes_over,
ret_layout, ret_layout,
is_self_recursive: recursivity, is_self_recursive: recursivity,
}; };
@ -1406,10 +1403,17 @@ fn build_specialized_proc_from_var<'a>(
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
pattern_symbols: &[Symbol], pattern_symbols: &[Symbol],
fn_var: Variable, fn_var: Variable,
) -> Result<(&'a [(Layout<'a>, Symbol)], Layout<'a>), LayoutProblem> { ) -> Result<(&'a [(Layout<'a>, Symbol)], Layout<'a>, Layout<'a>), LayoutProblem> {
match env.subs.get_without_compacting(fn_var).content { match env.subs.get_without_compacting(fn_var).content {
Content::Structure(FlatType::Func(pattern_vars, _closure_var, ret_var)) => { Content::Structure(FlatType::Func(pattern_vars, closure_var, ret_var)) => {
build_specialized_proc(env, layout_cache, pattern_symbols, &pattern_vars, ret_var) build_specialized_proc(
env,
layout_cache,
pattern_symbols,
&pattern_vars,
Some(closure_var),
ret_var,
)
} }
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args))
if !pattern_symbols.is_empty() => if !pattern_symbols.is_empty() =>
@ -1421,8 +1425,7 @@ fn build_specialized_proc_from_var<'a>(
} }
_ => { _ => {
// a top-level constant 0-argument thunk // a top-level constant 0-argument thunk
build_specialized_proc(env, layout_cache, pattern_symbols, &[], None, fn_var)
build_specialized_proc(env, layout_cache, pattern_symbols, &[], fn_var)
} }
} }
} }
@ -1433,8 +1436,9 @@ fn build_specialized_proc<'a>(
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
pattern_symbols: &[Symbol], pattern_symbols: &[Symbol],
pattern_vars: &[Variable], pattern_vars: &[Variable],
closure_var: Option<Variable>,
ret_var: Variable, ret_var: Variable,
) -> Result<(&'a [(Layout<'a>, Symbol)], Layout<'a>), LayoutProblem> { ) -> Result<(&'a [(Layout<'a>, Symbol)], Layout<'a>, Layout<'a>), LayoutProblem> {
let mut proc_args = Vec::with_capacity_in(pattern_vars.len(), &env.arena); let mut proc_args = Vec::with_capacity_in(pattern_vars.len(), &env.arena);
debug_assert_eq!( debug_assert_eq!(
@ -1451,11 +1455,18 @@ fn build_specialized_proc<'a>(
let proc_args = proc_args.into_bump_slice(); let proc_args = proc_args.into_bump_slice();
let closes_over = match closure_var {
Some(cvar) => layout_cache
.from_var(&env.arena, cvar, env.subs)
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)),
None => Layout::Struct(&[]),
};
let ret_layout = layout_cache let ret_layout = layout_cache
.from_var(&env.arena, ret_var, env.subs) .from_var(&env.arena, ret_var, env.subs)
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err));
Ok((proc_args, ret_layout)) Ok((proc_args, closes_over, ret_layout))
} }
fn specialize<'a>( fn specialize<'a>(

View file

@ -129,6 +129,7 @@ impl Pools {
struct State { struct State {
env: Env, env: Env,
mark: Mark, mark: Mark,
vars_by_symbol: MutMap<Symbol, Variable>,
} }
pub fn run( pub fn run(
@ -141,9 +142,10 @@ pub fn run(
let state = State { let state = State {
env: env.clone(), env: env.clone(),
mark: Mark::NONE.next(), mark: Mark::NONE.next(),
vars_by_symbol: MutMap::default(),
}; };
let rank = Rank::toplevel(); let rank = Rank::toplevel();
let state = solve( let mut state = solve(
env, env,
state, state,
rank, rank,
@ -154,6 +156,10 @@ pub fn run(
constraint, constraint,
); );
// by default, state.vars_by_symbol only gives back top-level symbols and their variable
// for closure size inference, we need all of the symbols, we do that here
state.env.vars_by_symbol.extend(state.vars_by_symbol);
(Solved(subs), state.env) (Solved(subs), state.env)
} }
@ -168,9 +174,10 @@ pub fn run_in_place(
let state = State { let state = State {
env: env.clone(), env: env.clone(),
mark: Mark::NONE.next(), mark: Mark::NONE.next(),
vars_by_symbol: MutMap::default(),
}; };
let rank = Rank::toplevel(); let rank = Rank::toplevel();
let state = solve( let mut state = solve(
env, env,
state, state,
rank, rank,
@ -181,13 +188,17 @@ pub fn run_in_place(
constraint, constraint,
); );
// by default, state.vars_by_symbol only gives back top-level symbols and their variable
// for closure size inference, we need all of the symbols, we do that here
state.env.vars_by_symbol.extend(state.vars_by_symbol);
state.env state.env
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn solve( fn solve(
env: &Env, env: &Env,
state: State, mut state: State,
rank: Rank, rank: Rank,
pools: &mut Pools, pools: &mut Pools,
problems: &mut Vec<TypeError>, problems: &mut Vec<TypeError>,
@ -197,12 +208,22 @@ fn solve(
) -> State { ) -> State {
match constraint { match constraint {
// True => state, // True => state,
True | SaveTheEnvironment => { True => {
state
.vars_by_symbol
.extend(env.vars_by_symbol.iter().map(|(x, y)| (*x, *y)));
state
}
SaveTheEnvironment => {
// NOTE deviation: elm only copies the env into the state on SaveTheEnvironment // NOTE deviation: elm only copies the env into the state on SaveTheEnvironment
let mut copy = state; let mut copy = state;
copy.env = env.clone(); copy.env = env.clone();
copy.vars_by_symbol
.extend(env.vars_by_symbol.iter().map(|(x, y)| (*x, *y)));
copy copy
} }
Eq(typ, expectation, category, region) => { Eq(typ, expectation, category, region) => {
@ -389,7 +410,7 @@ fn solve(
) )
} }
ret_con if let_con.rigid_vars.is_empty() && let_con.flex_vars.is_empty() => { ret_con if let_con.rigid_vars.is_empty() && let_con.flex_vars.is_empty() => {
let state = solve( let mut state = solve(
env, env,
state, state,
rank, rank,
@ -422,6 +443,10 @@ fn solve(
} }
} }
state
.vars_by_symbol
.extend(new_env.vars_by_symbol.iter().map(|(x, y)| (*x, *y)));
let new_state = solve( let new_state = solve(
&new_env, &new_env,
state, state,
@ -480,6 +505,10 @@ fn solve(
); );
} }
state
.vars_by_symbol
.extend(new_env.vars_by_symbol.iter().map(|(x, y)| (*x, *y)));
// run solver in next pool // run solver in next pool
// Solve the assignments' constraints first. // Solve the assignments' constraints first.
@ -565,6 +594,7 @@ fn solve(
let temp_state = State { let temp_state = State {
env: new_state.env, env: new_state.env,
mark: final_mark, mark: final_mark,
vars_by_symbol: new_state.vars_by_symbol,
}; };
// Now solve the body, using the new vars_by_symbol which includes // Now solve the body, using the new vars_by_symbol which includes