Merge pull request #4912 from roc-lang/remove-polymorphic-expression-compilation

Rip out polymorphic expression compilation
This commit is contained in:
Folkert de Vries 2023-01-24 21:35:08 +01:00 committed by GitHub
commit 8e5efe67b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 225 additions and 388 deletions

View file

@ -872,14 +872,16 @@ impl UseDepth {
}
}
/// When walking a function body, we may encounter specialized usages of polymorphic symbols. For
/// example
type NumberSpecializations<'a> = VecMap<InLayout<'a>, (Symbol, UseDepth)>;
/// When walking a function body, we may encounter specialized usages of polymorphic number symbols.
/// For example
///
/// myTag = A
/// use1 : [A, B]
/// use1 = myTag
/// use2 : [A, B, C]
/// use2 = myTag
/// n = 1
/// use1 : U8
/// use1 = 1
/// use2 : Nat
/// use2 = 2
///
/// We keep track of the specializations of `myTag` and create fresh symbols when there is more
/// than one, so that a unique def can be created for each.
@ -890,55 +892,39 @@ struct SymbolSpecializations<'a>(
// 2. the number of specializations of a symbol in a def is even smaller (almost always only one)
// So, a linear VecMap is preferrable. Use a two-layered one to make (1) extraction of defs easy
// and (2) reads of a certain symbol be determined by its first occurrence, not its last.
VecMap<Symbol, VecMap<SpecializationMark<'a>, (Variable, Symbol, UseDepth)>>,
VecMap<Symbol, NumberSpecializations<'a>>,
);
impl<'a> SymbolSpecializations<'a> {
/// Inserts a known specialization for a symbol. Returns the overwritten specialization, if any.
pub fn get_or_insert_known(
&mut self,
symbol: Symbol,
mark: SpecializationMark<'a>,
specialization_var: Variable,
specialization_symbol: Symbol,
deepest_use: UseDepth,
) -> Option<(Variable, Symbol, UseDepth)> {
self.0.get_or_insert(symbol, Default::default).insert(
mark,
(specialization_var, specialization_symbol, deepest_use),
)
/// Mark a let-generalized symbol eligible for specialization.
/// Only those bound to number literals can be compiled polymorphically.
fn mark_eligible(&mut self, symbol: Symbol) {
let _old = self.0.insert(symbol, VecMap::with_capacity(1));
debug_assert!(
_old.is_none(),
"overwriting specializations for {:?}",
symbol
);
}
/// Removes all specializations for a symbol, returning the type and symbol of each specialization.
pub fn remove(
&mut self,
symbol: Symbol,
) -> impl ExactSizeIterator<Item = (SpecializationMark<'a>, (Variable, Symbol, UseDepth))> {
fn remove(&mut self, symbol: Symbol) -> Option<NumberSpecializations<'a>> {
self.0
.remove(&symbol)
.map(|(_, specializations)| specializations)
.unwrap_or_default()
.into_iter()
}
/// Expects and removes at most a single specialization symbol for the given requested symbol.
/// A symbol may have no specializations if it is never referenced in a body, so it is possible
/// for this to return None.
pub fn remove_single(&mut self, symbol: Symbol) -> Option<Symbol> {
let mut specializations = self.remove(symbol);
debug_assert!(
specializations.len() < 2,
"Symbol {:?} has multiple specializations",
symbol
);
specializations.next().map(|(_, (_, symbol, _))| symbol)
}
pub fn is_empty(&self) -> bool {
fn is_empty(&self) -> bool {
self.0.is_empty()
}
fn maybe_get_specialized(&self, symbol: Symbol, layout: InLayout) -> Symbol {
self.0
.get(&symbol)
.and_then(|m| m.get(&layout))
.map(|x| x.0)
.unwrap_or(symbol)
}
}
#[derive(Clone, Debug, Default)]
@ -1115,7 +1101,7 @@ impl<'a> Procs<'a> {
// if we've already specialized this one, no further work is needed.
if !already_specialized {
if self.is_module_thunk(name.name()) {
debug_assert!(layout.arguments.is_empty());
debug_assert!(layout.arguments.is_empty(), "{:?}", name);
}
let needs_suspended_specialization =
@ -1312,6 +1298,14 @@ impl<'a> Procs<'a> {
symbol: Symbol,
specialization_var: Variable,
) -> Symbol {
let symbol_specializations = match self.symbol_specializations.0.get_mut(&symbol) {
Some(m) => m,
None => {
// Not eligible for multiple specializations
return symbol;
}
};
let arena = env.arena;
let subs: &Subs = env.subs;
@ -1322,32 +1316,6 @@ impl<'a> Procs<'a> {
Err(_) => return symbol,
};
let is_closure = matches!(
subs.get_content_without_compacting(specialization_var),
Content::Structure(FlatType::Func(..))
);
let function_mark = if is_closure {
let fn_layout = match layout_cache.raw_from_var(arena, specialization_var, subs) {
Ok(layout) => layout,
// This can happen when the def symbol has a type error. In such cases just use the
// def symbol, which is erroring.
Err(_) => return symbol,
};
Some(fn_layout)
} else {
None
};
let specialization_mark = SpecializationMark {
layout,
function_mark,
};
let symbol_specializations = self
.symbol_specializations
.0
.get_or_insert(symbol, Default::default);
// For the first specialization, always reuse the current symbol. The vast majority of defs
// only have one instance type, so this preserves readability of the IR.
// TODO: turn me off and see what breaks.
@ -1362,10 +1330,8 @@ impl<'a> Procs<'a> {
};
let current_use = self.specialization_stack.current_use_depth();
let (_var, specialized_symbol, deepest_use) = symbol_specializations
.get_or_insert(specialization_mark, || {
(specialization_var, make_specialized_symbol(), current_use)
});
let (specialized_symbol, deepest_use) = symbol_specializations
.get_or_insert(layout, || (make_specialized_symbol(), current_use));
if deepest_use.is_nested_use_in(&current_use) {
*deepest_use = current_use;
@ -1378,12 +1344,12 @@ impl<'a> Procs<'a> {
pub fn get_symbol_specializations_used_in_body(
&self,
symbol: Symbol,
) -> Option<impl Iterator<Item = (Variable, Symbol)> + '_> {
) -> Option<impl Iterator<Item = Symbol> + '_> {
let this_use = self.specialization_stack.current_use_depth();
self.symbol_specializations.0.get(&symbol).map(move |l| {
l.iter().filter_map(move |(_, (var, sym, deepest_use))| {
l.iter().filter_map(move |(_, (sym, deepest_use))| {
if deepest_use.is_nested_use_in(&this_use) {
Some((*var, *sym))
Some(*sym)
} else {
None
}
@ -2423,6 +2389,14 @@ fn from_can_let<'a>(
lower_rest!(variable, cont.value)
}
Accessor(accessor_data) => {
let fresh_record_symbol = env.unique_symbol();
let closure_data = accessor_data.to_closure_data(fresh_record_symbol);
debug_assert_eq!(*symbol, closure_data.name);
register_noncapturing_closure(env, procs, *symbol, closure_data);
lower_rest!(variable, cont.value)
}
Var(original, _) | AbilityMember(original, _, _)
if procs.get_partial_proc(original).is_none() =>
{
@ -2633,96 +2607,56 @@ fn from_can_let<'a>(
lower_rest!(variable, new_outer)
}
e @ (Int(..) | Float(..) | Num(..)) => {
let (str, val): (Box<str>, IntOrFloatValue) = match e {
Int(_, _, str, val, _) => (str, IntOrFloatValue::Int(val)),
Float(_, _, str, val, _) => (str, IntOrFloatValue::Float(val)),
Num(_, str, val, _) => (str, IntOrFloatValue::Int(val)),
_ => unreachable!(),
};
procs.symbol_specializations.mark_eligible(*symbol);
let mut stmt = lower_rest!(variable, cont.value);
let needed_specializations = procs.symbol_specializations.remove(*symbol).unwrap();
let zero_specialization = if needed_specializations.is_empty() {
let layout = layout_cache
.from_var(env.arena, def.expr_var, env.subs)
.unwrap();
Some((layout, *symbol))
} else {
None
};
// Layer on the specialized numbers
for (layout, sym) in needed_specializations
.into_iter()
.map(|(lay, (sym, _))| (lay, sym))
.chain(zero_specialization)
{
let literal = make_num_literal(&layout_cache.interner, layout, &str, val);
stmt = Stmt::Let(
sym,
Expr::Literal(literal.to_expr_literal()),
layout,
env.arena.alloc(stmt),
);
}
stmt
}
_ => {
let rest = lower_rest!(variable, cont.value);
// Remove all the requested symbol specializations now, since this is the
// def site and hence we won't need them any higher up.
let mut needed_specializations = procs.symbol_specializations.remove(*symbol);
match needed_specializations.len() {
0 => {
// We don't need any specializations, that means this symbol is never
// referenced.
with_hole(
env,
def.loc_expr.value,
def.expr_var,
procs,
layout_cache,
*symbol,
env.arena.alloc(rest),
)
}
// We do need specializations
1 => {
let (_specialization_mark, (var, specialized_symbol, _deepest_use)) =
needed_specializations.next().unwrap();
// Make sure rigid variables in the annotation are converted to flex variables.
instantiate_rigids(env.subs, def.expr_var);
// Unify the expr_var with the requested specialization once.
let _res = env.unify(
procs.externals_we_need.values_mut(),
layout_cache,
var,
def.expr_var,
);
with_hole(
env,
def.loc_expr.value,
def.expr_var,
procs,
layout_cache,
specialized_symbol,
env.arena.alloc(rest),
)
}
_n => {
let mut stmt = rest;
// Make sure rigid variables in the annotation are converted to flex variables.
instantiate_rigids(env.subs, def.expr_var);
// Need to eat the cost and create a specialized version of the body for
// each specialization.
for (_specialization_mark, (var, specialized_symbol, _deepest_use)) in
needed_specializations
{
use roc_can::copy::deep_copy_type_vars_into_expr;
let (new_def_expr_var, specialized_expr) = deep_copy_type_vars_into_expr(
env.subs,
def.expr_var,
&def.loc_expr.value,
)
.expect(
"expr marked as having specializations, but it has no type variables!",
);
let _res = env.unify(
procs.externals_we_need.values_mut(),
layout_cache,
var,
new_def_expr_var,
);
stmt = with_hole(
env,
specialized_expr,
new_def_expr_var,
procs,
layout_cache,
specialized_symbol,
env.arena.alloc(stmt),
);
}
stmt
}
}
with_hole(
env,
def.loc_expr.value,
def.expr_var,
procs,
layout_cache,
*symbol,
env.arena.alloc(rest),
)
}
};
}
@ -2739,23 +2673,8 @@ fn from_can_let<'a>(
// layer on any default record fields
for (symbol, variable, expr) in assignments {
let specialization_symbol = procs
.symbol_specializations
.remove_single(symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(symbol);
let hole = env.arena.alloc(stmt);
stmt = with_hole(
env,
expr,
variable,
procs,
layout_cache,
specialization_symbol,
hole,
);
stmt = with_hole(env, expr, variable, procs, layout_cache, symbol, hole);
}
match def.loc_expr.value {
@ -3500,8 +3419,7 @@ fn specialize_proc_help<'a>(
match specs_used_in_body {
Some(mut specs) => {
let spec_symbol =
specs.next().map(|(_, sym)| sym).unwrap_or(symbol);
let spec_symbol = specs.next().unwrap_or(symbol);
if specs.next().is_some() {
internal_error!(
"polymorphic symbol captures not supported yet"
@ -3653,12 +3571,11 @@ fn specialize_proc_help<'a>(
_ => unreachable!("to closure or not to closure?"),
}
proc_args.iter_mut().for_each(|(_layout, symbol)| {
proc_args.iter_mut().for_each(|(layout, symbol)| {
// Grab the specialization symbol, if it exists.
*symbol = procs
.symbol_specializations
.remove_single(*symbol)
.unwrap_or(*symbol);
.maybe_get_specialized(*symbol, *layout)
});
let closure_data_layout = match opt_closure_layout {
@ -7413,16 +7330,7 @@ fn store_pattern_help<'a>(
match can_pat {
Identifier(symbol) => {
// An identifier in a pattern can define at most one specialization!
// Remove any requested specializations for this name now, since this is the definition site.
let specialization_symbol = procs
.symbol_specializations
.remove_single(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
substitute_in_exprs(env.arena, &mut stmt, specialization_symbol, outer_symbol);
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol);
}
Underscore => {
// do nothing
@ -7437,16 +7345,7 @@ fn store_pattern_help<'a>(
StorePattern::NotProductive(stmt) => stmt,
};
// An identifier in a pattern can define at most one specialization!
// Remove any requested specializations for this name now, since this is the definition site.
let specialization_symbol = procs
.symbol_specializations
.remove_single(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
substitute_in_exprs(env.arena, &mut stmt, specialization_symbol, outer_symbol);
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol);
return StorePattern::Productive(stmt);
}
@ -7528,19 +7427,7 @@ fn store_pattern_help<'a>(
for destruct in destructs {
match &destruct.typ {
DestructType::Required(symbol) => {
let specialization_symbol = procs
.symbol_specializations
.remove_single(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
substitute_in_exprs(
env.arena,
&mut stmt,
specialization_symbol,
outer_symbol,
);
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol);
}
DestructType::Guard(guard_pattern) => {
return store_pattern_help(
@ -7708,15 +7595,9 @@ fn store_list_pattern<'a>(
Identifier(symbol) => {
let (load, needed_stores) = compute_element_load(env);
// Pattern can define only one specialization
let symbol = procs
.symbol_specializations
.remove_single(*symbol)
.unwrap_or(*symbol);
// store immediately in the given symbol
(
Stmt::Let(symbol, load, element_layout, env.arena.alloc(stmt)),
Stmt::Let(*symbol, load, element_layout, env.arena.alloc(stmt)),
needed_stores,
)
}
@ -7803,14 +7684,8 @@ fn store_tag_pattern<'a>(
match argument {
Identifier(symbol) => {
// Pattern can define only one specialization
let symbol = procs
.symbol_specializations
.remove_single(*symbol)
.unwrap_or(*symbol);
// store immediately in the given symbol
stmt = Stmt::Let(symbol, load, arg_layout, env.arena.alloc(stmt));
stmt = Stmt::Let(*symbol, load, arg_layout, env.arena.alloc(stmt));
is_productive = true;
}
Underscore => {
@ -7885,20 +7760,7 @@ fn store_newtype_pattern<'a>(
match argument {
Identifier(symbol) => {
// store immediately in the given symbol, removing it specialization if it had any
let specialization_symbol = procs
.symbol_specializations
.remove_single(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
arg_layout,
env.arena.alloc(stmt),
);
stmt = Stmt::Let(*symbol, load, arg_layout, env.arena.alloc(stmt));
is_productive = true;
}
Underscore => {
@ -7960,37 +7822,11 @@ fn store_record_destruct<'a>(
match &destruct.typ {
DestructType::Required(symbol) => {
// A destructure can define at most one specialization!
// Remove any requested specializations for this name now, since this is the definition site.
let specialization_symbol = procs
.symbol_specializations
.remove_single(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
destruct.layout,
env.arena.alloc(stmt),
);
stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt));
}
DestructType::Guard(guard_pattern) => match &guard_pattern {
Identifier(symbol) => {
let specialization_symbol = procs
.symbol_specializations
.remove_single(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
destruct.layout,
env.arena.alloc(stmt),
);
stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt));
}
Underscore => {
// important that this is special-cased to do nothing: mono record patterns will extract all the
@ -8137,54 +7973,17 @@ where
// See my git blame for details.
debug_assert!(!procs.partial_procs.contains_key(right));
// Otherwise we're dealing with an alias whose usages will tell us what specializations we
// need. So let's figure those out first.
let result = build_rest(env, procs, layout_cache);
// The specializations we wanted of the symbol on the LHS of this alias.
let needed_specializations_of_left = procs.symbol_specializations.remove(left);
if procs.is_imported_module_thunk(right) {
// if this is an imported symbol, then we must make sure it is
// specialized, and wrap the original in a function pointer.
let mut result = result;
add_needed_external(procs, env, variable, LambdaName::no_niche(right));
let no_specializations_needed = needed_specializations_of_left.len() == 0;
let needed_specializations_of_left = needed_specializations_of_left
.map(|(_, spec)| Some(spec))
// HACK: sometimes specializations can be lost, for example for `x` in
// x = Bool.true
// p = \_ -> x == 1
// that's because when specializing `p`, we collect specializations for `x`, but then
// drop all of them when leaving the body of `p`, because `x` is an argument of `p` in
// such a case.
// So, if we have no recorded specializations, suppose we are in a case like this, and
// generate the default implementation.
//
// TODO: we should fix this properly. I think the way to do it is to only have proc
// specialization only drop specializations of non-captured symbols. That's because
// captured symbols can only ever be specialized outside the closure.
// After that is done, remove this hack.
.chain(if no_specializations_needed {
[Some((
variable,
left,
procs.specialization_stack.current_use_depth(),
))]
} else {
[None]
})
.flatten();
let res_layout = layout_cache.from_var(env.arena, variable, env.subs);
let layout = return_on_layout_error!(env, res_layout, "handle_variable_aliasing");
for (variable, left, _deepest_use) in needed_specializations_of_left {
add_needed_external(procs, env, variable, LambdaName::no_niche(right));
let res_layout = layout_cache.from_var(env.arena, variable, env.subs);
let layout = return_on_layout_error!(env, res_layout, "handle_variable_aliasing");
result = force_thunk(env, right, layout, left, env.arena.alloc(result));
}
result
force_thunk(env, right, layout, left, env.arena.alloc(result))
} else if env.is_imported_symbol(right) {
// if this is an imported symbol, then we must make sure it is
// specialized, and wrap the original in a function pointer.
@ -8193,41 +7992,8 @@ where
// then we must construct its closure; since imported symbols have no closure, we use the empty struct
let_empty_struct(left, env.arena.alloc(result))
} else {
// Otherwise, we are referencing a non-proc value.
// We need to lift all specializations of "left" to be specializations of "right".
let mut scratchpad_update_specializations = std::vec::Vec::new();
let left_had_specialization_symbols = needed_specializations_of_left.len() > 0;
for (specialization_mark, (specialized_var, specialized_sym, deepest_use)) in
needed_specializations_of_left
{
let old_specialized_sym = procs.symbol_specializations.get_or_insert_known(
right,
specialization_mark,
specialized_var,
specialized_sym,
deepest_use,
);
if let Some((_, old_specialized_sym, _)) = old_specialized_sym {
scratchpad_update_specializations.push((old_specialized_sym, specialized_sym));
}
}
let mut result = result;
if left_had_specialization_symbols {
// If the symbol is specialized, only the specializations need to be updated.
for (old_specialized_sym, specialized_sym) in
scratchpad_update_specializations.into_iter()
{
substitute_in_exprs(env.arena, &mut result, old_specialized_sym, specialized_sym);
}
} else {
substitute_in_exprs(env.arena, &mut result, left, right);
}
substitute_in_exprs(env.arena, &mut result, left, right);
result
}
}
@ -10165,6 +9931,7 @@ fn from_can_record_destruct<'a>(
})
}
#[derive(Debug, Clone, Copy)]
enum IntOrFloatValue {
Int(IntValue),
Float(f64),