Pass 1 of polymorphic specialization of defs

This commit is contained in:
Ayaz Hafiz 2022-04-25 18:59:39 -04:00
parent 293bc6b15b
commit 7182180622
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
5 changed files with 309 additions and 257 deletions

View file

@ -26,6 +26,7 @@ use roc_types::subs::{
Content, ExhaustiveMark, FlatType, RedundantMark, StorageSubs, Subs, Variable,
VariableSubsSlice,
};
use roc_unify::unify::Mode;
use std::collections::HashMap;
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
@ -263,80 +264,24 @@ impl<'a> PartialProc<'a> {
}
}
#[derive(Clone, Debug)]
enum PolymorphicExpr {
/// A root ability member, which must be specialized at a call site, for example
/// "hash" which must be specialized to an exact symbol implementing "hash" for a type.
AbilityMember(Symbol),
/// A polymorphic expression we inline at the usage site.
Expr(roc_can::expr::Expr, Variable),
}
#[derive(Clone, Copy, Debug)]
struct AbilityMember(Symbol);
/// A table of aliases of ability member symbols.
#[derive(Clone, Debug)]
enum PartialExprLink {
/// The root polymorphic expression
Sink(PolymorphicExpr),
/// A hop in a partial expression alias chain
Aliases(Symbol),
}
struct AbilityAliases(BumpMap<Symbol, AbilityMember>);
/// A table of symbols to polymorphic expressions. For example, in the program
///
/// n = 1
///
/// asU8 : U8 -> U8
/// asU8 = \_ -> 1
///
/// asU32 : U32 -> U8
/// asU32 = \_ -> 1
///
/// asU8 n + asU32 n
///
/// The expression bound by `n` doesn't have a definite layout until it is used
/// at the call sites `asU8 n`, `asU32 n`.
///
/// Polymorphic *functions* are stored in `PartialProc`s, since functions are
/// non longer first-class once we finish lowering to the IR.
#[derive(Clone, Debug)]
struct PartialExprs(BumpMap<Symbol, PartialExprLink>);
impl PartialExprs {
impl AbilityAliases {
fn new_in(arena: &Bump) -> Self {
Self(BumpMap::new_in(arena))
}
fn insert(&mut self, symbol: Symbol, expr: PolymorphicExpr) {
self.0.insert(symbol, PartialExprLink::Sink(expr));
fn insert(&mut self, symbol: Symbol, member: AbilityMember) {
self.0.insert(symbol, member);
}
fn insert_alias(&mut self, symbol: Symbol, aliases: Symbol) {
self.0.insert(symbol, PartialExprLink::Aliases(aliases));
}
fn contains(&self, symbol: Symbol) -> bool {
self.0.contains_key(&symbol)
}
fn get(&mut self, mut symbol: Symbol) -> Option<&PolymorphicExpr> {
// In practice the alias chain is very short
loop {
match self.0.get(&symbol) {
None => {
return None;
}
Some(&PartialExprLink::Aliases(real_symbol)) => {
symbol = real_symbol;
}
Some(PartialExprLink::Sink(expr)) => {
return Some(expr);
}
}
}
}
fn remove(&mut self, symbol: Symbol) {
debug_assert!(self.contains(symbol));
self.0.remove(&symbol);
fn get(&self, symbol: Symbol) -> Option<&AbilityMember> {
self.0.get(&symbol)
}
}
@ -801,26 +746,28 @@ impl<'a> Specialized<'a> {
#[derive(Clone, Debug)]
pub struct Procs<'a> {
pub partial_procs: PartialProcs<'a>,
partial_exprs: PartialExprs,
ability_member_aliases: AbilityAliases,
pub imported_module_thunks: &'a [Symbol],
pub module_thunks: &'a [Symbol],
pending_specializations: PendingSpecializations<'a>,
specialized: Specialized<'a>,
pub runtime_errors: BumpMap<Symbol, &'a str>,
pub externals_we_need: BumpMap<ModuleId, ExternalSpecializations>,
pub needed_symbol_specializations: BumpMap<(Symbol, Layout<'a>), (Variable, Symbol)>,
}
impl<'a> Procs<'a> {
pub fn new_in(arena: &'a Bump) -> Self {
Self {
partial_procs: PartialProcs::new_in(arena),
partial_exprs: PartialExprs::new_in(arena),
ability_member_aliases: AbilityAliases::new_in(arena),
imported_module_thunks: &[],
module_thunks: &[],
pending_specializations: PendingSpecializations::Finding(Suspended::new_in(arena)),
specialized: Specialized::default(),
runtime_errors: BumpMap::new_in(arena),
externals_we_need: BumpMap::new_in(arena),
needed_symbol_specializations: BumpMap::new_in(arena),
}
}
}
@ -4081,7 +4028,7 @@ pub fn with_hole<'a>(
}
CopyExisting(index) => {
let record_needs_specialization =
procs.partial_exprs.contains(structure);
procs.ability_member_aliases.get(structure).is_some();
let specialized_structure_sym = if record_needs_specialization {
// We need to specialize the record now; create a new one for it.
// TODO: reuse this symbol for all updates
@ -4308,70 +4255,71 @@ pub fn with_hole<'a>(
unreachable!("calling a non-closure layout")
}
},
UnspecializedExpr(symbol) => match procs.partial_exprs.get(symbol).unwrap()
{
&PolymorphicExpr::AbilityMember(member) => {
let proc_name = get_specialization(env, fn_var, member).expect("Recorded as an ability member, but it doesn't have a specialization");
UnspecializedExpr(symbol) => {
match procs.ability_member_aliases.get(symbol).unwrap() {
&AbilityMember(member) => {
let proc_name = get_specialization(env, fn_var, member).expect("Recorded as an ability member, but it doesn't have a specialization");
// a call by a known name
return call_by_name(
env,
procs,
fn_var,
proc_name,
loc_args,
layout_cache,
assigned,
hole,
);
// a call by a known name
return call_by_name(
env,
procs,
fn_var,
proc_name,
loc_args,
layout_cache,
assigned,
hole,
);
} // TODO(POLYEXPR)
// PolymorphicExpr::Expr(lambda_expr, lambda_expr_var) => {
// match full_layout {
// RawFunctionLayout::Function(
// arg_layouts,
// lambda_set,
// ret_layout,
// ) => {
// let closure_data_symbol = env.unique_symbol();
// result = match_on_lambda_set(
// env,
// lambda_set,
// closure_data_symbol,
// arg_symbols,
// arg_layouts,
// ret_layout,
// assigned,
// hole,
// );
// let snapshot = env.subs.snapshot();
// let cache_snapshot = layout_cache.snapshot();
// let _unified = roc_unify::unify::unify(
// env.subs,
// fn_var,
// *lambda_expr_var,
// roc_unify::unify::Mode::EQ,
// );
// result = with_hole(
// env,
// lambda_expr.clone(),
// fn_var,
// procs,
// layout_cache,
// closure_data_symbol,
// env.arena.alloc(result),
// );
// env.subs.rollback_to(snapshot);
// layout_cache.rollback_to(cache_snapshot);
// }
// RawFunctionLayout::ZeroArgumentThunk(_) => {
// unreachable!("calling a non-closure layout")
// }
// }
// }
}
PolymorphicExpr::Expr(lambda_expr, lambda_expr_var) => {
match full_layout {
RawFunctionLayout::Function(
arg_layouts,
lambda_set,
ret_layout,
) => {
let closure_data_symbol = env.unique_symbol();
result = match_on_lambda_set(
env,
lambda_set,
closure_data_symbol,
arg_symbols,
arg_layouts,
ret_layout,
assigned,
hole,
);
let snapshot = env.subs.snapshot();
let cache_snapshot = layout_cache.snapshot();
let _unified = roc_unify::unify::unify(
env.subs,
fn_var,
*lambda_expr_var,
roc_unify::unify::Mode::EQ,
);
result = with_hole(
env,
lambda_expr.clone(),
fn_var,
procs,
layout_cache,
closure_data_symbol,
env.arena.alloc(result),
);
env.subs.rollback_to(snapshot);
layout_cache.rollback_to(cache_snapshot);
}
RawFunctionLayout::ZeroArgumentThunk(_) => {
unreachable!("calling a non-closure layout")
}
}
}
},
}
NotASymbol => {
// the expression is not a symbol. That means it's an expression
// evaluating to a function value.
@ -4746,8 +4694,10 @@ fn get_specialization<'a>(
#[allow(clippy::too_many_arguments)]
fn construct_closure_data<'a, I>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
// TODO(POLYEXPR): remove?
_procs: &mut Procs<'a>,
// TODO(POLYEXPR): remove?
_layout_cache: &mut LayoutCache<'a>,
lambda_set: LambdaSet<'a>,
name: Symbol,
symbols: I,
@ -4764,7 +4714,7 @@ where
// arguments with a polymorphic type that we have to deal with
let mut polymorphic_arguments = Vec::new_in(env.arena);
let mut result = match lambda_set.layout_for_member(name) {
let result = match lambda_set.layout_for_member(name) {
ClosureRepresentation::Union {
tag_id,
alphabetic_order_fields: field_layouts,
@ -4775,9 +4725,9 @@ where
// them ordered by their alignment requirements
let mut combined = Vec::with_capacity_in(symbols.len(), env.arena);
for ((symbol, variable), layout) in symbols.zip(field_layouts.iter()) {
if procs.partial_exprs.contains(*symbol) {
polymorphic_arguments.push((*symbol, *variable));
}
// if procs.partial_exprs.contains(*symbol) {
// polymorphic_arguments.push((*symbol, *variable));
// }
combined.push((*symbol, layout))
}
@ -4810,9 +4760,9 @@ where
// them ordered by their alignment requirements
let mut combined = Vec::with_capacity_in(symbols.len(), env.arena);
for ((symbol, variable), layout) in symbols.zip(field_layouts.iter()) {
if procs.partial_exprs.contains(*symbol) {
polymorphic_arguments.push((*symbol, *variable));
}
// if procs.partial_exprs.contains(*symbol) {
// polymorphic_arguments.push((*symbol, *variable));
// }
combined.push((*symbol, layout))
}
@ -4868,9 +4818,12 @@ where
// TODO: this is not quite right. What we should actually be doing is removing references to
// polymorphic expressions from the captured symbols, and allowing the specializations of those
// symbols to be inlined when specializing the closure body elsewhere.
for (symbol, var) in polymorphic_arguments {
result = specialize_symbol(env, procs, layout_cache, Some(var), symbol, result, symbol);
}
// TODO(POLYEXPR)
// for &&(symbol, var) in symbols {
// if procs.ability_member_aliases.contains(symbol) {
// result = specialize_symbol(env, procs, layout_cache, Some(var), symbol, result, symbol);
// }
// }
result
}
@ -5232,7 +5185,7 @@ fn sorted_field_symbols<'a>(
let alignment = layout.alignment_bytes(env.target_info);
let symbol = possible_reuse_symbol(env, procs, &arg.value);
let symbol = possible_reuse_symbol_or_spec(env, procs, layout_cache, &arg.value, var);
field_symbols_temp.push((alignment, symbol, ((var, arg), &*env.arena.alloc(symbol))));
}
field_symbols_temp.sort_by(|a, b| b.0.cmp(&a.0));
@ -5366,32 +5319,6 @@ fn register_capturing_closure<'a>(
}
}
fn is_literal_like(expr: &roc_can::expr::Expr) -> bool {
use roc_can::expr::Expr::*;
matches!(
expr,
Num(..)
| Int(..)
| Float(..)
| List { .. }
| Str(_)
| ZeroArgumentTag { .. }
| Tag { .. }
| Record { .. }
| Call(..)
)
}
fn expr_is_polymorphic<'a>(
env: &mut Env<'a, '_>,
expr: &roc_can::expr::Expr,
expr_var: Variable,
) -> bool {
// TODO: I don't think we need the `is_literal_like` check, but taking it slow for now...
let is_flex_or_rigid = |c: &Content| matches!(c, Content::FlexVar(_) | Content::RigidVar(_));
is_literal_like(expr) && env.subs.var_contains_content(expr_var, is_flex_or_rigid)
}
pub fn from_can<'a>(
env: &mut Env<'a, '_>,
variable: Variable,
@ -5662,38 +5589,91 @@ pub fn from_can<'a>(
return from_can(env, variable, new_outer, procs, layout_cache);
}
ref body if expr_is_polymorphic(env, body, def.expr_var) => {
// This is a pattern like
//
// n = 1
// asU8 n
//
// At the definition site `n = 1` we only know `1` to have the type `[Int *]`,
// which won't be refined until the call `asU8 n`. Add it as a partial expression
// that will be specialized at each concrete usage site.
procs.partial_exprs.insert(
*symbol,
PolymorphicExpr::Expr(def.loc_expr.value, def.expr_var),
);
// TODO(POLYEXPR)
// ref body if expr_is_polymorphic(env, body, def.expr_var) => {
// // This is a pattern like
// //
// // n = 1
// // asU8 n
// //
// // At the definition site `n = 1` we only know `1` to have the type `[Int *]`,
// // which won't be refined until the call `asU8 n`. Add it as a partial expression
// // that will be specialized at each concrete usage site.
// procs.ability_member_aliases.insert(
// *symbol,
// PolymorphicExpr::Expr(def.loc_expr.value, def.expr_var),
// );
let result = from_can(env, variable, cont.value, procs, layout_cache);
// let result = from_can(env, variable, cont.value, procs, layout_cache);
// We won't see this symbol again.
procs.partial_exprs.remove(*symbol);
// // We won't see this symbol again.
// procs.ability_member_aliases.remove(*symbol);
return result;
}
// return result;
// }
_ => {
let rest = from_can(env, variable, cont.value, procs, layout_cache);
return with_hole(
env,
def.loc_expr.value,
def.expr_var,
procs,
layout_cache,
*symbol,
env.arena.alloc(rest),
);
let needs_def_specializations = procs
.needed_symbol_specializations
.keys()
.find(|(s, _)| s == symbol)
.is_some();
if !needs_def_specializations {
return with_hole(
env,
def.loc_expr.value,
def.expr_var,
procs,
layout_cache,
*symbol,
env.arena.alloc(rest),
);
}
// We do need specializations
let mut stmt = rest;
let needed_specializations = procs
.needed_symbol_specializations
.drain_filter(|(s, _), _| s == symbol)
.collect::<std::vec::Vec<_>>();
for ((_, wanted_layout), (var, specialized_symbol)) in
needed_specializations
{
// let res =
// roc_unify::unify::unify(env.subs, var, def.expr_var, Mode::EQ);
let content = env.subs.get_content_without_compacting(def.expr_var);
let c = roc_types::subs::SubsFmtContent(content, env.subs);
let content2 = env.subs.get_content_without_compacting(var);
let c2 = roc_types::subs::SubsFmtContent(content2, env.subs);
let layout = layout_cache
.from_var(env.arena, def.expr_var, env.subs)
.unwrap();
dbg!(
specialized_symbol,
c,
c2,
layout,
wanted_layout,
var,
def.expr_var,
);
stmt = with_hole(
env,
def.loc_expr.value.clone(),
// def.expr_var,
var,
procs,
layout_cache,
specialized_symbol,
env.arena.alloc(stmt),
);
}
return stmt;
}
}
}
@ -6335,19 +6315,22 @@ fn store_pattern_help<'a>(
match can_pat {
Identifier(symbol) => {
if let Some(&PolymorphicExpr::Expr(_, var)) = procs.partial_exprs.get(outer_symbol) {
// It might be the case that symbol we're storing hasn't been reified to a value
// yet, if it's polymorphic. Do that now.
stmt = specialize_symbol(
env,
procs,
layout_cache,
Some(var),
*symbol,
stmt,
outer_symbol,
);
}
// TODO(POLYEXPR)
// if let Some(&PolymorphicExpr::Expr(_, var)) =
// procs.ability_member_aliases.get(outer_symbol)
// {
// // It might be the case that symbol we're storing hasn't been reified to a value
// // yet, if it's polymorphic. Do that now.
// stmt = specialize_symbol(
// env,
// procs,
// layout_cache,
// Some(var),
// *symbol,
// stmt,
// outer_symbol,
// );
// }
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol);
}
@ -6712,7 +6695,7 @@ fn can_reuse_symbol<'a>(
Imported(symbol)
} else if procs.partial_procs.contains_key(symbol) {
LocalFunction(symbol)
} else if procs.partial_exprs.contains(symbol) {
} else if procs.ability_member_aliases.get(symbol).is_some() {
UnspecializedExpr(symbol)
} else {
Value(symbol)
@ -6733,6 +6716,39 @@ fn possible_reuse_symbol<'a>(
}
}
// TODO(POLYEXPR): unify with possible_reuse_symbol
fn possible_reuse_symbol_or_spec<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
expr: &roc_can::expr::Expr,
var: Variable,
) -> Symbol {
match can_reuse_symbol(env, procs, expr) {
ReuseSymbol::Value(symbol) => {
let wanted_layout = layout_cache.from_var(env.arena, var, env.subs).unwrap();
let mut fake_subs = env.subs.clone();
let new_var = roc_types::subs::deep_copy_var_to(&mut fake_subs, env.subs, var);
let content = roc_types::subs::SubsFmtContent(
env.subs.get_content_without_compacting(new_var),
env.subs,
);
dbg!(new_var, content);
let (_, specialized_symbol) = procs
.needed_symbol_specializations
.entry((symbol, wanted_layout))
.or_insert_with(|| (new_var, env.unique_symbol()));
dbg!(symbol, *specialized_symbol, wanted_layout, var);
*specialized_symbol
}
_ => env.unique_symbol(),
}
}
fn handle_variable_aliasing<'a, BuildRest>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
@ -6747,15 +6763,15 @@ where
{
if env.abilities_store.is_ability_member_name(right) {
procs
.partial_exprs
.insert(left, PolymorphicExpr::AbilityMember(right));
.ability_member_aliases
.insert(left, AbilityMember(right));
return build_rest(env, procs, layout_cache);
}
if procs.partial_exprs.contains(right) {
if let Some(&ability_member) = procs.ability_member_aliases.get(right) {
// If `right` links to a partial expression, make sure we link `left` to it as well, so
// that usages of it will be specialized when building the rest of the program.
procs.partial_exprs.insert_alias(left, right);
procs.ability_member_aliases.insert(left, ability_member);
return build_rest(env, procs, layout_cache);
}
@ -6790,7 +6806,28 @@ where
} else {
// This should be a fully specialized value. Replace the alias with the original symbol.
let mut result = build_rest(env, procs, layout_cache);
// We need to lift all specializations of "left" to be specializations of "right".
let to_update = procs
.needed_symbol_specializations
.drain_filter(|(s, _), _| s == &left)
.collect::<std::vec::Vec<_>>();
let mut scratchpad_update_specializations = std::vec::Vec::new();
for ((_, layout), (specialized_var, specialized_sym)) in to_update.into_iter() {
let old_specialized_sym = procs
.needed_symbol_specializations
.insert((right, layout), (specialized_var, specialized_sym));
if let Some((_, old_specialized_sym)) = old_specialized_sym {
scratchpad_update_specializations.push((old_specialized_sym, specialized_sym));
}
}
substitute_in_exprs(env.arena, &mut result, left, right);
for (old_specialized_sym, specialized_sym) in scratchpad_update_specializations.into_iter()
{
substitute_in_exprs(env.arena, &mut result, old_specialized_sym, specialized_sym);
}
result
}
}
@ -6829,34 +6866,36 @@ fn specialize_symbol<'a>(
result: Stmt<'a>,
original: Symbol,
) -> Stmt<'a> {
if let Some(PolymorphicExpr::Expr(expr, expr_var)) = procs.partial_exprs.get(original) {
// Specialize the expression type now, based off the `arg_var` we've been given.
// TODO: cache the specialized result
let snapshot = env.subs.snapshot();
let cache_snapshot = layout_cache.snapshot();
let _unified = roc_unify::unify::unify(
env.subs,
arg_var.unwrap(),
*expr_var,
roc_unify::unify::Mode::EQ,
);
// TODO(POLYEXPR)
// if let Some(PolymorphicExpr::Expr(expr, expr_var)) = procs.ability_member_aliases.get(original)
// {
// // Specialize the expression type now, based off the `arg_var` we've been given.
// // TODO: cache the specialized result
// let snapshot = env.subs.snapshot();
// let cache_snapshot = layout_cache.snapshot();
// let _unified = roc_unify::unify::unify(
// env.subs,
// arg_var.unwrap(),
// *expr_var,
// roc_unify::unify::Mode::EQ,
// );
let result = with_hole(
env,
expr.clone(),
*expr_var,
procs,
layout_cache,
symbol,
env.arena.alloc(result),
);
// let result = with_hole(
// env,
// expr.clone(),
// *expr_var,
// procs,
// layout_cache,
// symbol,
// env.arena.alloc(result),
// );
// Restore the prior state so as not to interfere with future specializations.
env.subs.rollback_to(snapshot);
layout_cache.rollback_to(cache_snapshot);
// // Restore the prior state so as not to interfere with future specializations.
// env.subs.rollback_to(snapshot);
// layout_cache.rollback_to(cache_snapshot);
return result;
}
// return result;
// }
match procs.get_partial_proc(original) {
None => {
@ -7040,8 +7079,17 @@ fn assign_to_symbol<'a>(
original,
)
}
Value(_) => {
// symbol is already defined; nothing else to do here
Value(_symbol) => {
//let wanted_layout = layout_cache.from_var(env.arena, arg_var, env.subs).unwrap();
//let (_, specialized_symbol) = procs
// .needed_symbol_specializations
// .entry((symbol, wanted_layout))
// .or_insert_with(|| (arg_var, env.unique_symbol()));
//dbg!(symbol, wanted_layout);
//let mut result = result;
//substitute_in_exprs(env.arena, &mut result, symbol, *specialized_symbol);
result
}
NotASymbol => with_hole(
@ -7188,8 +7236,14 @@ fn call_by_name<'a>(
let arena = env.arena;
let arg_symbols = Vec::from_iter_in(
loc_args.iter().map(|(_, arg_expr)| {
possible_reuse_symbol(env, procs, &arg_expr.value)
loc_args.iter().map(|(arg_var, arg_expr)| {
possible_reuse_symbol_or_spec(
env,
procs,
layout_cache,
&arg_expr.value,
*arg_var,
)
}),
arena,
)
@ -7280,11 +7334,9 @@ fn call_by_name_help<'a>(
// the arguments given to the function, stored in symbols
let mut field_symbols = Vec::with_capacity_in(loc_args.len(), arena);
field_symbols.extend(
loc_args
.iter()
.map(|(_, arg_expr)| possible_reuse_symbol(env, procs, &arg_expr.value)),
);
field_symbols.extend(loc_args.iter().map(|(arg_var, arg_expr)| {
possible_reuse_symbol_or_spec(env, procs, layout_cache, &arg_expr.value, *arg_var)
}));
// If required, add an extra argument to the layout that is the captured environment
// afterwards, we MUST make sure the number of arguments in the layout matches the