From 8be230695bd373d66b39ed6129487da15512d59e Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Tue, 28 Jun 2022 17:13:15 -0400 Subject: [PATCH] Get multimorphic lambda captures working --- crates/compiler/collections/src/vec_map.rs | 4 + crates/compiler/load_internal/src/file.rs | 26 +- crates/compiler/module/src/symbol.rs | 25 +- crates/compiler/mono/src/borrow.rs | 2 +- crates/compiler/mono/src/ir.rs | 277 +++-------- crates/compiler/mono/src/layout.rs | 445 ++++++++++++------ .../multimorphic_lambda_set_capture.txt | 49 ++ crates/compiler/test_mono/src/tests.rs | 16 +- crates/compiler/types/src/pretty_print.rs | 37 +- 9 files changed, 458 insertions(+), 423 deletions(-) create mode 100644 crates/compiler/test_mono/generated/multimorphic_lambda_set_capture.txt diff --git a/crates/compiler/collections/src/vec_map.rs b/crates/compiler/collections/src/vec_map.rs index 827f2ad293..7b117ee842 100644 --- a/crates/compiler/collections/src/vec_map.rs +++ b/crates/compiler/collections/src/vec_map.rs @@ -97,6 +97,10 @@ impl VecMap { self.keys.iter().zip(self.values.iter()) } + pub fn iter_mut(&mut self) -> impl ExactSizeIterator { + self.keys.iter().zip(self.values.iter_mut()) + } + pub fn keys(&self) -> impl ExactSizeIterator { self.keys.iter() } diff --git a/crates/compiler/load_internal/src/file.rs b/crates/compiler/load_internal/src/file.rs index e67aa4cfc4..328fd7542f 100644 --- a/crates/compiler/load_internal/src/file.rs +++ b/crates/compiler/load_internal/src/file.rs @@ -29,12 +29,11 @@ use roc_module::symbol::{ IdentIds, IdentIdsByModule, Interns, ModuleId, ModuleIds, PQModuleName, PackageModuleIds, PackageQualified, Symbol, }; -use roc_mono::fresh_multimorphic_symbol; use roc_mono::ir::{ CapturedSymbols, EntryPoint, ExternalSpecializations, PartialProc, Proc, ProcLayout, Procs, ProcsBase, UpdateModeIds, }; -use roc_mono::layout::{LambdaName, Layout, LayoutCache, LayoutProblem}; +use roc_mono::layout::{LambdaName, Layout, LayoutCache, LayoutProblem, MultimorphicNames}; use roc_parse::ast::{self, Defs, ExtractSpaces, Spaced, StrLiteral, TypeAnnotation}; use roc_parse::header::{ExposedName, ImportsEntry, PackageEntry, PlatformHeader, To, TypedIdent}; use roc_parse::header::{HeaderFor, ModuleNameEnum, PackageName}; @@ -429,6 +428,7 @@ fn start_phase<'a>( exposed_to_host: state.exposed_to_host.clone(), abilities_store, derived_symbols, + multimorphic_names: state.multimorphic_names.clone(), } } Phase::MakeSpecializations => { @@ -498,6 +498,7 @@ fn start_phase<'a>( module_timing, world_abilities: state.world_abilities.clone_ref(), derived_symbols, + multimorphic_names: state.multimorphic_names.clone(), } } } @@ -827,6 +828,8 @@ struct State<'a> { pub render: RenderTarget, + pub multimorphic_names: MultimorphicNames, + /// All abilities across all modules. pub world_abilities: WorldAbilities, @@ -877,6 +880,7 @@ impl<'a> State<'a> { layout_caches: std::vec::Vec::with_capacity(number_of_workers), cached_subs: Arc::new(Mutex::new(cached_subs)), render, + multimorphic_names: MultimorphicNames::default(), make_specializations_pass: MakeSpecializationsPass::Pass(1), world_abilities: Default::default(), } @@ -1001,6 +1005,7 @@ enum BuildTask<'a> { exposed_to_host: ExposedToHost, abilities_store: AbilitiesStore, derived_symbols: GlobalDerivedSymbols, + multimorphic_names: MultimorphicNames, }, MakeSpecializations { module_id: ModuleId, @@ -1012,6 +1017,7 @@ enum BuildTask<'a> { module_timing: ModuleTiming, world_abilities: WorldAbilities, derived_symbols: GlobalDerivedSymbols, + multimorphic_names: MultimorphicNames, }, } @@ -4413,6 +4419,7 @@ fn make_specializations<'a>( target_info: TargetInfo, world_abilities: WorldAbilities, derived_symbols: GlobalDerivedSymbols, + mut multimorphic_names: MultimorphicNames, ) -> Msg<'a> { let make_specializations_start = SystemTime::now(); let mut update_mode_ids = UpdateModeIds::new(); @@ -4428,6 +4435,7 @@ fn make_specializations<'a>( call_specialization_counter: 1, abilities: AbilitiesView::World(world_abilities), derived_symbols: &derived_symbols, + multimorphic_names: &mut multimorphic_names, }; let mut procs = Procs::new_in(arena); @@ -4491,6 +4499,7 @@ fn build_pending_specializations<'a>( exposed_to_host: ExposedToHost, // TODO remove abilities_store: AbilitiesStore, derived_symbols: GlobalDerivedSymbols, + mut multimorphic_names: MultimorphicNames, ) -> Msg<'a> { let find_specializations_start = SystemTime::now(); @@ -4520,6 +4529,7 @@ fn build_pending_specializations<'a>( // do we need a global view. abilities: AbilitiesView::Module(&abilities_store), derived_symbols: &derived_symbols, + multimorphic_names: &mut multimorphic_names, }; // Add modules' decls to Procs @@ -4550,7 +4560,7 @@ fn build_pending_specializations<'a>( mono_env.arena, expr_var, mono_env.subs, - fresh_multimorphic_symbol!(mono_env), + mono_env.multimorphic_names, ); // cannot specialize when e.g. main's type contains type variables @@ -4614,7 +4624,7 @@ fn build_pending_specializations<'a>( mono_env.arena, expr_var, mono_env.subs, - fresh_multimorphic_symbol!(mono_env), + mono_env.multimorphic_names, ); // cannot specialize when e.g. main's type contains type variables @@ -4696,7 +4706,7 @@ fn build_pending_specializations<'a>( mono_env.arena, expr_var, mono_env.subs, - fresh_multimorphic_symbol!(mono_env), + mono_env.multimorphic_names, ); // cannot specialize when e.g. main's type contains type variables @@ -4760,7 +4770,7 @@ fn build_pending_specializations<'a>( mono_env.arena, expr_var, mono_env.subs, - fresh_multimorphic_symbol!(mono_env), + mono_env.multimorphic_names, ); // cannot specialize when e.g. main's type contains type variables @@ -4914,6 +4924,7 @@ fn run_task<'a>( exposed_to_host, abilities_store, derived_symbols, + multimorphic_names, } => Ok(build_pending_specializations( arena, solved_subs, @@ -4927,6 +4938,7 @@ fn run_task<'a>( exposed_to_host, abilities_store, derived_symbols, + multimorphic_names, )), MakeSpecializations { module_id, @@ -4938,6 +4950,7 @@ fn run_task<'a>( module_timing, world_abilities, derived_symbols, + multimorphic_names, } => Ok(make_specializations( arena, module_id, @@ -4950,6 +4963,7 @@ fn run_task<'a>( target_info, world_abilities, derived_symbols, + multimorphic_names, )), }?; diff --git a/crates/compiler/module/src/symbol.rs b/crates/compiler/module/src/symbol.rs index 4e2614d286..dd3c534252 100644 --- a/crates/compiler/module/src/symbol.rs +++ b/crates/compiler/module/src/symbol.rs @@ -998,7 +998,10 @@ define_builtins! { // Fake module for storing derived function symbols 1 DERIVED: "#Derived" => { } - 2 NUM: "Num" => { + // Fake module for storing fresh multimorphic function symbol names + 2 MULTIMORPHIC: "#Multimorphic" => { + } + 3 NUM: "Num" => { 0 NUM_NUM: "Num" // the Num.Num type alias 1 NUM_I128: "I128" // the Num.I128 type alias 2 NUM_U128: "U128" // the Num.U128 type alias @@ -1141,7 +1144,7 @@ define_builtins! { 139 NUM_MAX_F64: "maxF64" 140 NUM_MIN_F64: "minF64" } - 3 BOOL: "Bool" => { + 4 BOOL: "Bool" => { 0 BOOL_BOOL: "Bool" // the Bool.Bool type alias 1 BOOL_FALSE: "False" imported // Bool.Bool = [False, True] // NB: not strictly needed; used for finding tag names in error suggestions @@ -1154,7 +1157,7 @@ define_builtins! { 7 BOOL_EQ: "isEq" 8 BOOL_NEQ: "isNotEq" } - 4 STR: "Str" => { + 5 STR: "Str" => { 0 STR_STR: "Str" imported // the Str.Str type alias 1 STR_IS_EMPTY: "isEmpty" 2 STR_APPEND: "#append" // unused @@ -1191,7 +1194,7 @@ define_builtins! { 33 STR_TO_I8: "toI8" 34 STR_TO_SCALARS: "toScalars" } - 5 LIST: "List" => { + 6 LIST: "List" => { 0 LIST_LIST: "List" imported // the List.List type alias 1 LIST_IS_EMPTY: "isEmpty" 2 LIST_GET: "get" @@ -1257,7 +1260,7 @@ define_builtins! { 62 LIST_WITH_CAPACITY: "withCapacity" 63 LIST_ITERATE: "iterate" } - 6 RESULT: "Result" => { + 7 RESULT: "Result" => { 0 RESULT_RESULT: "Result" // the Result.Result type alias 1 RESULT_OK: "Ok" imported // Result.Result a e = [Ok a, Err e] // NB: not strictly needed; used for finding tag names in error suggestions @@ -1270,7 +1273,7 @@ define_builtins! { 7 RESULT_IS_OK: "isOk" 8 RESULT_IS_ERR: "isErr" } - 7 DICT: "Dict" => { + 8 DICT: "Dict" => { 0 DICT_DICT: "Dict" imported // the Dict.Dict type alias 1 DICT_EMPTY: "empty" 2 DICT_SINGLE: "single" @@ -1289,7 +1292,7 @@ define_builtins! { 13 DICT_INTERSECTION: "intersection" 14 DICT_DIFFERENCE: "difference" } - 8 SET: "Set" => { + 9 SET: "Set" => { 0 SET_SET: "Set" imported // the Set.Set type alias 1 SET_EMPTY: "empty" 2 SET_SINGLE: "single" @@ -1306,12 +1309,12 @@ define_builtins! { 13 SET_CONTAINS: "contains" 14 SET_TO_DICT: "toDict" } - 9 BOX: "Box" => { + 10 BOX: "Box" => { 0 BOX_BOX_TYPE: "Box" imported // the Box.Box opaque type 1 BOX_BOX_FUNCTION: "box" // Box.box 2 BOX_UNBOX: "unbox" } - 10 ENCODE: "Encode" => { + 11 ENCODE: "Encode" => { 0 ENCODE_ENCODER: "Encoder" 1 ENCODE_ENCODING: "Encoding" 2 ENCODE_TO_ENCODER: "toEncoder" @@ -1339,9 +1342,9 @@ define_builtins! { 24 ENCODE_APPEND: "append" 25 ENCODE_TO_BYTES: "toBytes" } - 11 JSON: "Json" => { + 12 JSON: "Json" => { 0 JSON_JSON: "Json" } - num_modules: 12 // Keep this count up to date by hand! (TODO: see the mut_map! macro for how we could determine this count correctly in the macro) + num_modules: 13 // Keep this count up to date by hand! (TODO: see the mut_map! macro for how we could determine this count correctly in the macro) } diff --git a/crates/compiler/mono/src/borrow.rs b/crates/compiler/mono/src/borrow.rs index 92f8f39d80..293341c95a 100644 --- a/crates/compiler/mono/src/borrow.rs +++ b/crates/compiler/mono/src/borrow.rs @@ -163,7 +163,7 @@ impl<'a> DeclarationToIndex<'a> { } } unreachable!( - "symbol/layout {:?} {:?} combo must be in DeclarationToIndex", + "symbol/layout {:?} {:#?} combo must be in DeclarationToIndex", needle_symbol, needle_layout ) } diff --git a/crates/compiler/mono/src/ir.rs b/crates/compiler/mono/src/ir.rs index 6eec3a6e5f..3eba5975ed 100644 --- a/crates/compiler/mono/src/ir.rs +++ b/crates/compiler/mono/src/ir.rs @@ -2,7 +2,7 @@ use crate::layout::{ Builtin, ClosureRepresentation, LambdaName, LambdaSet, Layout, LayoutCache, LayoutProblem, - RawFunctionLayout, TagIdIntType, TagOrClosure, UnionLayout, WrappedVariant, + MultimorphicNames, RawFunctionLayout, TagIdIntType, TagOrClosure, UnionLayout, WrappedVariant, }; use bumpalo::collections::{CollectIn, Vec}; use bumpalo::Bump; @@ -706,7 +706,6 @@ impl<'a> Specialized<'a> { } fn mark_in_progress(&mut self, symbol: Symbol, layout: ProcLayout<'a>) { - // dbg!((symbol, layout)); for (i, s) in self.symbols.iter().enumerate() { if *s == symbol && self.proc_layouts[i] == layout { match &self.procedures[i] { @@ -727,7 +726,6 @@ impl<'a> Specialized<'a> { } fn remove_specialized(&mut self, symbol: Symbol, layout: &ProcLayout<'a>) -> bool { - // dbg!((symbol, layout)); let mut index = None; for (i, s) in self.symbols.iter().enumerate() { @@ -746,7 +744,6 @@ impl<'a> Specialized<'a> { } fn insert_specialized(&mut self, symbol: Symbol, layout: ProcLayout<'a>, proc: Proc<'a>) { - // dbg!((symbol, layout)); for (i, s) in self.symbols.iter().enumerate() { if *s == symbol && self.proc_layouts[i] == layout { match &self.procedures[i] { @@ -814,16 +811,6 @@ struct SymbolSpecializations<'a>( VecMap, (Variable, Symbol)>>, ); -#[macro_export] -macro_rules! fresh_multimorphic_symbol { - ($env:expr) => { - &mut || { - let ident_id = $env.ident_ids.gen_unique(); - Symbol::new($env.home, ident_id) - } - }; -} - impl<'a> SymbolSpecializations<'a> { /// Gets a specialization for a symbol, or creates a new one. #[inline(always)] @@ -837,17 +824,13 @@ impl<'a> SymbolSpecializations<'a> { let arena = env.arena; let subs: &Subs = env.subs; - let layout = match layout_cache.from_var( - arena, - specialization_var, - subs, - fresh_multimorphic_symbol!(env), - ) { - Ok(layout) => layout, - // This can happen when the def symbol has a type error. In such cases just use the - // def symbol, which is erroring. - Err(_) => return symbol, - }; + let layout = + match layout_cache.from_var(arena, specialization_var, subs, env.multimorphic_names) { + Ok(layout) => layout, + // This can happen when the def symbol has a type error. In such cases just use the + // def symbol, which is erroring. + Err(_) => return symbol, + }; let is_closure = matches!( subs.get_content_without_compacting(specialization_var), @@ -858,7 +841,7 @@ impl<'a> SymbolSpecializations<'a> { arena, specialization_var, subs, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) { Ok(layout) => layout, // This can happen when the def symbol has a type error. In such cases just use the @@ -1041,12 +1024,7 @@ impl<'a> Procs<'a> { layout_cache: &mut LayoutCache<'a>, ) -> Result, RuntimeError> { let raw_layout = layout_cache - .raw_from_var( - env.arena, - annotation, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .raw_from_var(env.arena, annotation, env.subs, env.multimorphic_names) .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); let top_level = ProcLayout::from_raw(env.arena, raw_layout); @@ -1327,6 +1305,7 @@ pub struct Env<'a, 'i> { pub call_specialization_counter: u32, pub abilities: AbilitiesView<'i>, pub derived_symbols: &'i GlobalDerivedSymbols, + pub multimorphic_names: &'i mut MultimorphicNames, } impl<'a, 'i> Env<'a, 'i> { @@ -2924,12 +2903,7 @@ fn specialize_external<'a>( for (symbol, variable) in host_exposed_variables { let layout = layout_cache - .raw_from_var( - env.arena, - *variable, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .raw_from_var(env.arena, *variable, env.subs, env.multimorphic_names) .unwrap(); let name = env.unique_symbol(); @@ -3257,7 +3231,7 @@ fn build_specialized_proc_from_var<'a>( pattern_symbols: &[Symbol], fn_var: Variable, ) -> Result, LayoutProblem> { - match layout_cache.raw_from_var(env.arena, fn_var, env.subs, fresh_multimorphic_symbol!(env))? { + match layout_cache.raw_from_var(env.arena, fn_var, env.subs, env.multimorphic_names)? { RawFunctionLayout::Function(pattern_layouts, closure_layout, ret_layout) => { let mut pattern_layouts_vec = Vec::with_capacity_in(pattern_layouts.len(), env.arena); pattern_layouts_vec.extend_from_slice(pattern_layouts); @@ -3476,7 +3450,7 @@ where // for debugging only let raw = layout_cache - .raw_from_var(env.arena, fn_var, env.subs, fresh_multimorphic_symbol!(env)) + .raw_from_var(env.arena, fn_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); let raw = if procs.is_module_thunk(proc_name.source_name()) { @@ -3595,12 +3569,7 @@ fn specialize_naked_symbol<'a>( return result; } else if env.is_imported_symbol(symbol) { - match layout_cache.from_var( - env.arena, - variable, - env.subs, - fresh_multimorphic_symbol!(env), - ) { + match layout_cache.from_var(env.arena, variable, env.subs, env.multimorphic_names) { Err(e) => panic!("invalid layout {:?}", e), Ok(_) => { // this is a 0-arity thunk @@ -3983,7 +3952,7 @@ pub fn with_hole<'a>( record_var, env.subs, env.target_info, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) { Ok(fields) => fields, Err(_) => return Stmt::RuntimeError("Can't create record with improper layout"), @@ -4037,12 +4006,7 @@ pub fn with_hole<'a>( // creating a record from the var will unpack it if it's just a single field. let layout = layout_cache - .from_var( - env.arena, - record_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, record_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); let field_symbols = field_symbols.into_bump_slice(); @@ -4102,18 +4066,8 @@ pub fn with_hole<'a>( final_else, } => { match ( - layout_cache.from_var( - env.arena, - branch_var, - env.subs, - fresh_multimorphic_symbol!(env), - ), - layout_cache.from_var( - env.arena, - cond_var, - env.subs, - fresh_multimorphic_symbol!(env), - ), + layout_cache.from_var(env.arena, branch_var, env.subs, env.multimorphic_names), + layout_cache.from_var(env.arena, cond_var, env.subs, env.multimorphic_names), ) { (Ok(ret_layout), Ok(cond_layout)) => { // if the hole is a return, then we don't need to merge the two @@ -4212,12 +4166,7 @@ pub fn with_hole<'a>( } let layout = layout_cache - .from_var( - env.arena, - branch_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, branch_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| { panic!("TODO turn fn_var into a RuntimeError {:?}", err) }); @@ -4284,12 +4233,7 @@ pub fn with_hole<'a>( ); let layout = layout_cache - .from_var( - env.arena, - expr_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, expr_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); let param = Param { @@ -4312,12 +4256,8 @@ pub fn with_hole<'a>( .. } if loc_elems.is_empty() => { // because an empty list has an unknown element type, it is handled differently - let opt_elem_layout = layout_cache.from_var( - env.arena, - elem_var, - env.subs, - fresh_multimorphic_symbol!(env), - ); + let opt_elem_layout = + layout_cache.from_var(env.arena, elem_var, env.subs, env.multimorphic_names); match opt_elem_layout { Ok(elem_layout) => { @@ -4371,12 +4311,7 @@ pub fn with_hole<'a>( let arg_symbols = arg_symbols.into_bump_slice(); let elem_layout = layout_cache - .from_var( - env.arena, - elem_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, elem_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); let expr = Expr::Array { @@ -4412,7 +4347,7 @@ pub fn with_hole<'a>( record_var, env.subs, env.target_info, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) { Ok(fields) => fields, Err(_) => return Stmt::RuntimeError("Can't access record with improper layout"), @@ -4463,12 +4398,7 @@ pub fn with_hole<'a>( }; let layout = layout_cache - .from_var( - env.arena, - field_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, field_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| { panic!("TODO turn fn_var into a RuntimeError {:?}", err) }); @@ -4520,7 +4450,7 @@ pub fn with_hole<'a>( env.arena, function_type, env.subs, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) ); @@ -4573,7 +4503,7 @@ pub fn with_hole<'a>( record_var, env.subs, env.target_info, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) { Ok(fields) => fields, Err(_) => return Stmt::RuntimeError("Can't update record with improper layout"), @@ -4618,12 +4548,7 @@ pub fn with_hole<'a>( let symbols = symbols.into_bump_slice(); let record_layout = layout_cache - .from_var( - env.arena, - record_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, record_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); let field_layouts = match &record_layout { @@ -4738,7 +4663,7 @@ pub fn with_hole<'a>( env.arena, function_type, env.subs, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ); match return_on_layout_error!(env, raw) { @@ -4880,7 +4805,7 @@ pub fn with_hole<'a>( env.arena, fn_var, env.subs, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) ); @@ -5064,12 +4989,7 @@ pub fn with_hole<'a>( // layout of the return type let layout = return_on_layout_error!( env, - layout_cache.from_var( - env.arena, - ret_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + layout_cache.from_var(env.arena, ret_var, env.subs, env.multimorphic_names,) ); let call = self::Call { @@ -5107,12 +5027,7 @@ pub fn with_hole<'a>( // layout of the return type let layout = return_on_layout_error!( env, - layout_cache.from_var( - env.arena, - ret_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + layout_cache.from_var(env.arena, ret_var, env.subs, env.multimorphic_names,) ); macro_rules! match_on_closure_argument { @@ -5123,7 +5038,7 @@ pub fn with_hole<'a>( let closure_data_layout = return_on_layout_error!( env, - layout_cache.raw_from_var(env.arena, closure_data_var, env.subs,fresh_multimorphic_symbol!(env),) + layout_cache.raw_from_var(env.arena, closure_data_var, env.subs,env.multimorphic_names,) ); let top_level = ProcLayout::from_raw(env.arena, closure_data_layout); @@ -5371,7 +5286,7 @@ where .into_iter() .map(|(_, var)| { layout_cache - .from_var(env.arena, *var, env.subs, fresh_multimorphic_symbol!(env)) + .from_var(env.arena, *var, env.subs, env.multimorphic_names) .expect("layout problem for capture") }) .collect_in::>(env.arena); @@ -5513,7 +5428,7 @@ fn convert_tag_union<'a>( variant_var, env.subs, env.target_info, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ); let variant = match res_variant { Ok(cached) => cached, @@ -5573,12 +5488,7 @@ fn convert_tag_union<'a>( // Layout will unpack this unwrapped tack if it only has one (non-zero-sized) field let layout = layout_cache - .from_var( - env.arena, - variant_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, variant_var, env.subs, env.multimorphic_names) .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); // even though this was originally a Tag, we treat it as a Struct from now on @@ -5606,12 +5516,7 @@ fn convert_tag_union<'a>( // version is not the same as the minimal version. let union_layout = match return_on_layout_error!( env, - layout_cache.from_var( - env.arena, - variant_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + layout_cache.from_var(env.arena, variant_var, env.subs, env.multimorphic_names,) ) { Layout::Union(ul) => ul, _ => unreachable!(), @@ -5809,12 +5714,7 @@ fn tag_union_to_function<'a>( // only need to construct closure data let raw_layout = return_on_layout_error!( env, - layout_cache.raw_from_var( - env.arena, - whole_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + layout_cache.raw_from_var(env.arena, whole_var, env.subs, env.multimorphic_names,) ); match raw_layout { @@ -5858,12 +5758,7 @@ fn sorted_field_symbols<'a>( for (var, mut arg) in args.drain(..) { // Layout will unpack this unwrapped tag if it only has one (non-zero-sized) field - let layout = match layout_cache.from_var( - env.arena, - var, - env.subs, - fresh_multimorphic_symbol!(env), - ) { + let layout = match layout_cache.from_var(env.arena, var, env.subs, env.multimorphic_names) { Ok(cached) => cached, Err(LayoutProblem::UnresolvedTypeVar(_)) => { // this argument has type `forall a. a`, which is isomorphic to @@ -5966,7 +5861,7 @@ fn register_capturing_closure<'a>( env.subs, closure_var, env.target_info, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) { Ok(lambda_set) => { if let Layout::Struct { @@ -6002,7 +5897,7 @@ fn register_capturing_closure<'a>( env.arena, function_type, env.subs, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ), env.subs, (function_type, closure_type), @@ -6082,20 +5977,10 @@ pub fn from_can<'a>( final_else, } => { let ret_layout = layout_cache - .from_var( - env.arena, - branch_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, branch_var, env.subs, env.multimorphic_names) .expect("invalid ret_layout"); let cond_layout = layout_cache - .from_var( - env.arena, - cond_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, cond_var, env.subs, env.multimorphic_names) .expect("invalid cond_layout"); let mut stmt = from_can(env, branch_var, final_else.value, procs, layout_cache); @@ -6139,12 +6024,8 @@ pub fn from_can<'a>( let mut layouts = Vec::with_capacity_in(lookups_in_cond.len(), env.arena); for (_, var) in lookups_in_cond { - let res_layout = layout_cache.from_var( - env.arena, - var, - env.subs, - fresh_multimorphic_symbol!(env), - ); + let res_layout = + layout_cache.from_var(env.arena, var, env.subs, env.multimorphic_names); let layout = return_on_layout_error!(env, res_layout); layouts.push(layout); } @@ -6313,22 +6194,12 @@ fn from_can_when<'a>( let cond_layout = return_on_layout_error!( env, - layout_cache.from_var( - env.arena, - cond_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + layout_cache.from_var(env.arena, cond_var, env.subs, env.multimorphic_names,) ); let ret_layout = return_on_layout_error!( env, - layout_cache.from_var( - env.arena, - expr_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + layout_cache.from_var(env.arena, expr_var, env.subs, env.multimorphic_names,) ); let arena = env.arena; @@ -7327,12 +7198,8 @@ where LambdaName::from_non_multimorphic(right), ); - let res_layout = layout_cache.from_var( - env.arena, - variable, - env.subs, - fresh_multimorphic_symbol!(env), - ); + let res_layout = + layout_cache.from_var(env.arena, variable, env.subs, env.multimorphic_names); let layout = return_on_layout_error!(env, res_layout); result = force_thunk(env, right, layout, left, env.arena.alloc(result)); @@ -7431,7 +7298,7 @@ fn specialize_symbol<'a>( env.arena, arg_var, env.subs, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) { Ok(v) => v, Err(e) => return_on_layout_error_help!(env, e), @@ -7497,12 +7364,7 @@ fn specialize_symbol<'a>( // to it in the IR. let res_layout = return_on_layout_error!( env, - layout_cache.raw_from_var( - env.arena, - arg_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + layout_cache.raw_from_var(env.arena, arg_var, env.subs, env.multimorphic_names,) ); // we have three kinds of functions really. Plain functions, closures by capture, @@ -7739,7 +7601,7 @@ fn call_by_name<'a>( hole: &'a Stmt<'a>, ) -> Stmt<'a> { // Register a pending_specialization for this function - match layout_cache.raw_from_var(env.arena, fn_var, env.subs, fresh_multimorphic_symbol!(env)) { + match layout_cache.raw_from_var(env.arena, fn_var, env.subs, env.multimorphic_names) { Err(LayoutProblem::UnresolvedTypeVar(var)) => { let msg = format!( "Hit an unresolved type variable {:?} when creating a layout for {:?} (var {:?})", @@ -7896,7 +7758,7 @@ fn call_by_name_help<'a>( // the variables of the given arguments let mut pattern_vars = Vec::with_capacity_in(loc_args.len(), arena); for (var, _) in &loc_args { - match layout_cache.from_var(env.arena, *var, env.subs, fresh_multimorphic_symbol!(env)) { + match layout_cache.from_var(env.arena, *var, env.subs, env.multimorphic_names) { Ok(_) => { pattern_vars.push(*var); } @@ -8568,7 +8430,7 @@ fn from_can_pattern_help<'a>( *whole_var, env.subs, env.target_info, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) .map_err(Into::into); @@ -8653,12 +8515,12 @@ fn from_can_pattern_help<'a>( arguments.sort_by(|arg1, arg2| { let size1 = layout_cache - .from_var(env.arena, arg1.0, env.subs, fresh_multimorphic_symbol!(env)) + .from_var(env.arena, arg1.0, env.subs, env.multimorphic_names) .map(|x| x.alignment_bytes(env.target_info)) .unwrap_or(0); let size2 = layout_cache - .from_var(env.arena, arg2.0, env.subs, fresh_multimorphic_symbol!(env)) + .from_var(env.arena, arg2.0, env.subs, env.multimorphic_names) .map(|x| x.alignment_bytes(env.target_info)) .unwrap_or(0); @@ -8694,20 +8556,10 @@ fn from_can_pattern_help<'a>( temp.sort_by(|arg1, arg2| { let layout1 = layout_cache - .from_var( - env.arena, - arg1.0, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, arg1.0, env.subs, env.multimorphic_names) .unwrap(); let layout2 = layout_cache - .from_var( - env.arena, - arg2.0, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, arg2.0, env.subs, env.multimorphic_names) .unwrap(); let size1 = layout1.alignment_bytes(env.target_info); @@ -8727,7 +8579,7 @@ fn from_can_pattern_help<'a>( env.arena, *whole_var, env.subs, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) { Ok(Layout::Union(ul)) => ul, _ => unreachable!(), @@ -9019,12 +8871,7 @@ fn from_can_pattern_help<'a>( } => { let (arg_var, loc_arg_pattern) = &(**argument); let arg_layout = layout_cache - .from_var( - env.arena, - *arg_var, - env.subs, - fresh_multimorphic_symbol!(env), - ) + .from_var(env.arena, *arg_var, env.subs, env.multimorphic_names) .unwrap(); let mono_arg_pattern = from_can_pattern_help( env, @@ -9050,7 +8897,7 @@ fn from_can_pattern_help<'a>( *whole_var, env.subs, env.target_info, - fresh_multimorphic_symbol!(env), + env.multimorphic_names, ) .map_err(RuntimeError::from)?; diff --git a/crates/compiler/mono/src/layout.rs b/crates/compiler/mono/src/layout.rs index 63d9433829..6b32423df6 100644 --- a/crates/compiler/mono/src/layout.rs +++ b/crates/compiler/mono/src/layout.rs @@ -1,14 +1,14 @@ use crate::ir::Parens; -use bumpalo::collections::Vec; +use bumpalo::collections::{CollectIn, Vec}; use bumpalo::Bump; use roc_builtins::bitcode::{FloatWidth, IntWidth}; use roc_collections::all::{default_hasher, MutMap}; +use roc_collections::VecMap; use roc_error_macros::{internal_error, todo_abilities}; use roc_module::ident::{Lowercase, TagName}; -use roc_module::symbol::{Interns, Symbol}; +use roc_module::symbol::{IdentIds, Interns, ModuleId, Symbol}; use roc_problem::can::RuntimeError; use roc_target::{PtrWidth, TargetInfo}; -use roc_types::pretty_print::ResolvedLambdaSet; use roc_types::subs::{ self, Content, FlatType, Label, RecordFields, Subs, UnionTags, UnsortedUnionLabels, Variable, }; @@ -17,6 +17,7 @@ use std::cmp::Ordering; use std::collections::hash_map::{DefaultHasher, Entry}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; use ven_pretty::{DocAllocator, DocBuilder}; // if your changes cause this number to go down, great! @@ -67,8 +68,8 @@ impl<'a> RawFunctionLayout<'a> { matches!(self, RawFunctionLayout::ZeroArgumentThunk(_)) } - fn new_help<'b, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, 'b, F>, + fn new_help<'b>( + env: &mut Env<'a, 'b>, var: Variable, content: Content, ) -> Result { @@ -153,8 +154,8 @@ impl<'a> RawFunctionLayout<'a> { } } - fn layout_from_lambda_set( - _env: &mut Env<'a, '_, F>, + fn layout_from_lambda_set( + _env: &mut Env<'a, '_>, _lset: subs::LambdaSet, ) -> Result { unreachable!() @@ -162,8 +163,8 @@ impl<'a> RawFunctionLayout<'a> { // Self::layout_from_flat_type(env, lset.as_tag_union()) } - fn layout_from_flat_type( - env: &mut Env<'a, '_, F>, + fn layout_from_flat_type( + env: &mut Env<'a, '_>, flat_type: FlatType, ) -> Result { use roc_types::subs::FlatType::*; @@ -189,7 +190,7 @@ impl<'a> RawFunctionLayout<'a> { env.subs, closure_var, env.target_info, - env.fresh_multimorphic_symbol, + env.multimorphic_names, )?; Ok(Self::Function(fn_args, lambda_set, ret)) @@ -221,10 +222,7 @@ impl<'a> RawFunctionLayout<'a> { /// Returns Err(()) if given an error, or Ok(Layout) if given a non-erroneous Structure. /// Panics if given a FlexVar or RigidVar, since those should have been /// monomorphized away already! - fn from_var( - env: &mut Env<'a, '_, F>, - var: Variable, - ) -> Result { + fn from_var(env: &mut Env<'a, '_>, var: Variable) -> Result { if env.is_seen(var) { unreachable!("The initial variable of a signature cannot be seen already") } else { @@ -704,6 +702,158 @@ impl std::fmt::Debug for LambdaSet<'_> { } } +#[derive(Default, Debug)] +struct MultimorphicNamesTable { + /// (source symbol, captures layouts) -> multimorphic alias + /// + /// SAFETY: actually, the `Layout` is alive only as long as the `arena` is alive. We take care + /// to promote new layouts to the owned arena. Since we are using a bump-allocating arena, the + /// references will never be invalidated until the arena is dropped, which happens when this + /// struct is dropped. + /// Also, the `Layout`s we owned are never exposed back via the public API. + inner: VecMap<(Symbol, &'static [Layout<'static>]), Symbol>, + arena: Bump, + ident_ids: IdentIds, +} + +impl MultimorphicNamesTable { + fn get<'b>(&self, name: Symbol, captures_layouts: &'b [Layout<'b>]) -> Option { + self.inner.get(&(name, captures_layouts)).copied() + } + + fn insert<'b>(&mut self, name: Symbol, captures_layouts: &'b [Layout<'b>]) -> Symbol { + debug_assert!(!self.inner.contains_key(&(name, captures_layouts))); + + let new_ident = self.ident_ids.gen_unique(); + let new_symbol = Symbol::new(ModuleId::MULTIMORPHIC, new_ident); + + let captures_layouts = self.promote_layout_slice(captures_layouts); + + self.inner.insert((name, captures_layouts), new_symbol); + + new_symbol + } + + fn alloc_st(&self, v: T) -> &'static T { + unsafe { std::mem::transmute::<_, &'static Bump>(&self.arena) }.alloc(v) + } + + fn promote_layout<'b>(&self, layout: Layout<'b>) -> Layout<'static> { + match layout { + Layout::Builtin(builtin) => Layout::Builtin(self.promote_builtin(builtin)), + Layout::Struct { + field_order_hash, + field_layouts, + } => Layout::Struct { + field_order_hash, + field_layouts: self.promote_layout_slice(field_layouts), + }, + Layout::Boxed(layout) => Layout::Boxed(self.alloc_st(self.promote_layout(*layout))), + Layout::Union(union_layout) => Layout::Union(self.promote_union_layout(union_layout)), + Layout::LambdaSet(lambda_set) => Layout::LambdaSet(self.promote_lambda_set(lambda_set)), + Layout::RecursivePointer => Layout::RecursivePointer, + } + } + + fn promote_layout_slice<'b>(&self, layouts: &'b [Layout<'b>]) -> &'static [Layout<'static>] { + layouts + .iter() + .map(|layout| self.promote_layout(*layout)) + .collect_in::>(unsafe { std::mem::transmute(&self.arena) }) + .into_bump_slice() + } + + fn promote_layout_slice_slices<'b>( + &self, + layout_slices: &'b [&'b [Layout<'b>]], + ) -> &'static [&'static [Layout<'static>]] { + layout_slices + .iter() + .map(|slice| self.promote_layout_slice(slice)) + .collect_in::>(unsafe { std::mem::transmute(&self.arena) }) + .into_bump_slice() + } + + fn promote_builtin(&self, builtin: Builtin) -> Builtin<'static> { + match builtin { + Builtin::Int(w) => Builtin::Int(w), + Builtin::Float(w) => Builtin::Float(w), + Builtin::Bool => Builtin::Bool, + Builtin::Decimal => Builtin::Decimal, + Builtin::Str => Builtin::Str, + Builtin::Dict(k, v) => Builtin::Dict( + self.alloc_st(self.promote_layout(*k)), + self.alloc_st(self.promote_layout(*v)), + ), + Builtin::Set(k) => Builtin::Set(self.alloc_st(self.promote_layout(*k))), + Builtin::List(l) => Builtin::Set(self.alloc_st(self.promote_layout(*l))), + } + } + + fn promote_union_layout(&self, union_layout: UnionLayout) -> UnionLayout<'static> { + match union_layout { + UnionLayout::NonRecursive(slices) => { + UnionLayout::NonRecursive(self.promote_layout_slice_slices(slices)) + } + UnionLayout::Recursive(slices) => { + UnionLayout::Recursive(self.promote_layout_slice_slices(slices)) + } + UnionLayout::NonNullableUnwrapped(slice) => { + UnionLayout::NonNullableUnwrapped(self.promote_layout_slice(slice)) + } + UnionLayout::NullableWrapped { + nullable_id, + other_tags, + } => UnionLayout::NullableWrapped { + nullable_id, + other_tags: self.promote_layout_slice_slices(other_tags), + }, + UnionLayout::NullableUnwrapped { + nullable_id, + other_fields, + } => UnionLayout::NullableUnwrapped { + nullable_id, + other_fields: self.promote_layout_slice(other_fields), + }, + } + } + + fn promote_lambda_set(&self, lambda_set: LambdaSet) -> LambdaSet<'static> { + let LambdaSet { + set, + representation, + } = lambda_set; + let set = set + .iter() + .map(|(name, slice)| (*name, self.promote_layout_slice(slice))) + .collect_in::>(unsafe { std::mem::transmute(&self.arena) }) + .into_bump_slice(); + LambdaSet { + set, + representation: self.alloc_st(self.promote_layout(*representation)), + } + } +} + +#[derive(Default, Debug)] +pub struct MultimorphicNames(Arc>); + +impl Clone for MultimorphicNames { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +impl MultimorphicNames { + fn get<'b>(&self, name: Symbol, captures_layouts: &'b [Layout<'b>]) -> Option { + self.0.lock().unwrap().get(name, captures_layouts) + } + + fn insert<'b>(&mut self, name: Symbol, captures_layouts: &'b [Layout<'b>]) -> Symbol { + self.0.lock().unwrap().insert(name, captures_layouts) + } +} + #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] enum LambdaNameInner { /// Standard lambda name assigned during canonicalize/constrain @@ -806,8 +956,7 @@ impl<'a> LambdaSet<'a> { } /// Does the lambda set contain the given symbol? - /// NOTE: for multimorphic variants, this checks the alias name; the source name will always be - /// the name of the first multimorphic variant. + /// NOTE: for multimorphic variants, this checks the alias name. pub fn contains(&self, symbol: Symbol) -> bool { self.set.iter().any(|(s, _)| match s.0 { LambdaNameInner::Name(name) => name == symbol, @@ -841,12 +990,23 @@ impl<'a> LambdaSet<'a> { self.layout_for_member(comparator) } + fn contains_source(&self, symbol: Symbol) -> bool { + self.set.iter().any(|(s, _)| match s.0 { + LambdaNameInner::Name(name) => name == symbol, + LambdaNameInner::Multimorphic { source, .. } => source == symbol, + }) + } + + /// Finds an alias name for a possible-multimorphic lambda variant in the lambda set. pub fn find_lambda_name( &self, function_symbol: Symbol, captures_layouts: &[Layout], ) -> LambdaName { - debug_assert!(self.contains(function_symbol), "function symbol not in set"); + debug_assert!( + self.contains_source(function_symbol), + "function symbol not in set" + ); let comparator = |other_name: LambdaName, other_captures_layouts: &[Layout]| { let other_name = match other_name.0 { @@ -868,46 +1028,6 @@ impl<'a> LambdaSet<'a> { *name } - // Layout for a single member of the lambda set, when you are constructing a proc, and already - // know the multimorphic name (if any). - // pub fn layout_for_member_constructing_proc( - // &self, - // lambda_name: LambdaName, - // ) -> ClosureRepresentation<'a> { - // debug_assert!( - // self.set.iter().any(|(s, _)| *s == lambda_name), - // "lambda not in set" - // ); - - // let comparator = - // |other_name: LambdaName, _other_captures_layouts: &[Layout]| other_name == lambda_name; - - // self.layout_for_member(comparator) - // } - - // Layout for a single member of the lambda set, when you are constructing a closure - // representation, and maybe need to pick out a multimorphic variant. - // pub fn layout_for_member_constructing_closure_data( - // &self, - // function_symbol: Symbol, - // captures_layouts: &[Layout], - // ) -> ClosureRepresentation<'a> { - // debug_assert!(self.contains(function_symbol), "function symbol not in set"); - - // let comparator = |other_name: LambdaName, other_captures_layouts: &[Layout]| { - // let other_name = match other_name.0 { - // LambdaNameInner::Name(name) => name, - // // Take the source, since we'll want to pick out the multimorphic name if it - // // matches - // LambdaNameInner::Multimorphic { source, .. } => source, - // }; - // other_name == function_symbol - // && captures_layouts.iter().eq(other_captures_layouts.iter()) - // }; - - // self.layout_for_member(comparator) - // } - fn layout_for_member(&self, comparator: F) -> ClosureRepresentation<'a> where F: Fn(LambdaName, &[Layout]) -> bool, @@ -1001,26 +1121,28 @@ impl<'a> LambdaSet<'a> { } } - pub fn from_var( + pub fn from_var( arena: &'a Bump, subs: &Subs, closure_var: Variable, target_info: TargetInfo, - fresh_multimorphic_symbol: &mut F, - ) -> Result - where - F: FreshMultimorphicSymbol, - { - match roc_types::pretty_print::resolve_lambda_set(subs, closure_var) { + multimorphic_names: &mut MultimorphicNames, + ) -> Result { + match resolve_lambda_set(subs, closure_var) { ResolvedLambdaSet::Set(mut lambdas) => { // sort the tags; make sure ordering stays intact! lambdas.sort_by_key(|(sym, _)| *sym); let mut set: Vec<(LambdaName, &[Layout])> = Vec::with_capacity_in(lambdas.len(), arena); + let mut set_for_making_repr: std::vec::Vec<(Symbol, std::vec::Vec)> = + std::vec::Vec::with_capacity(lambdas.len()); let mut last_function_symbol = None; - for (function_symbol, variables) in lambdas.iter() { + let mut lambdas_it = lambdas.iter().peekable(); + + let mut has_multimorphic = false; + while let Some((function_symbol, variables)) = lambdas_it.next() { let mut arguments = Vec::with_capacity_in(variables.len(), arena); let mut env = Env { @@ -1028,39 +1150,59 @@ impl<'a> LambdaSet<'a> { subs, seen: Vec::new_in(arena), target_info, - fresh_multimorphic_symbol, + multimorphic_names, }; for var in variables { arguments.push(Layout::from_var(&mut env, *var)?); } - let lambda_name = match last_function_symbol { - None => LambdaNameInner::Name(*function_symbol), - Some(last_function_symbol) => { - if function_symbol != last_function_symbol { - LambdaNameInner::Name(*function_symbol) - } else { - LambdaNameInner::Multimorphic { - source: *function_symbol, - alias: (*fresh_multimorphic_symbol)(), - } - } + let arguments = arguments.into_bump_slice(); + + let is_multimorphic = match (last_function_symbol, lambdas_it.peek()) { + (None, None) => false, + (Some(sym), None) | (None, Some((sym, _))) => function_symbol == sym, + (Some(sym1), Some((sym2, _))) => { + function_symbol == sym1 || function_symbol == sym2 } }; + + let lambda_name = if is_multimorphic { + let alias = match multimorphic_names.get(*function_symbol, arguments) { + Some(alias) => alias, + None => multimorphic_names.insert(*function_symbol, arguments), + }; + + has_multimorphic = true; + + LambdaNameInner::Multimorphic { + source: *function_symbol, + alias, + } + } else { + LambdaNameInner::Name(*function_symbol) + }; let lambda_name = LambdaName(lambda_name); - set.push((lambda_name, arguments.into_bump_slice())); + set.push((lambda_name, arguments)); + set_for_making_repr.push((lambda_name.call_name(), variables.to_vec())); last_function_symbol = Some(function_symbol); } + if has_multimorphic { + // Must re-sort the set in case we added multimorphic lambdas since they may under + // another name + set.sort_by_key(|(name, _)| name.call_name()); + set_for_making_repr.sort_by_key(|(name, _)| *name); + } + let representation = arena.alloc(Self::make_representation( arena, subs, - lambdas, + set_for_making_repr, target_info, - fresh_multimorphic_symbol, + multimorphic_names, )); Ok(LambdaSet { @@ -1079,22 +1221,16 @@ impl<'a> LambdaSet<'a> { } } - fn make_representation( + fn make_representation( arena: &'a Bump, subs: &Subs, tags: std::vec::Vec<(Symbol, std::vec::Vec)>, target_info: TargetInfo, - fresh_multimorphic_symbol: &mut F, + multimorphic_names: &mut MultimorphicNames, ) -> Layout<'a> { // otherwise, this is a closure with a payload - let variant = union_sorted_tags_help( - arena, - tags, - None, - subs, - target_info, - fresh_multimorphic_symbol, - ); + let variant = + union_sorted_tags_help(arena, tags, None, subs, target_info, multimorphic_names); use UnionVariant::*; match variant { @@ -1149,6 +1285,40 @@ impl<'a> LambdaSet<'a> { } } +enum ResolvedLambdaSet { + Set(std::vec::Vec<(Symbol, std::vec::Vec)>), + /// TODO: figure out if this can happen in a correct program, or is the result of a bug in our + /// compiler. See https://github.com/rtfeldman/roc/issues/3163. + Unbound, +} + +fn resolve_lambda_set(subs: &Subs, mut var: Variable) -> ResolvedLambdaSet { + let mut set = vec![]; + loop { + match subs.get_content_without_compacting(var) { + Content::LambdaSet(subs::LambdaSet { + solved, + recursion_var: _, + unspecialized, + }) => { + debug_assert!( + unspecialized.is_empty(), + "unspecialized lambda sets left over during resolution: {:?}", + roc_types::subs::SubsFmtContent(subs.get_content_without_compacting(var), subs), + ); + roc_types::pretty_print::push_union(subs, solved, &mut set); + return ResolvedLambdaSet::Set(set); + } + Content::RecursionVar { structure, .. } => { + var = *structure; + } + Content::FlexVar(_) => return ResolvedLambdaSet::Unbound, + + c => internal_error!("called with a non-lambda set {:?}", c), + } + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Builtin<'a> { Int(IntWidth), @@ -1161,21 +1331,15 @@ pub enum Builtin<'a> { List(&'a Layout<'a>), } -pub struct Env<'a, 'b, F> -where - F: FreshMultimorphicSymbol, -{ +pub struct Env<'a, 'b> { target_info: TargetInfo, arena: &'a Bump, seen: Vec<'a, Variable>, subs: &'b Subs, - fresh_multimorphic_symbol: &'b mut F, + multimorphic_names: &'b mut MultimorphicNames, } -impl<'a, 'b, F> Env<'a, 'b, F> -where - F: FreshMultimorphicSymbol, -{ +impl<'a, 'b> Env<'a, 'b> { fn is_seen(&self, var: Variable) -> bool { let var = self.subs.get_root_key_without_compacting(var); @@ -1225,8 +1389,8 @@ impl<'a> Layout<'a> { field_order_hash: FieldOrderHash::ZERO_FIELD_HASH, }; - fn new_help<'b, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, 'b, F>, + fn new_help<'b>( + env: &mut Env<'a, 'b>, var: Variable, content: Content, ) -> Result { @@ -1284,10 +1448,7 @@ impl<'a> Layout<'a> { /// Returns Err(()) if given an error, or Ok(Layout) if given a non-erroneous Structure. /// Panics if given a FlexVar or RigidVar, since those should have been /// monomorphized away already! - fn from_var( - env: &mut Env<'a, '_, F>, - var: Variable, - ) -> Result { + fn from_var(env: &mut Env<'a, '_>, var: Variable) -> Result { if env.is_seen(var) { Ok(Layout::RecursivePointer) } else { @@ -1611,16 +1772,13 @@ impl<'a> LayoutCache<'a> { } } - pub fn from_var( + pub fn from_var( &mut self, arena: &'a Bump, var: Variable, subs: &Subs, - fresh_multimorphic_symbol: &mut F, - ) -> Result, LayoutProblem> - where - F: FreshMultimorphicSymbol, - { + multimorphic_names: &mut MultimorphicNames, + ) -> Result, LayoutProblem> { // Store things according to the root Variable, to avoid duplicate work. let var = subs.get_root_key_without_compacting(var); @@ -1629,18 +1787,18 @@ impl<'a> LayoutCache<'a> { subs, seen: Vec::new_in(arena), target_info: self.target_info, - fresh_multimorphic_symbol, + multimorphic_names, }; Layout::from_var(&mut env, var) } - pub fn raw_from_var( + pub fn raw_from_var( &mut self, arena: &'a Bump, var: Variable, subs: &Subs, - fresh_multimorphic_symbol: &mut F, + multimorphic_names: &mut MultimorphicNames, ) -> Result, LayoutProblem> { // Store things according to the root Variable, to avoid duplicate work. let var = subs.get_root_key_without_compacting(var); @@ -1650,7 +1808,7 @@ impl<'a> LayoutCache<'a> { subs, seen: Vec::new_in(arena), target_info: self.target_info, - fresh_multimorphic_symbol, + multimorphic_names, }; RawFunctionLayout::from_var(&mut env, var) } @@ -1905,8 +2063,8 @@ impl<'a> Builtin<'a> { } } -fn layout_from_lambda_set<'a, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +fn layout_from_lambda_set<'a>( + env: &mut Env<'a, '_>, lset: subs::LambdaSet, ) -> Result, LayoutProblem> { // Lambda set is just a tag union from the layout's perspective. @@ -1935,8 +2093,8 @@ fn layout_from_lambda_set<'a, F: FreshMultimorphicSymbol>( } } -fn layout_from_flat_type<'a, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +fn layout_from_flat_type<'a>( + env: &mut Env<'a, '_>, flat_type: FlatType, ) -> Result, LayoutProblem> { use roc_types::subs::FlatType::*; @@ -2049,7 +2207,7 @@ fn layout_from_flat_type<'a, F: FreshMultimorphicSymbol>( env.subs, closure_var, env.target_info, - env.fresh_multimorphic_symbol, + env.multimorphic_names, )?; Ok(Layout::LambdaSet(lambda_set)) @@ -2129,19 +2287,19 @@ fn layout_from_flat_type<'a, F: FreshMultimorphicSymbol>( pub type SortedField<'a> = (Lowercase, Variable, Result, Layout<'a>>); -pub fn sort_record_fields<'a, F: FreshMultimorphicSymbol>( +pub fn sort_record_fields<'a>( arena: &'a Bump, var: Variable, subs: &Subs, target_info: TargetInfo, - fresh_multimorphic_symbol: &mut F, + multimorphic_names: &mut MultimorphicNames, ) -> Result>, LayoutProblem> { let mut env = Env { arena, subs, seen: Vec::new_in(arena), target_info, - fresh_multimorphic_symbol, + multimorphic_names, }; let (it, _) = match gather_fields_unsorted_iter(subs, RecordFields::empty(), var) { @@ -2156,8 +2314,8 @@ pub fn sort_record_fields<'a, F: FreshMultimorphicSymbol>( sort_record_fields_help(&mut env, it) } -fn sort_record_fields_help<'a, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +fn sort_record_fields_help<'a>( + env: &mut Env<'a, '_>, fields_map: impl Iterator)>, ) -> Result>, LayoutProblem> { let target_info = env.target_info; @@ -2343,12 +2501,12 @@ impl<'a> WrappedVariant<'a> { } } -pub fn union_sorted_tags<'a, F: FreshMultimorphicSymbol>( +pub fn union_sorted_tags<'a>( arena: &'a Bump, var: Variable, subs: &Subs, target_info: TargetInfo, - fresh_multimorphic_symbol: &mut F, + multimorphic_names: &mut MultimorphicNames, ) -> Result, LayoutProblem> { let var = if let Content::RecursionVar { structure, .. } = subs.get_content_without_compacting(var) { @@ -2369,7 +2527,7 @@ pub fn union_sorted_tags<'a, F: FreshMultimorphicSymbol>( | Err((_, Content::FlexVar(_) | Content::RigidVar(_))) | Err((_, Content::RecursionVar { .. })) => { let opt_rec_var = get_recursion_var(subs, var); - union_sorted_tags_help(arena, tags_vec, opt_rec_var, subs, target_info, fresh_multimorphic_symbol) + union_sorted_tags_help(arena, tags_vec, opt_rec_var, subs, target_info, multimorphic_names) } Err((_, Content::Error)) => return Err(LayoutProblem::Erroneous), Err(other) => panic!("invalid content in tag union variable: {:?}", other), @@ -2398,8 +2556,8 @@ fn is_recursive_tag_union(layout: &Layout) -> bool { ) } -fn union_sorted_tags_help_new<'a, L, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +fn union_sorted_tags_help_new<'a, L>( + env: &mut Env<'a, '_>, tags_list: &[(&'_ L, &[Variable])], opt_rec_var: Option, ) -> UnionVariant<'a> @@ -2589,13 +2747,13 @@ where } } -pub fn union_sorted_tags_help<'a, L, F: FreshMultimorphicSymbol>( +pub fn union_sorted_tags_help<'a, L>( arena: &'a Bump, mut tags_vec: std::vec::Vec<(L, std::vec::Vec)>, opt_rec_var: Option, subs: &Subs, target_info: TargetInfo, - fresh_multimorphic_symbol: &mut F, + multimorphic_names: &mut MultimorphicNames, ) -> UnionVariant<'a> where L: Into + Ord + Clone, @@ -2608,7 +2766,7 @@ where subs, seen: Vec::new_in(arena), target_info, - fresh_multimorphic_symbol, + multimorphic_names, }; match tags_vec.len() { @@ -2797,8 +2955,8 @@ where } } -fn layout_from_newtype<'a, L: Label, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +fn layout_from_newtype<'a, L: Label>( + env: &mut Env<'a, '_>, tags: &UnsortedUnionLabels, ) -> Layout<'a> { debug_assert!(tags.is_newtype_wrapper(env.subs)); @@ -2821,10 +2979,7 @@ fn layout_from_newtype<'a, L: Label, F: FreshMultimorphicSymbol>( } } -fn layout_from_union<'a, L, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, - tags: &UnsortedUnionLabels, -) -> Layout<'a> +fn layout_from_union<'a, L>(env: &mut Env<'a, '_>, tags: &UnsortedUnionLabels) -> Layout<'a> where L: Label + Ord + Into, { @@ -2900,8 +3055,8 @@ where } } -fn layout_from_recursive_union<'a, L, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +fn layout_from_recursive_union<'a, L>( + env: &mut Env<'a, '_>, rec_var: Variable, tags: &UnsortedUnionLabels, ) -> Result, LayoutProblem> @@ -3082,8 +3237,8 @@ fn layout_from_num_content<'a>( } } -fn dict_layout_from_key_value<'a, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +fn dict_layout_from_key_value<'a>( + env: &mut Env<'a, '_>, key_var: Variable, value_var: Variable, ) -> Result, LayoutProblem> { @@ -3118,8 +3273,8 @@ fn dict_layout_from_key_value<'a, F: FreshMultimorphicSymbol>( pub trait FreshMultimorphicSymbol: FnMut() -> Symbol {} impl FreshMultimorphicSymbol for T where T: FnMut() -> Symbol {} -pub fn list_layout_from_elem<'a, F: FreshMultimorphicSymbol>( - env: &mut Env<'a, '_, F>, +pub fn list_layout_from_elem<'a>( + env: &mut Env<'a, '_>, element_var: Variable, ) -> Result, LayoutProblem> { let is_variable = |content| matches!(content, &Content::FlexVar(_) | &Content::RigidVar(_)); diff --git a/crates/compiler/test_mono/generated/multimorphic_lambda_set_capture.txt b/crates/compiler/test_mono/generated/multimorphic_lambda_set_capture.txt new file mode 100644 index 0000000000..ba2cb0e203 --- /dev/null +++ b/crates/compiler/test_mono/generated/multimorphic_lambda_set_capture.txt @@ -0,0 +1,49 @@ +procedure #Multimorphic.0 (Test.23, #Attr.12): + let Test.4 : Str = UnionAtIndex (Id 0) (Index 0) #Attr.12; + inc Test.4; + dec #Attr.12; + let Test.25 : Str = ""; + ret Test.25; + +procedure #Multimorphic.1 (Test.17, #Attr.12): + let Test.4 : {} = UnionAtIndex (Id 1) (Index 0) #Attr.12; + dec #Attr.12; + let Test.19 : Str = ""; + ret Test.19; + +procedure Test.1 (Test.4): + let Test.16 : [C Str, C {}] = ClosureTag(#Multimorphic.1) Test.4; + ret Test.16; + +procedure Test.1 (Test.4): + let Test.22 : [C Str, C {}] = ClosureTag(#Multimorphic.0) Test.4; + ret Test.22; + +procedure Test.0 (): + let Test.2 : Int1 = true; + joinpoint Test.13 Test.3: + let Test.8 : {} = Struct {}; + let Test.9 : U8 = GetTagId Test.3; + joinpoint Test.10 Test.7: + ret Test.7; + in + switch Test.9: + case 0: + let Test.11 : Str = CallByName #Multimorphic.0 Test.8 Test.3; + jump Test.10 Test.11; + + default: + let Test.12 : Str = CallByName #Multimorphic.1 Test.8 Test.3; + jump Test.10 Test.12; + + in + let Test.26 : Int1 = true; + let Test.27 : Int1 = lowlevel Eq Test.26 Test.2; + if Test.27 then + let Test.15 : {} = Struct {}; + let Test.14 : [C Str, C {}] = CallByName Test.1 Test.15; + jump Test.13 Test.14; + else + let Test.21 : Str = ""; + let Test.20 : [C Str, C {}] = CallByName Test.1 Test.21; + jump Test.13 Test.20; diff --git a/crates/compiler/test_mono/src/tests.rs b/crates/compiler/test_mono/src/tests.rs index 10be2c684e..62a0ba261f 100644 --- a/crates/compiler/test_mono/src/tests.rs +++ b/crates/compiler/test_mono/src/tests.rs @@ -1527,24 +1527,22 @@ fn tail_call_with_different_layout() { } #[mono_test] -#[ignore] -fn lambda_sets_collide_with_captured_var() { +fn multimorphic_lambda_set_capture() { indoc!( r#" capture : a -> ({} -> Str) capture = \val -> - thunk = - \{} -> - when val is - _ -> "" - thunk + \{} -> + when val is + _ -> "" x : [True, False] + x = True fun = when x is - True -> capture 1u8 - False -> capture 1u64 + True -> capture {} + False -> capture "" fun {} "# diff --git a/crates/compiler/types/src/pretty_print.rs b/crates/compiler/types/src/pretty_print.rs index 0203793009..0837e5f20e 100644 --- a/crates/compiler/types/src/pretty_print.rs +++ b/crates/compiler/types/src/pretty_print.rs @@ -4,7 +4,6 @@ use crate::subs::{ }; use crate::types::{name_type_var, RecordField, Uls}; use roc_collections::all::MutMap; -use roc_error_macros::internal_error; use roc_module::ident::{Lowercase, TagName}; use roc_module::symbol::{Interns, ModuleId, Symbol}; @@ -1152,7 +1151,7 @@ fn write_flat_type<'a>( } } -fn push_union<'a, L: Label>( +pub fn push_union<'a, L: Label>( subs: &'a Subs, tags: &UnionLabels, fields: &mut Vec<(L, Vec)>, @@ -1196,40 +1195,6 @@ pub fn chase_ext_tag_union<'a>( } } -pub enum ResolvedLambdaSet { - Set(Vec<(Symbol, Vec)>), - /// TODO: figure out if this can happen in a correct program, or is the result of a bug in our - /// compiler. See https://github.com/rtfeldman/roc/issues/3163. - Unbound, -} - -pub fn resolve_lambda_set(subs: &Subs, mut var: Variable) -> ResolvedLambdaSet { - let mut set = vec![]; - loop { - match subs.get_content_without_compacting(var) { - Content::LambdaSet(subs::LambdaSet { - solved, - recursion_var: _, - unspecialized, - }) => { - debug_assert!( - unspecialized.is_empty(), - "unspecialized lambda sets left over during resolution: {:?}", - crate::subs::SubsFmtContent(subs.get_content_without_compacting(var), subs), - ); - push_union(subs, solved, &mut set); - return ResolvedLambdaSet::Set(set); - } - Content::RecursionVar { structure, .. } => { - var = *structure; - } - Content::FlexVar(_) => return ResolvedLambdaSet::Unbound, - - c => internal_error!("called with a non-lambda set {:?}", c), - } - } -} - fn write_apply<'a>( env: &Env, ctx: &mut Context<'a>,