Bugfix handle more specialization instances

This commit is contained in:
Ayaz Hafiz 2022-05-05 11:12:50 -04:00
parent de924de266
commit 19e8b37402
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58

View file

@ -770,6 +770,26 @@ impl<'a> Procs<'a> {
needed_symbol_specializations: BumpMap::new_in(arena), needed_symbol_specializations: BumpMap::new_in(arena),
} }
} }
/// Expects and removes a single specialization symbol for the given requested symbol.
fn remove_single_symbol_specialization(&mut self, symbol: Symbol) -> Option<Symbol> {
let mut specialized_symbols = self
.needed_symbol_specializations
.drain_filter(|(sym, _), _| sym == &symbol);
let specialization_symbol = specialized_symbols
.next()
.map(|(_, (_, specialized_symbol))| specialized_symbol);
debug_assert_eq!(
specialized_symbols.count(),
0,
"Symbol {:?} has multiple specializations",
symbol
);
specialization_symbol
}
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -2468,11 +2488,11 @@ fn specialize_external<'a>(
// An argument from the closure list may have taken on a specialized symbol // An argument from the closure list may have taken on a specialized symbol
// name during the evaluation of the def body. If this is the case, load the // name during the evaluation of the def body. If this is the case, load the
// specialized name rather than the original captured name! // specialized name rather than the original captured name!
let get_specialized_name = |symbol, layout| { let mut get_specialized_name = |symbol, layout| {
procs procs
.needed_symbol_specializations .needed_symbol_specializations
.get(&(symbol, layout)) .remove(&(symbol, layout))
.map(|(_, specialized)| *specialized) .map(|(_, specialized)| specialized)
.unwrap_or(symbol) .unwrap_or(symbol)
}; };
@ -3304,7 +3324,7 @@ pub fn with_hole<'a>(
} else { } else {
// this may be a destructure pattern // this may be a destructure pattern
let (mono_pattern, assignments) = let (mono_pattern, assignments) =
match from_can_pattern(env, layout_cache, &def.loc_pattern.value) { match from_can_pattern(env, procs, layout_cache, &def.loc_pattern.value) {
Ok(v) => v, Ok(v) => v,
Err(_runtime_error) => { Err(_runtime_error) => {
// todo // todo
@ -5492,6 +5512,7 @@ pub fn from_can<'a>(
} }
LetNonRec(def, cont, outer_annotation) => { LetNonRec(def, cont, outer_annotation) => {
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
// dbg!(symbol, &def.loc_expr.value);
match def.loc_expr.value { match def.loc_expr.value {
roc_can::expr::Expr::Closure(closure_data) => { roc_can::expr::Expr::Closure(closure_data) => {
register_capturing_closure(env, procs, layout_cache, *symbol, closure_data); register_capturing_closure(env, procs, layout_cache, *symbol, closure_data);
@ -5706,7 +5727,7 @@ pub fn from_can<'a>(
// this may be a destructure pattern // this may be a destructure pattern
let (mono_pattern, assignments) = let (mono_pattern, assignments) =
match from_can_pattern(env, layout_cache, &def.loc_pattern.value) { match from_can_pattern(env, procs, layout_cache, &def.loc_pattern.value) {
Ok(v) => v, Ok(v) => v,
Err(_) => todo!(), Err(_) => todo!(),
}; };
@ -5737,8 +5758,22 @@ pub fn from_can<'a>(
// layer on any default record fields // layer on any default record fields
for (symbol, variable, expr) in assignments { for (symbol, variable, expr) in assignments {
let specialization_symbol = procs
.remove_single_symbol_specialization(symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(symbol);
let hole = env.arena.alloc(stmt); let hole = env.arena.alloc(stmt);
stmt = with_hole(env, expr, variable, procs, layout_cache, symbol, hole); stmt = with_hole(
env,
expr,
variable,
procs,
layout_cache,
specialization_symbol,
hole,
);
} }
if let roc_can::expr::Expr::Var(outer_symbol) = def.loc_expr.value { if let roc_can::expr::Expr::Var(outer_symbol) = def.loc_expr.value {
@ -5772,6 +5807,7 @@ pub fn from_can<'a>(
fn to_opt_branches<'a>( fn to_opt_branches<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
branches: std::vec::Vec<roc_can::expr::WhenBranch>, branches: std::vec::Vec<roc_can::expr::WhenBranch>,
exhaustive_mark: ExhaustiveMark, exhaustive_mark: ExhaustiveMark,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
@ -5798,7 +5834,7 @@ fn to_opt_branches<'a>(
} }
for loc_pattern in when_branch.patterns { for loc_pattern in when_branch.patterns {
match from_can_pattern(env, layout_cache, &loc_pattern.value) { match from_can_pattern(env, procs, layout_cache, &loc_pattern.value) {
Ok((mono_pattern, assignments)) => { Ok((mono_pattern, assignments)) => {
loc_branches.push(( loc_branches.push((
Loc::at(loc_pattern.region, mono_pattern.clone()), Loc::at(loc_pattern.region, mono_pattern.clone()),
@ -5876,7 +5912,7 @@ fn from_can_when<'a>(
// We can't know what to return! // We can't know what to return!
return Stmt::RuntimeError("Hit a 0-branch when expression"); return Stmt::RuntimeError("Hit a 0-branch when expression");
} }
let opt_branches = to_opt_branches(env, branches, exhaustive_mark, layout_cache); let opt_branches = to_opt_branches(env, procs, branches, exhaustive_mark, layout_cache);
let cond_layout = let cond_layout =
return_on_layout_error!(env, layout_cache.from_var(env.arena, cond_var, env.subs)); return_on_layout_error!(env, layout_cache.from_var(env.arena, cond_var, env.subs));
@ -6341,7 +6377,15 @@ fn store_pattern_help<'a>(
match can_pat { match can_pat {
Identifier(symbol) => { Identifier(symbol) => {
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol); // An identifier in a pattern can define at most one specialization!
// Remove any requested specializations for this name now, since this is the definition site.
let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
substitute_in_exprs(env.arena, &mut stmt, specialization_symbol, outer_symbol);
} }
Underscore => { Underscore => {
// do nothing // do nothing
@ -6402,7 +6446,18 @@ fn store_pattern_help<'a>(
for destruct in destructs { for destruct in destructs {
match &destruct.typ { match &destruct.typ {
DestructType::Required(symbol) => { DestructType::Required(symbol) => {
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol); let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
substitute_in_exprs(
env.arena,
&mut stmt,
specialization_symbol,
outer_symbol,
);
} }
DestructType::Guard(guard_pattern) => { DestructType::Guard(guard_pattern) => {
return store_pattern_help( return store_pattern_help(
@ -6480,10 +6535,11 @@ fn store_tag_pattern<'a>(
match argument { match argument {
Identifier(symbol) => { Identifier(symbol) => {
// TODO: use procs.remove_single_symbol_specialization
let symbol = procs let symbol = procs
.needed_symbol_specializations .needed_symbol_specializations
.get(&(*symbol, arg_layout)) .remove(&(*symbol, arg_layout))
.map(|(_, sym)| *sym) .map(|(_, sym)| sym)
.unwrap_or(*symbol); .unwrap_or(*symbol);
// store immediately in the given symbol // store immediately in the given symbol
@ -6562,8 +6618,19 @@ fn store_newtype_pattern<'a>(
match argument { match argument {
Identifier(symbol) => { Identifier(symbol) => {
// store immediately in the given symbol // store immediately in the given symbol, removing it specialization if it had any
stmt = Stmt::Let(*symbol, load, arg_layout, env.arena.alloc(stmt)); let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
arg_layout,
env.arena.alloc(stmt),
);
is_productive = true; is_productive = true;
} }
Underscore => { Underscore => {
@ -6625,11 +6692,35 @@ fn store_record_destruct<'a>(
match &destruct.typ { match &destruct.typ {
DestructType::Required(symbol) => { DestructType::Required(symbol) => {
stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt)); // A destructure can define at most one specialization!
// Remove any requested specializations for this name now, since this is the definition site.
let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
destruct.layout,
env.arena.alloc(stmt),
);
} }
DestructType::Guard(guard_pattern) => match &guard_pattern { DestructType::Guard(guard_pattern) => match &guard_pattern {
Identifier(symbol) => { Identifier(symbol) => {
stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt)); let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
destruct.layout,
env.arena.alloc(stmt),
);
} }
Underscore => { Underscore => {
// important that this is special-cased to do nothing: mono record patterns will extract all the // important that this is special-cased to do nothing: mono record patterns will extract all the
@ -6816,48 +6907,53 @@ where
return build_rest(env, procs, layout_cache); return build_rest(env, procs, layout_cache);
} }
// Otherwise we're dealing with an alias to something that doesn't need to be specialized, or if procs.partial_procs.contains_key(right) {
// whose usages will already be specialized in the rest of the program. // This is an alias to a function defined in this module.
if procs.is_imported_module_thunk(right) { // Attach the alias, then build the rest of the module, so that we reference and specialize
// the correct proc.
procs.partial_procs.insert_alias(left, right);
return build_rest(env, procs, layout_cache);
}
// Otherwise we're dealing with an alias whose usages will tell us what specializations we
// need. So let's figure those out first.
let result = build_rest(env, procs, layout_cache); let result = build_rest(env, procs, layout_cache);
// The specializations we wanted of the symbol on the LHS of this alias.
let needed_specializations_of_left = procs
.needed_symbol_specializations
.drain_filter(|(s, _), _| s == &left)
.collect::<std::vec::Vec<_>>();
if procs.is_imported_module_thunk(right) {
// if this is an imported symbol, then we must make sure it is // if this is an imported symbol, then we must make sure it is
// specialized, and wrap the original in a function pointer. // specialized, and wrap the original in a function pointer.
let mut result = result;
for (_, (variable, left)) in needed_specializations_of_left.into_iter() {
add_needed_external(procs, env, variable, right); add_needed_external(procs, env, variable, right);
let res_layout = layout_cache.from_var(env.arena, variable, env.subs); let res_layout = layout_cache.from_var(env.arena, variable, env.subs);
let layout = return_on_layout_error!(env, res_layout); let layout = return_on_layout_error!(env, res_layout);
force_thunk(env, right, layout, left, env.arena.alloc(result)) result = force_thunk(env, right, layout, left, env.arena.alloc(result));
}
result
} else if env.is_imported_symbol(right) { } else if env.is_imported_symbol(right) {
let result = build_rest(env, procs, layout_cache);
// if this is an imported symbol, then we must make sure it is // if this is an imported symbol, then we must make sure it is
// specialized, and wrap the original in a function pointer. // specialized, and wrap the original in a function pointer.
add_needed_external(procs, env, variable, right); add_needed_external(procs, env, variable, right);
// then we must construct its closure; since imported symbols have no closure, we use the empty struct // then we must construct its closure; since imported symbols have no closure, we use the empty struct
let_empty_struct(left, env.arena.alloc(result)) let_empty_struct(left, env.arena.alloc(result))
} else if procs.partial_procs.contains_key(right) {
// This is an alias to a function defined in this module.
// Attach the alias, then build the rest of the module, so that we reference and specialize
// the correct proc.
procs.partial_procs.insert_alias(left, right);
build_rest(env, procs, layout_cache)
} else { } else {
// This should be a fully specialized value. Replace the alias with the original symbol.
let mut result = build_rest(env, procs, layout_cache);
// We need to lift all specializations of "left" to be specializations of "right". // We need to lift all specializations of "left" to be specializations of "right".
let to_update = procs
.needed_symbol_specializations
.drain_filter(|(s, _), _| s == &left)
.collect::<std::vec::Vec<_>>();
let mut scratchpad_update_specializations = std::vec::Vec::new(); let mut scratchpad_update_specializations = std::vec::Vec::new();
let left_had_specialization_symbols = !to_update.is_empty(); let left_had_specialization_symbols = !needed_specializations_of_left.is_empty();
for ((_, layout), (specialized_var, specialized_sym)) in to_update.into_iter() { for ((_, layout), (specialized_var, specialized_sym)) in
needed_specializations_of_left.into_iter()
{
let old_specialized_sym = procs let old_specialized_sym = procs
.needed_symbol_specializations .needed_symbol_specializations
.insert((right, layout), (specialized_var, specialized_sym)); .insert((right, layout), (specialized_var, specialized_sym));
@ -6867,6 +6963,7 @@ where
} }
} }
let mut result = result;
if left_had_specialization_symbols { if left_had_specialization_symbols {
// If the symbol is specialized, only the specializations need to be updated. // If the symbol is specialized, only the specializations need to be updated.
for (old_specialized_sym, specialized_sym) in for (old_specialized_sym, specialized_sym) in
@ -7894,6 +7991,7 @@ pub struct WhenBranch<'a> {
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn from_can_pattern<'a>( fn from_can_pattern<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
can_pattern: &roc_can::pattern::Pattern, can_pattern: &roc_can::pattern::Pattern,
) -> Result< ) -> Result<
@ -7904,13 +8002,14 @@ fn from_can_pattern<'a>(
RuntimeError, RuntimeError,
> { > {
let mut assignments = Vec::new_in(env.arena); let mut assignments = Vec::new_in(env.arena);
let pattern = from_can_pattern_help(env, layout_cache, can_pattern, &mut assignments)?; let pattern = from_can_pattern_help(env, procs, layout_cache, can_pattern, &mut assignments)?;
Ok((pattern, assignments)) Ok((pattern, assignments))
} }
fn from_can_pattern_help<'a>( fn from_can_pattern_help<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
can_pattern: &roc_can::pattern::Pattern, can_pattern: &roc_can::pattern::Pattern,
assignments: &mut Vec<'a, (Symbol, Variable, roc_can::expr::Expr)>, assignments: &mut Vec<'a, (Symbol, Variable, roc_can::expr::Expr)>,
@ -8105,7 +8204,13 @@ fn from_can_pattern_help<'a>(
let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena);
for ((_, loc_pat), layout) in arguments.iter().zip(field_layouts.iter()) { for ((_, loc_pat), layout) in arguments.iter().zip(field_layouts.iter()) {
mono_args.push(( mono_args.push((
from_can_pattern_help(env, layout_cache, &loc_pat.value, assignments)?, from_can_pattern_help(
env,
procs,
layout_cache,
&loc_pat.value,
assignments,
)?,
*layout, *layout,
)); ));
} }
@ -8183,6 +8288,7 @@ fn from_can_pattern_help<'a>(
mono_args.push(( mono_args.push((
from_can_pattern_help( from_can_pattern_help(
env, env,
procs,
layout_cache, layout_cache,
&loc_pat.value, &loc_pat.value,
assignments, assignments,
@ -8228,6 +8334,7 @@ fn from_can_pattern_help<'a>(
mono_args.push(( mono_args.push((
from_can_pattern_help( from_can_pattern_help(
env, env,
procs,
layout_cache, layout_cache,
&loc_pat.value, &loc_pat.value,
assignments, assignments,
@ -8271,6 +8378,7 @@ fn from_can_pattern_help<'a>(
mono_args.push(( mono_args.push((
from_can_pattern_help( from_can_pattern_help(
env, env,
procs,
layout_cache, layout_cache,
&loc_pat.value, &loc_pat.value,
assignments, assignments,
@ -8344,6 +8452,7 @@ fn from_can_pattern_help<'a>(
mono_args.push(( mono_args.push((
from_can_pattern_help( from_can_pattern_help(
env, env,
procs,
layout_cache, layout_cache,
&loc_pat.value, &loc_pat.value,
assignments, assignments,
@ -8400,6 +8509,7 @@ fn from_can_pattern_help<'a>(
mono_args.push(( mono_args.push((
from_can_pattern_help( from_can_pattern_help(
env, env,
procs,
layout_cache, layout_cache,
&loc_pat.value, &loc_pat.value,
assignments, assignments,
@ -8430,8 +8540,13 @@ fn from_can_pattern_help<'a>(
let arg_layout = layout_cache let arg_layout = layout_cache
.from_var(env.arena, *arg_var, env.subs) .from_var(env.arena, *arg_var, env.subs)
.unwrap(); .unwrap();
let mono_arg_pattern = let mono_arg_pattern = from_can_pattern_help(
from_can_pattern_help(env, layout_cache, &loc_arg_pattern.value, assignments)?; env,
procs,
layout_cache,
&loc_arg_pattern.value,
assignments,
)?;
Ok(Pattern::OpaqueUnwrap { Ok(Pattern::OpaqueUnwrap {
opaque: *opaque, opaque: *opaque,
argument: Box::new((mono_arg_pattern, arg_layout)), argument: Box::new((mono_arg_pattern, arg_layout)),
@ -8474,6 +8589,7 @@ fn from_can_pattern_help<'a>(
// this field is destructured by the pattern // this field is destructured by the pattern
mono_destructs.push(from_can_record_destruct( mono_destructs.push(from_can_record_destruct(
env, env,
procs,
layout_cache, layout_cache,
&destruct.value, &destruct.value,
field_layout, field_layout,
@ -8565,6 +8681,7 @@ fn from_can_pattern_help<'a>(
fn from_can_record_destruct<'a>( fn from_can_record_destruct<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
can_rd: &roc_can::pattern::RecordDestruct, can_rd: &roc_can::pattern::RecordDestruct,
field_layout: Layout<'a>, field_layout: Layout<'a>,
@ -8581,7 +8698,7 @@ fn from_can_record_destruct<'a>(
DestructType::Required(can_rd.symbol) DestructType::Required(can_rd.symbol)
} }
roc_can::pattern::DestructType::Guard(_, loc_pattern) => DestructType::Guard( roc_can::pattern::DestructType::Guard(_, loc_pattern) => DestructType::Guard(
from_can_pattern_help(env, layout_cache, &loc_pattern.value, assignments)?, from_can_pattern_help(env, procs, layout_cache, &loc_pattern.value, assignments)?,
), ),
}, },
}) })