diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index ecfe8053d6..ffca4f94d6 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -770,6 +770,26 @@ impl<'a> Procs<'a> { needed_symbol_specializations: BumpMap::new_in(arena), } } + + /// Expects and removes a single specialization symbol for the given requested symbol. + fn remove_single_symbol_specialization(&mut self, symbol: Symbol) -> Option { + let mut specialized_symbols = self + .needed_symbol_specializations + .drain_filter(|(sym, _), _| sym == &symbol); + + let specialization_symbol = specialized_symbols + .next() + .map(|(_, (_, specialized_symbol))| specialized_symbol); + + debug_assert_eq!( + specialized_symbols.count(), + 0, + "Symbol {:?} has multiple specializations", + symbol + ); + + specialization_symbol + } } #[derive(Clone, Debug, PartialEq)] @@ -2468,11 +2488,11 @@ fn specialize_external<'a>( // An argument from the closure list may have taken on a specialized symbol // name during the evaluation of the def body. If this is the case, load the // specialized name rather than the original captured name! - let get_specialized_name = |symbol, layout| { + let mut get_specialized_name = |symbol, layout| { procs .needed_symbol_specializations - .get(&(symbol, layout)) - .map(|(_, specialized)| *specialized) + .remove(&(symbol, layout)) + .map(|(_, specialized)| specialized) .unwrap_or(symbol) }; @@ -3304,7 +3324,7 @@ pub fn with_hole<'a>( } else { // this may be a destructure pattern let (mono_pattern, assignments) = - match from_can_pattern(env, layout_cache, &def.loc_pattern.value) { + match from_can_pattern(env, procs, layout_cache, &def.loc_pattern.value) { Ok(v) => v, Err(_runtime_error) => { // todo @@ -5492,6 +5512,7 @@ pub fn from_can<'a>( } LetNonRec(def, cont, outer_annotation) => { if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { + // dbg!(symbol, &def.loc_expr.value); match def.loc_expr.value { roc_can::expr::Expr::Closure(closure_data) => { register_capturing_closure(env, procs, layout_cache, *symbol, closure_data); @@ -5706,7 +5727,7 @@ pub fn from_can<'a>( // this may be a destructure pattern let (mono_pattern, assignments) = - match from_can_pattern(env, layout_cache, &def.loc_pattern.value) { + match from_can_pattern(env, procs, layout_cache, &def.loc_pattern.value) { Ok(v) => v, Err(_) => todo!(), }; @@ -5737,8 +5758,22 @@ pub fn from_can<'a>( // layer on any default record fields for (symbol, variable, expr) in assignments { + let specialization_symbol = procs + .remove_single_symbol_specialization(symbol) + // Can happen when the symbol was never used under this body, and hence has no + // requested specialization. + .unwrap_or(symbol); + let hole = env.arena.alloc(stmt); - stmt = with_hole(env, expr, variable, procs, layout_cache, symbol, hole); + stmt = with_hole( + env, + expr, + variable, + procs, + layout_cache, + specialization_symbol, + hole, + ); } if let roc_can::expr::Expr::Var(outer_symbol) = def.loc_expr.value { @@ -5772,6 +5807,7 @@ pub fn from_can<'a>( fn to_opt_branches<'a>( env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, branches: std::vec::Vec, exhaustive_mark: ExhaustiveMark, layout_cache: &mut LayoutCache<'a>, @@ -5798,7 +5834,7 @@ fn to_opt_branches<'a>( } for loc_pattern in when_branch.patterns { - match from_can_pattern(env, layout_cache, &loc_pattern.value) { + match from_can_pattern(env, procs, layout_cache, &loc_pattern.value) { Ok((mono_pattern, assignments)) => { loc_branches.push(( Loc::at(loc_pattern.region, mono_pattern.clone()), @@ -5876,7 +5912,7 @@ fn from_can_when<'a>( // We can't know what to return! return Stmt::RuntimeError("Hit a 0-branch when expression"); } - let opt_branches = to_opt_branches(env, branches, exhaustive_mark, layout_cache); + let opt_branches = to_opt_branches(env, procs, branches, exhaustive_mark, layout_cache); let cond_layout = return_on_layout_error!(env, layout_cache.from_var(env.arena, cond_var, env.subs)); @@ -6341,7 +6377,15 @@ fn store_pattern_help<'a>( match can_pat { Identifier(symbol) => { - substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol); + // An identifier in a pattern can define at most one specialization! + // Remove any requested specializations for this name now, since this is the definition site. + let specialization_symbol = procs + .remove_single_symbol_specialization(*symbol) + // Can happen when the symbol was never used under this body, and hence has no + // requested specialization. + .unwrap_or(*symbol); + + substitute_in_exprs(env.arena, &mut stmt, specialization_symbol, outer_symbol); } Underscore => { // do nothing @@ -6402,7 +6446,18 @@ fn store_pattern_help<'a>( for destruct in destructs { match &destruct.typ { DestructType::Required(symbol) => { - substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol); + let specialization_symbol = procs + .remove_single_symbol_specialization(*symbol) + // Can happen when the symbol was never used under this body, and hence has no + // requested specialization. + .unwrap_or(*symbol); + + substitute_in_exprs( + env.arena, + &mut stmt, + specialization_symbol, + outer_symbol, + ); } DestructType::Guard(guard_pattern) => { return store_pattern_help( @@ -6480,10 +6535,11 @@ fn store_tag_pattern<'a>( match argument { Identifier(symbol) => { + // TODO: use procs.remove_single_symbol_specialization let symbol = procs .needed_symbol_specializations - .get(&(*symbol, arg_layout)) - .map(|(_, sym)| *sym) + .remove(&(*symbol, arg_layout)) + .map(|(_, sym)| sym) .unwrap_or(*symbol); // store immediately in the given symbol @@ -6562,8 +6618,19 @@ fn store_newtype_pattern<'a>( match argument { Identifier(symbol) => { - // store immediately in the given symbol - stmt = Stmt::Let(*symbol, load, arg_layout, env.arena.alloc(stmt)); + // store immediately in the given symbol, removing it specialization if it had any + let specialization_symbol = procs + .remove_single_symbol_specialization(*symbol) + // Can happen when the symbol was never used under this body, and hence has no + // requested specialization. + .unwrap_or(*symbol); + + stmt = Stmt::Let( + specialization_symbol, + load, + arg_layout, + env.arena.alloc(stmt), + ); is_productive = true; } Underscore => { @@ -6625,11 +6692,35 @@ fn store_record_destruct<'a>( match &destruct.typ { DestructType::Required(symbol) => { - stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt)); + // A destructure can define at most one specialization! + // Remove any requested specializations for this name now, since this is the definition site. + let specialization_symbol = procs + .remove_single_symbol_specialization(*symbol) + // Can happen when the symbol was never used under this body, and hence has no + // requested specialization. + .unwrap_or(*symbol); + + stmt = Stmt::Let( + specialization_symbol, + load, + destruct.layout, + env.arena.alloc(stmt), + ); } DestructType::Guard(guard_pattern) => match &guard_pattern { Identifier(symbol) => { - stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt)); + let specialization_symbol = procs + .remove_single_symbol_specialization(*symbol) + // Can happen when the symbol was never used under this body, and hence has no + // requested specialization. + .unwrap_or(*symbol); + + stmt = Stmt::Let( + specialization_symbol, + load, + destruct.layout, + env.arena.alloc(stmt), + ); } Underscore => { // important that this is special-cased to do nothing: mono record patterns will extract all the @@ -6816,48 +6907,53 @@ where return build_rest(env, procs, layout_cache); } - // Otherwise we're dealing with an alias to something that doesn't need to be specialized, or - // whose usages will already be specialized in the rest of the program. - if procs.is_imported_module_thunk(right) { - let result = build_rest(env, procs, layout_cache); + if procs.partial_procs.contains_key(right) { + // This is an alias to a function defined in this module. + // Attach the alias, then build the rest of the module, so that we reference and specialize + // the correct proc. + procs.partial_procs.insert_alias(left, right); + return build_rest(env, procs, layout_cache); + } + // Otherwise we're dealing with an alias whose usages will tell us what specializations we + // need. So let's figure those out first. + let result = build_rest(env, procs, layout_cache); + + // The specializations we wanted of the symbol on the LHS of this alias. + let needed_specializations_of_left = procs + .needed_symbol_specializations + .drain_filter(|(s, _), _| s == &left) + .collect::>(); + + if procs.is_imported_module_thunk(right) { // if this is an imported symbol, then we must make sure it is // specialized, and wrap the original in a function pointer. - add_needed_external(procs, env, variable, right); + let mut result = result; + for (_, (variable, left)) in needed_specializations_of_left.into_iter() { + add_needed_external(procs, env, variable, right); - let res_layout = layout_cache.from_var(env.arena, variable, env.subs); - let layout = return_on_layout_error!(env, res_layout); + let res_layout = layout_cache.from_var(env.arena, variable, env.subs); + let layout = return_on_layout_error!(env, res_layout); - force_thunk(env, right, layout, left, env.arena.alloc(result)) + result = force_thunk(env, right, layout, left, env.arena.alloc(result)); + } + result } else if env.is_imported_symbol(right) { - let result = build_rest(env, procs, layout_cache); - // if this is an imported symbol, then we must make sure it is // specialized, and wrap the original in a function pointer. add_needed_external(procs, env, variable, right); // then we must construct its closure; since imported symbols have no closure, we use the empty struct let_empty_struct(left, env.arena.alloc(result)) - } else if procs.partial_procs.contains_key(right) { - // This is an alias to a function defined in this module. - // Attach the alias, then build the rest of the module, so that we reference and specialize - // the correct proc. - procs.partial_procs.insert_alias(left, right); - build_rest(env, procs, layout_cache) } else { - // This should be a fully specialized value. Replace the alias with the original symbol. - let mut result = build_rest(env, procs, layout_cache); - // We need to lift all specializations of "left" to be specializations of "right". - let to_update = procs - .needed_symbol_specializations - .drain_filter(|(s, _), _| s == &left) - .collect::>(); let mut scratchpad_update_specializations = std::vec::Vec::new(); - let left_had_specialization_symbols = !to_update.is_empty(); + let left_had_specialization_symbols = !needed_specializations_of_left.is_empty(); - for ((_, layout), (specialized_var, specialized_sym)) in to_update.into_iter() { + for ((_, layout), (specialized_var, specialized_sym)) in + needed_specializations_of_left.into_iter() + { let old_specialized_sym = procs .needed_symbol_specializations .insert((right, layout), (specialized_var, specialized_sym)); @@ -6867,6 +6963,7 @@ where } } + let mut result = result; if left_had_specialization_symbols { // If the symbol is specialized, only the specializations need to be updated. for (old_specialized_sym, specialized_sym) in @@ -7894,6 +7991,7 @@ pub struct WhenBranch<'a> { #[allow(clippy::type_complexity)] fn from_can_pattern<'a>( env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, layout_cache: &mut LayoutCache<'a>, can_pattern: &roc_can::pattern::Pattern, ) -> Result< @@ -7904,13 +8002,14 @@ fn from_can_pattern<'a>( RuntimeError, > { let mut assignments = Vec::new_in(env.arena); - let pattern = from_can_pattern_help(env, layout_cache, can_pattern, &mut assignments)?; + let pattern = from_can_pattern_help(env, procs, layout_cache, can_pattern, &mut assignments)?; Ok((pattern, assignments)) } fn from_can_pattern_help<'a>( env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, layout_cache: &mut LayoutCache<'a>, can_pattern: &roc_can::pattern::Pattern, assignments: &mut Vec<'a, (Symbol, Variable, roc_can::expr::Expr)>, @@ -8105,7 +8204,13 @@ fn from_can_pattern_help<'a>( let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); for ((_, loc_pat), layout) in arguments.iter().zip(field_layouts.iter()) { mono_args.push(( - from_can_pattern_help(env, layout_cache, &loc_pat.value, assignments)?, + from_can_pattern_help( + env, + procs, + layout_cache, + &loc_pat.value, + assignments, + )?, *layout, )); } @@ -8183,6 +8288,7 @@ fn from_can_pattern_help<'a>( mono_args.push(( from_can_pattern_help( env, + procs, layout_cache, &loc_pat.value, assignments, @@ -8228,6 +8334,7 @@ fn from_can_pattern_help<'a>( mono_args.push(( from_can_pattern_help( env, + procs, layout_cache, &loc_pat.value, assignments, @@ -8271,6 +8378,7 @@ fn from_can_pattern_help<'a>( mono_args.push(( from_can_pattern_help( env, + procs, layout_cache, &loc_pat.value, assignments, @@ -8344,6 +8452,7 @@ fn from_can_pattern_help<'a>( mono_args.push(( from_can_pattern_help( env, + procs, layout_cache, &loc_pat.value, assignments, @@ -8400,6 +8509,7 @@ fn from_can_pattern_help<'a>( mono_args.push(( from_can_pattern_help( env, + procs, layout_cache, &loc_pat.value, assignments, @@ -8430,8 +8540,13 @@ fn from_can_pattern_help<'a>( let arg_layout = layout_cache .from_var(env.arena, *arg_var, env.subs) .unwrap(); - let mono_arg_pattern = - from_can_pattern_help(env, layout_cache, &loc_arg_pattern.value, assignments)?; + let mono_arg_pattern = from_can_pattern_help( + env, + procs, + layout_cache, + &loc_arg_pattern.value, + assignments, + )?; Ok(Pattern::OpaqueUnwrap { opaque: *opaque, argument: Box::new((mono_arg_pattern, arg_layout)), @@ -8474,6 +8589,7 @@ fn from_can_pattern_help<'a>( // this field is destructured by the pattern mono_destructs.push(from_can_record_destruct( env, + procs, layout_cache, &destruct.value, field_layout, @@ -8565,6 +8681,7 @@ fn from_can_pattern_help<'a>( fn from_can_record_destruct<'a>( env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, layout_cache: &mut LayoutCache<'a>, can_rd: &roc_can::pattern::RecordDestruct, field_layout: Layout<'a>, @@ -8581,7 +8698,7 @@ fn from_can_record_destruct<'a>( DestructType::Required(can_rd.symbol) } roc_can::pattern::DestructType::Guard(_, loc_pattern) => DestructType::Guard( - from_can_pattern_help(env, layout_cache, &loc_pattern.value, assignments)?, + from_can_pattern_help(env, procs, layout_cache, &loc_pattern.value, assignments)?, ), }, })