diff --git a/Cargo.lock b/Cargo.lock index 6f505e2dc7..35806e3e11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2250,6 +2250,7 @@ dependencies = [ "roc_solve", "roc_types", "roc_unify", + "ven_pretty", ] [[package]] diff --git a/cli/src/repl.rs b/cli/src/repl.rs index d95ebe26dc..e48c0cffdf 100644 --- a/cli/src/repl.rs +++ b/cli/src/repl.rs @@ -1,7 +1,6 @@ use bumpalo::Bump; use inkwell::context::Context; use inkwell::execution_engine::JitFunction; -use inkwell::passes::PassManager; use inkwell::types::BasicType; use inkwell::OptimizationLevel; use roc_builtins::unique::uniq_stdlib; @@ -19,7 +18,7 @@ use roc_gen::llvm::build::{build_proc, build_proc_header, OptLevel}; use roc_gen::llvm::convert::basic_type_from_layout; use roc_module::ident::Ident; use roc_module::symbol::{IdentIds, Interns, ModuleId, ModuleIds, Symbol}; -use roc_mono::expr::Procs; +use roc_mono::ir::Procs; use roc_mono::layout::{Layout, LayoutCache}; use roc_parse::ast::{self, Attempting}; use roc_parse::blankspace::space0_before; @@ -209,13 +208,9 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S } let context = Context::create(); - let module = roc_gen::llvm::build::module_from_builtins(&context, "app"); + let module = arena.alloc(roc_gen::llvm::build::module_from_builtins(&context, "app")); let builder = context.create_builder(); - let fpm = PassManager::create(&module); - - roc_gen::llvm::build::add_passes(&fpm, opt_level); - - fpm.initialize(); + let (mpm, fpm) = roc_gen::llvm::build::construct_optimization_passes(module, opt_level); // pretty-print the expr type string for later. name_all_type_vars(var, &mut subs); @@ -243,8 +238,9 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S builder: &builder, context: &context, interns, - module: arena.alloc(module), + module, ptr_bytes, + leak: false, }; let mut procs = Procs::default(); let mut ident_ids = env.interns.all_ident_ids.remove(&home).unwrap(); @@ -252,7 +248,7 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S // Populate Procs and get the low-level Expr from the canonical Expr let mut mono_problems = Vec::new(); - let mut mono_env = roc_mono::expr::Env { + let mut mono_env = roc_mono::ir::Env { arena: &arena, subs: &mut subs, problems: &mut mono_problems, @@ -260,7 +256,9 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S ident_ids: &mut ident_ids, }; - let main_body = roc_mono::expr::Expr::new(&mut mono_env, loc_expr.value, &mut procs); + let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); + let main_body = + roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); let mut headers = { let num_headers = match &procs.pending_specializations { Some(map) => map.len(), @@ -270,7 +268,7 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S Vec::with_capacity(num_headers) }; let mut layout_cache = LayoutCache::default(); - let mut procs = roc_mono::expr::specialize_all(&mut mono_env, procs, &mut layout_cache); + let mut procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut layout_cache); assert_eq!( procs.runtime_errors, @@ -285,8 +283,10 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S // Add all the Proc headers to the module. // We have to do this in a separate pass first, // because their bodies may reference each other. + + let mut gen_scope = roc_gen::llvm::build::Scope::default(); for ((symbol, layout), proc) in procs.specialized.drain() { - use roc_mono::expr::InProgressProc::*; + use roc_mono::ir::InProgressProc::*; match proc { InProgress => { @@ -331,10 +331,10 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S builder.position_at_end(basic_block); - let ret = roc_gen::llvm::build::build_expr( + let ret = roc_gen::llvm::build::build_exp_stmt( &env, &mut layout_ids, - &ImMap::default(), + &mut gen_scope, main_fn, &main_body, ); @@ -350,6 +350,8 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S panic!("Main function {} failed LLVM verification. Uncomment things near this error message for more details.", main_fn_name); } + mpm.run_on(module); + // Verify the module if let Err(errors) = env.module.verify() { panic!("Errors defining module: {:?}", errors); @@ -367,7 +369,9 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S .ok_or(format!("Unable to JIT compile `{}`", main_fn_name)) .expect("errored"); - Ok((format!("{}", main.call()), expr_type_str)) + let result = main.call(); + let output = format!("{}", result); + Ok((output, expr_type_str)) } } diff --git a/cli/tests/cli_run.rs b/cli/tests/cli_run.rs index d9fd90c41e..9f452412c7 100644 --- a/cli/tests/cli_run.rs +++ b/cli/tests/cli_run.rs @@ -101,4 +101,17 @@ mod cli_run { assert!(&out.stdout.ends_with("Hello, World!\n")); assert!(out.status.success()); } + + #[test] + fn run_quicksort() { + let out = run_roc(&[ + "run", + example_file("quicksort", "Quicksort.roc").to_str().unwrap(), + "--optimize", + ]); + + assert_eq!(&out.stderr, ""); + assert!(&out.stdout.ends_with("[4, 7, 19, 21]\n")); + assert!(out.status.success()); + } } diff --git a/compiler/build/src/program.rs b/compiler/build/src/program.rs index 2071b9a24b..135540404d 100644 --- a/compiler/build/src/program.rs +++ b/compiler/build/src/program.rs @@ -1,10 +1,8 @@ use bumpalo::Bump; use inkwell::context::Context; use inkwell::module::Linkage; -use inkwell::passes::PassManager; use inkwell::types::BasicType; use inkwell::OptimizationLevel; -use roc_collections::all::ImMap; use roc_gen::layout_id::LayoutIds; use roc_gen::llvm::build::{ build_proc, build_proc_header, get_call_conventions, module_from_builtins, OptLevel, @@ -12,7 +10,7 @@ use roc_gen::llvm::build::{ use roc_gen::llvm::convert::basic_type_from_layout; use roc_load::file::LoadedModule; use roc_module::symbol::Symbol; -use roc_mono::expr::{Env, Expr, PartialProc, Procs}; +use roc_mono::ir::{Env, PartialProc, Procs, Stmt}; use roc_mono::layout::{Layout, LayoutCache}; use inkwell::targets::{ @@ -128,13 +126,9 @@ pub fn gen( // Generate the binary let context = Context::create(); - let module = module_from_builtins(&context, "app"); + let module = arena.alloc(module_from_builtins(&context, "app")); let builder = context.create_builder(); - let fpm = PassManager::create(&module); - - roc_gen::llvm::build::add_passes(&fpm, opt_level); - - fpm.initialize(); + let (mpm, fpm) = roc_gen::llvm::build::construct_optimization_passes(module, opt_level); // Compute main_fn_type before moving subs to Env let layout = Layout::new(&arena, content, &subs).unwrap_or_else(|err| { @@ -155,8 +149,9 @@ pub fn gen( builder: &builder, context: &context, interns: loaded.interns, - module: arena.alloc(module), + module, ptr_bytes, + leak: false, }; let mut ident_ids = env.interns.all_ident_ids.remove(&home).unwrap(); let mut layout_ids = LayoutIds::default(); @@ -202,7 +197,9 @@ pub fn gen( let proc = PartialProc { annotation: def.expr_var, // This is a 0-arity thunk, so it has no arguments. - pattern_symbols: &[], + pattern_symbols: bumpalo::collections::Vec::new_in( + mono_env.arena, + ), body, }; @@ -226,7 +223,9 @@ pub fn gen( } // Populate Procs further and get the low-level Expr from the canonical Expr - let main_body = Expr::new(&mut mono_env, loc_expr.value, &mut procs); + let main_body = Stmt::new(&mut mono_env, loc_expr.value, &mut procs); + let main_body = + roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); let mut headers = { let num_headers = match &procs.pending_specializations { Some(map) => map.len(), @@ -235,7 +234,7 @@ pub fn gen( Vec::with_capacity(num_headers) }; - let mut procs = roc_mono::expr::specialize_all(&mut mono_env, procs, &mut layout_cache); + let mut procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut layout_cache); assert_eq!( procs.runtime_errors, @@ -251,7 +250,7 @@ pub fn gen( // We have to do this in a separate pass first, // because their bodies may reference each other. for ((symbol, layout), proc) in procs.specialized.drain() { - use roc_mono::expr::InProgressProc::*; + use roc_mono::ir::InProgressProc::*; match proc { InProgress => { @@ -296,10 +295,10 @@ pub fn gen( builder.position_at_end(basic_block); - let ret = roc_gen::llvm::build::build_expr( + let ret = roc_gen::llvm::build::build_exp_stmt( &env, &mut layout_ids, - &ImMap::default(), + &mut roc_gen::llvm::build::Scope::default(), main_fn, &main_body, ); @@ -315,6 +314,8 @@ pub fn gen( panic!("Function {} failed LLVM verification.", main_fn_name); } + mpm.run_on(module); + // Verify the module if let Err(errors) = env.module.verify() { panic!("😱 LLVM errors when defining module: {:?}", errors); diff --git a/compiler/can/src/expr.rs b/compiler/can/src/expr.rs index ea14a43040..7582abdab5 100644 --- a/compiler/can/src/expr.rs +++ b/compiler/can/src/expr.rs @@ -58,6 +58,7 @@ pub enum Expr { Str(Box), BlockStr(Box), List { + list_var: Variable, // required for uniqueness of the list elem_var: Variable, loc_elems: Vec>, }, @@ -256,6 +257,7 @@ pub fn canonicalize_expr<'a>( if loc_elems.is_empty() { ( List { + list_var: var_store.fresh(), elem_var: var_store.fresh(), loc_elems: Vec::new(), }, @@ -283,6 +285,7 @@ pub fn canonicalize_expr<'a>( ( List { + list_var: var_store.fresh(), elem_var: var_store.fresh(), loc_elems: can_elems, }, @@ -1052,6 +1055,7 @@ pub fn inline_calls(var_store: &mut VarStore, scope: &mut Scope, expr: Expr) -> | other @ RunLowLevel { .. } => other, List { + list_var, elem_var, loc_elems, } => { @@ -1067,6 +1071,7 @@ pub fn inline_calls(var_store: &mut VarStore, scope: &mut Scope, expr: Expr) -> } List { + list_var, elem_var, loc_elems: new_elems, } diff --git a/compiler/can/src/pattern.rs b/compiler/can/src/pattern.rs index 70dbae9c67..eef2b0532e 100644 --- a/compiler/can/src/pattern.rs +++ b/compiler/can/src/pattern.rs @@ -77,7 +77,12 @@ pub fn symbols_from_pattern_help(pattern: &Pattern, symbols: &mut Vec) { } RecordDestructure { destructs, .. } => { for destruct in destructs { - symbols.push(destruct.value.symbol); + // when a record field has a pattern guard, only symbols in the guard are introduced + if let DestructType::Guard(_, subpattern) = &destruct.value.typ { + symbols_from_pattern_help(&subpattern.value, symbols); + } else { + symbols.push(destruct.value.symbol); + } } } diff --git a/compiler/constrain/src/expr.rs b/compiler/constrain/src/expr.rs index 2737a1e956..f7746c676f 100644 --- a/compiler/constrain/src/expr.rs +++ b/compiler/constrain/src/expr.rs @@ -203,7 +203,7 @@ pub fn constrain_expr( List { elem_var, loc_elems, - .. + list_var: _unused, } => { if loc_elems.is_empty() { exists( diff --git a/compiler/constrain/src/uniq.rs b/compiler/constrain/src/uniq.rs index d786e585b3..3728b7eb35 100644 --- a/compiler/constrain/src/uniq.rs +++ b/compiler/constrain/src/uniq.rs @@ -639,6 +639,7 @@ pub fn constrain_expr( exists(vars, And(arg_cons)) } List { + list_var, elem_var, loc_elems, } => { @@ -676,9 +677,12 @@ pub fn constrain_expr( } let inferred = list_type(Bool::variable(uniq_var), entry_type); - constraints.push(Eq(inferred, expected, Category::List, region)); + constraints.push(Eq(inferred, expected.clone(), Category::List, region)); - exists(vec![*elem_var, uniq_var], And(constraints)) + let stored = Type::Variable(*list_var); + constraints.push(Eq(stored, expected, Category::Storage, region)); + + exists(vec![*elem_var, *list_var, uniq_var], And(constraints)) } } Var(symbol) => { diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 13eff5351e..6a9b2094ee 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -5,6 +5,7 @@ use crate::llvm::convert::{ }; use bumpalo::collections::Vec; use bumpalo::Bump; +use inkwell::basic_block::BasicBlock; use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::memory_buffer::MemoryBuffer; @@ -18,8 +19,8 @@ use inkwell::{IntPredicate, OptimizationLevel}; use roc_collections::all::ImMap; use roc_module::low_level::LowLevel; use roc_module::symbol::{Interns, Symbol}; -use roc_mono::expr::{Expr, Proc}; -use roc_mono::layout::{Builtin, Layout}; +use roc_mono::ir::JoinPointId; +use roc_mono::layout::{Builtin, Layout, MemoryMode}; use target_lexicon::CallingConvention; /// This is for Inkwell's FunctionValue::verify - we want to know the verification @@ -30,12 +31,44 @@ const PRINT_FN_VERIFICATION_OUTPUT: bool = true; #[cfg(not(debug_assertions))] const PRINT_FN_VERIFICATION_OUTPUT: bool = false; +#[derive(Debug, Clone, Copy)] pub enum OptLevel { Normal, Optimize, } -pub type Scope<'a, 'ctx> = ImMap, PointerValue<'ctx>)>; +// pub type Scope<'a, 'ctx> = ImMap, PointerValue<'ctx>)>; +#[derive(Default, Debug, Clone, PartialEq)] +pub struct Scope<'a, 'ctx> { + symbols: ImMap, PointerValue<'ctx>)>, + join_points: ImMap, &'a [PointerValue<'ctx>])>, +} + +impl<'a, 'ctx> Scope<'a, 'ctx> { + fn get(&self, symbol: &Symbol) -> Option<&(Layout<'a>, PointerValue<'ctx>)> { + self.symbols.get(symbol) + } + fn insert(&mut self, symbol: Symbol, value: (Layout<'a>, PointerValue<'ctx>)) { + self.symbols.insert(symbol, value); + } + fn remove(&mut self, symbol: &Symbol) { + self.symbols.remove(symbol); + } + /* + fn get_join_point(&self, symbol: &JoinPointId) -> Option<&PhiValue<'ctx>> { + self.join_points.get(symbol) + } + fn remove_join_point(&mut self, symbol: &JoinPointId) { + self.join_points.remove(symbol); + } + fn get_mut_join_point(&mut self, symbol: &JoinPointId) -> Option<&mut PhiValue<'ctx>> { + self.join_points.get_mut(symbol) + } + fn insert_join_point(&mut self, symbol: JoinPointId, value: PhiValue<'ctx>) { + self.join_points.insert(symbol, value); + } + */ +} pub struct Env<'a, 'ctx, 'env> { pub arena: &'a Bump, @@ -44,6 +77,7 @@ pub struct Env<'a, 'ctx, 'env> { pub module: &'ctx Module<'ctx>, pub interns: Interns, pub ptr_bytes: u32, + pub leak: bool, } impl<'a, 'ctx, 'env> Env<'a, 'ctx, 'env> { @@ -121,14 +155,18 @@ fn add_intrinsic<'ctx>( fn_val } -pub fn add_passes(fpm: &PassManager>, opt_level: OptLevel) { +pub fn construct_optimization_passes<'a>( + module: &'a Module, + opt_level: OptLevel, +) -> (PassManager>, PassManager>) { + let mpm = PassManager::create(()); + let fpm = PassManager::create(module); + // tail-call elimination is always on fpm.add_instruction_combining_pass(); fpm.add_tail_call_elimination_pass(); let pmb = PassManagerBuilder::create(); - - // Enable more optimizations when running cargo test --release match opt_level { OptLevel::Normal => { pmb.set_optimization_level(OptimizationLevel::None); @@ -138,215 +176,48 @@ pub fn add_passes(fpm: &PassManager>, opt_level: OptLevel) { // // See https://llvm.org/doxygen/CodeGen_8h_source.html pmb.set_optimization_level(OptimizationLevel::Aggressive); + pmb.set_inliner_with_threshold(4); - // TODO figure out how enabling these individually differs from - // the broad "aggressive optimizations" setting. + // TODO figure out which of these actually help - // fpm.add_reassociate_pass(); - // fpm.add_basic_alias_analysis_pass(); - // fpm.add_promote_memory_to_register_pass(); - // fpm.add_cfg_simplification_pass(); - // fpm.add_gvn_pass(); - // TODO figure out why enabling any of these (even alone) causes LLVM to segfault - // fpm.add_strip_dead_prototypes_pass(); - // fpm.add_dead_arg_elimination_pass(); - // fpm.add_function_inlining_pass(); - // pmb.set_inliner_with_threshold(4); + // function passes + fpm.add_basic_alias_analysis_pass(); + fpm.add_memcpy_optimize_pass(); + fpm.add_jump_threading_pass(); + fpm.add_instruction_combining_pass(); + fpm.add_licm_pass(); + fpm.add_loop_unroll_pass(); + fpm.add_scalar_repl_aggregates_pass_ssa(); + + // module passes + mpm.add_cfg_simplification_pass(); + mpm.add_jump_threading_pass(); + mpm.add_instruction_combining_pass(); + mpm.add_memcpy_optimize_pass(); + mpm.add_promote_memory_to_register_pass(); } } + pmb.populate_module_pass_manager(&mpm); pmb.populate_function_pass_manager(&fpm); + + fpm.initialize(); + + // For now, we have just one of each + (mpm, fpm) } -#[allow(clippy::cognitive_complexity)] -pub fn build_expr<'a, 'ctx, 'env>( +pub fn build_exp_literal<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - scope: &Scope<'a, 'ctx>, - parent: FunctionValue<'ctx>, - expr: &Expr<'a>, + literal: &roc_mono::ir::Literal<'a>, ) -> BasicValueEnum<'ctx> { - use roc_mono::expr::Expr::*; + use roc_mono::ir::Literal::*; - match expr { + match literal { Int(num) => env.context.i64_type().const_int(*num as u64, true).into(), Float(num) => env.context.f64_type().const_float(*num).into(), Bool(b) => env.context.bool_type().const_int(*b as u64, false).into(), Byte(b) => env.context.i8_type().const_int(*b as u64, false).into(), - Cond { - branch_symbol, - pass: (pass_stores, pass_expr), - fail: (fail_stores, fail_expr), - ret_layout, - .. - } => { - let pass = env.arena.alloc(Expr::Store(pass_stores, pass_expr)); - let fail = env.arena.alloc(Expr::Store(fail_stores, fail_expr)); - - let ret_type = - basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); - - let cond_expr = load_symbol(env, scope, branch_symbol); - - match cond_expr { - IntValue(value) => { - // This is a call tobuild_basic_phi2, except inlined to prevent - // problems with lifetimes and closures involving layout_ids. - let builder = env.builder; - let context = env.context; - - // build blocks - let then_block = context.append_basic_block(parent, "then"); - let else_block = context.append_basic_block(parent, "else"); - let cont_block = context.append_basic_block(parent, "branchcont"); - - builder.build_conditional_branch(value, then_block, else_block); - - // build then block - builder.position_at_end(then_block); - let then_val = build_expr(env, layout_ids, scope, parent, pass); - builder.build_unconditional_branch(cont_block); - - let then_block = builder.get_insert_block().unwrap(); - - // build else block - builder.position_at_end(else_block); - let else_val = build_expr(env, layout_ids, scope, parent, fail); - builder.build_unconditional_branch(cont_block); - - let else_block = builder.get_insert_block().unwrap(); - - // emit merge block - builder.position_at_end(cont_block); - - let phi = builder.build_phi(ret_type, "branch"); - - phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); - - phi.as_basic_value() - } - _ => panic!( - "Tried to make a branch out of an invalid condition: cond_expr = {:?}", - cond_expr, - ), - } - } - Switch { - cond, - branches, - default_branch: (default_stores, default_expr), - ret_layout, - cond_layout, - } => { - let ret_type = - basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); - - let default_branch = env.arena.alloc(Expr::Store(default_stores, default_expr)); - - let mut combined = Vec::with_capacity_in(branches.len(), env.arena); - - for (int, stores, expr) in branches.iter() { - combined.push((*int, Expr::Store(stores, expr))); - } - - let switch_args = SwitchArgs { - cond_layout: cond_layout.clone(), - cond_expr: cond, - branches: combined.into_bump_slice(), - default_branch, - ret_type, - }; - - build_switch(env, layout_ids, scope, parent, switch_args) - } - Store(stores, ret) => { - let mut scope = im_rc::HashMap::clone(scope); - let context = &env.context; - - for (symbol, layout, expr) in stores.iter() { - let val = build_expr(env, layout_ids, &scope, parent, &expr); - let expr_bt = basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes); - let alloca = create_entry_block_alloca( - env, - parent, - expr_bt, - symbol.ident_string(&env.interns), - ); - - env.builder.build_store(alloca, val); - - // Make a new scope which includes the binding we just encountered. - // This should be done *after* compiling the bound expr, since any - // recursive (in the LetRec sense) bindings should already have - // been extracted as procedures. Nothing in here should need to - // access itself! - scope = im_rc::HashMap::clone(&scope); - - scope.insert(*symbol, (layout.clone(), alloca)); - } - - build_expr(env, layout_ids, &scope, parent, ret) - } - CallByName { name, layout, args } => { - let mut arg_tuples: Vec<(BasicValueEnum, &'a Layout<'a>)> = - Vec::with_capacity_in(args.len(), env.arena); - - for (arg, arg_layout) in args.iter() { - arg_tuples.push((build_expr(env, layout_ids, scope, parent, arg), arg_layout)); - } - - call_with_args( - env, - layout_ids, - layout, - *name, - parent, - arg_tuples.into_bump_slice(), - ) - } - FunctionPointer(symbol, layout) => { - let fn_name = layout_ids - .get(*symbol, layout) - .to_symbol_string(*symbol, &env.interns); - let ptr = env - .module - .get_function(fn_name.as_str()) - .unwrap_or_else(|| panic!("Could not get pointer to unknown function {:?}", symbol)) - .as_global_value() - .as_pointer_value(); - - BasicValueEnum::PointerValue(ptr) - } - CallByPointer(sub_expr, args, _var) => { - let mut arg_vals: Vec = Vec::with_capacity_in(args.len(), env.arena); - - for arg in args.iter() { - arg_vals.push(build_expr(env, layout_ids, scope, parent, arg)); - } - - let call = match build_expr(env, layout_ids, scope, parent, sub_expr) { - BasicValueEnum::PointerValue(ptr) => { - env.builder.build_call(ptr, arg_vals.as_slice(), "tmp") - } - non_ptr => { - panic!( - "Tried to call by pointer, but encountered a non-pointer: {:?}", - non_ptr - ); - } - }; - - // TODO FIXME this should not be hardcoded! - // Need to look up what calling convention is the right one for that function. - // If this is an external-facing function, it'll use the C calling convention. - // If it's an internal-only function, it should (someday) use the fast calling conention. - call.set_call_convention(C_CALL_CONV); - - call.try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) - } - Load(symbol) => load_symbol(env, scope, symbol), Str(str_literal) => { if str_literal.is_empty() { panic!("TODO build an empty string in LLVM"); @@ -386,75 +257,80 @@ pub fn build_expr<'a, 'ctx, 'env>( BasicValueEnum::PointerValue(ptr) } } - EmptyArray => { - let struct_type = collection(env.context, env.ptr_bytes); + } +} - // The pointer should be null (aka zero) and the length should be zero, - // so the whole struct should be a const_zero - BasicValueEnum::StructValue(struct_type.const_zero()) - } - Array { elem_layout, elems } => { - let ctx = env.context; - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let builder = env.builder; +pub fn build_exp_expr<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + scope: &Scope<'a, 'ctx>, + parent: FunctionValue<'ctx>, + expr: &roc_mono::ir::Expr<'a>, +) -> BasicValueEnum<'ctx> { + use roc_mono::ir::CallType::*; + use roc_mono::ir::Expr::*; - if elems.is_empty() { - empty_list(env) - } else { - let len_u64 = elems.len() as u64; - let elem_bytes = elem_layout.stack_size(env.ptr_bytes) as u64; + match expr { + Literal(literal) => build_exp_literal(env, literal), + RunLowLevel(op, symbols) => run_low_level(env, scope, parent, *op, symbols), - let ptr = { - let bytes_len = elem_bytes * len_u64; - let len_type = env.ptr_int(); - let len = len_type.const_int(bytes_len, false); + FunctionCall { + call_type: ByName(name), + layout, + args, + .. + } => { + let mut arg_tuples: Vec = Vec::with_capacity_in(args.len(), env.arena); - env.builder - .build_array_malloc(elem_type, len, "create_list_ptr") - .unwrap() - - // TODO check if malloc returned null; if so, runtime error for OOM! - }; - - // Copy the elements from the list literal into the array - for (index, elem) in elems.iter().enumerate() { - let index_val = ctx.i64_type().const_int(index as u64, false); - let elem_ptr = - unsafe { builder.build_in_bounds_gep(ptr, &[index_val], "index") }; - let val = build_expr(env, layout_ids, &scope, parent, &elem); - - builder.build_store(elem_ptr, val); - } - - let ptr_bytes = env.ptr_bytes; - let int_type = ptr_int(ctx, ptr_bytes); - let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "list_cast_ptr"); - let struct_type = collection(ctx, ptr_bytes); - let len = BasicValueEnum::IntValue(env.ptr_int().const_int(len_u64, false)); - let mut struct_val; - - // Store the pointer - struct_val = builder - .build_insert_value( - struct_type.get_undef(), - ptr_as_int, - Builtin::WRAPPER_PTR, - "insert_ptr", - ) - .unwrap(); - - // Store the length - struct_val = builder - .build_insert_value(struct_val, len, Builtin::WRAPPER_LEN, "insert_len") - .unwrap(); - - // Bitcast to an array of raw bytes - builder.build_bitcast( - struct_val.into_struct_value(), - collection(ctx, ptr_bytes), - "cast_collection", - ) + for symbol in args.iter() { + arg_tuples.push(load_symbol(env, scope, symbol)); } + + call_with_args( + env, + layout_ids, + layout, + *name, + parent, + arg_tuples.into_bump_slice(), + ) + } + + FunctionCall { + call_type: ByPointer(name), + layout: _, + args, + .. + } => { + let sub_expr = load_symbol(env, scope, name); + + let mut arg_vals: Vec = Vec::with_capacity_in(args.len(), env.arena); + + for arg in args.iter() { + arg_vals.push(load_symbol(env, scope, arg)); + } + + let call = match sub_expr { + BasicValueEnum::PointerValue(ptr) => { + env.builder.build_call(ptr, arg_vals.as_slice(), "tmp") + } + non_ptr => { + panic!( + "Tried to call by pointer, but encountered a non-pointer: {:?}", + non_ptr + ); + } + }; + + // TODO FIXME this should not be hardcoded! + // Need to look up what calling convention is the right one for that function. + // If this is an external-facing function, it'll use the C calling convention. + // If it's an internal-only function, it should (someday) use the fast calling conention. + call.set_call_convention(C_CALL_CONV); + + call.try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) } Struct(sorted_fields) => { @@ -467,9 +343,10 @@ pub fn build_expr<'a, 'ctx, 'env>( let mut field_types = Vec::with_capacity_in(num_fields, env.arena); let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); - for (field_expr, field_layout) in sorted_fields.iter() { + for symbol in sorted_fields.iter() { // Zero-sized fields have no runtime representation. // The layout of the struct expects them to be dropped! + let (field_expr, field_layout) = load_symbol_and_layout(env, scope, symbol); if field_layout.stack_size(ptr_bytes) != 0 { field_types.push(basic_type_from_layout( env.arena, @@ -478,7 +355,7 @@ pub fn build_expr<'a, 'ctx, 'env>( env.ptr_bytes, )); - field_vals.push(build_expr(env, layout_ids, &scope, parent, field_expr)); + field_vals.push(field_expr); } } @@ -501,6 +378,7 @@ pub fn build_expr<'a, 'ctx, 'env>( BasicValueEnum::StructValue(struct_val.into_struct_value()) } } + Tag { union_size, arguments, @@ -517,11 +395,11 @@ pub fn build_expr<'a, 'ctx, 'env>( let mut field_types = Vec::with_capacity_in(num_fields, env.arena); let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); - for (field_expr, field_layout) in it { + for field_symbol in it { + let (val, field_layout) = load_symbol_and_layout(env, scope, field_symbol); // Zero-sized fields have no runtime representation. // The layout of the struct expects them to be dropped! if field_layout.stack_size(ptr_bytes) != 0 { - let val = build_expr(env, layout_ids, &scope, parent, field_expr); let field_type = basic_type_from_layout( env.arena, env.context, @@ -553,14 +431,16 @@ pub fn build_expr<'a, 'ctx, 'env>( BasicValueEnum::StructValue(struct_val.into_struct_value()) } } + Tag { arguments, tag_layout, + union_size, .. } => { + debug_assert!(*union_size > 1); let ptr_size = env.ptr_bytes; - let whole_size = tag_layout.stack_size(ptr_size); let mut filler = tag_layout.stack_size(ptr_size); let ctx = env.context; @@ -571,15 +451,15 @@ pub fn build_expr<'a, 'ctx, 'env>( let mut field_types = Vec::with_capacity_in(num_fields, env.arena); let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); - for (field_expr, field_layout) in arguments.iter() { + for field_symbol in arguments.iter() { + let (val, field_layout) = load_symbol_and_layout(env, scope, field_symbol); let field_size = field_layout.stack_size(ptr_size); // Zero-sized fields have no runtime representation. // The layout of the struct expects them to be dropped! if field_size != 0 { - let val = build_expr(env, layout_ids, &scope, parent, field_expr); let field_type = - basic_type_from_layout(env.arena, env.context, &field_layout, ptr_size); + basic_type_from_layout(env.arena, env.context, field_layout, ptr_size); field_types.push(field_type); field_vals.push(val); @@ -628,27 +508,18 @@ pub fn build_expr<'a, 'ctx, 'env>( // This tricks comes from // https://github.com/raviqqe/ssf/blob/bc32aae68940d5bddf5984128e85af75ca4f4686/ssf-llvm/src/expression_compiler.rs#L116 - let array_type = ctx.i8_type().array_type(whole_size); + let internal_type = + basic_type_from_layout(env.arena, env.context, tag_layout, env.ptr_bytes); - let result = cast_basic_basic( + cast_basic_basic( builder, struct_val.into_struct_value().into(), - array_type.into(), - ); - - // For unclear reasons, we can't cast an array to a struct on the other side. - // the solution is to wrap the array in a struct (yea...) - let wrapper_type = ctx.struct_type(&[array_type.into()], false); - let mut wrapper_val = wrapper_type.const_zero().into(); - wrapper_val = builder - .build_insert_value(wrapper_val, result, 0, "insert_field") - .unwrap(); - - wrapper_val.into_struct_value().into() + internal_type, + ) } AccessAtIndex { index, - expr, + structure, is_unwrapped, .. } if *is_unwrapped => { @@ -661,7 +532,7 @@ pub fn build_expr<'a, 'ctx, 'env>( // right away. However, that struct might have only one field which // is not zero-sized, which would make it unwrapped. If that happens, // we must be - match build_expr(env, layout_ids, &scope, parent, expr) { + match load_symbol(env, scope, structure) { StructValue(argument) => builder .build_extract_value( argument, @@ -679,7 +550,7 @@ pub fn build_expr<'a, 'ctx, 'env>( AccessAtIndex { index, - expr, + structure, field_layouts, .. } => { @@ -702,7 +573,7 @@ pub fn build_expr<'a, 'ctx, 'env>( .struct_type(field_types.into_bump_slice(), false); // cast the argument bytes into the desired shape for this tag - let argument = build_expr(env, layout_ids, &scope, parent, expr).into_struct_value(); + let argument = load_symbol(env, scope, structure).into_struct_value(); let struct_value = cast_struct_struct(builder, argument, struct_type); @@ -710,16 +581,486 @@ pub fn build_expr<'a, 'ctx, 'env>( .build_extract_value(struct_value, *index as u32, "") .expect("desired field did not decode") } - RuntimeErrorFunction(_) => { - todo!("LLVM build runtime error function of {:?}", expr); + EmptyArray => empty_polymorphic_list(env), + Array { elem_layout, elems } => list_literal(env, scope, elem_layout, elems), + FunctionPointer(symbol, layout) => { + let fn_name = layout_ids + .get(*symbol, layout) + .to_symbol_string(*symbol, &env.interns); + let ptr = env + .module + .get_function(fn_name.as_str()) + .unwrap_or_else(|| panic!("Could not get pointer to unknown function {:?}", symbol)) + .as_global_value() + .as_pointer_value(); + + BasicValueEnum::PointerValue(ptr) } - RuntimeError(_) => { - todo!("LLVM build runtime error of {:?}", expr); - } - RunLowLevel(op, args) => run_low_level(env, layout_ids, scope, parent, *op, args), + RuntimeErrorFunction(_) => todo!(), } } +pub fn build_exp_stmt<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + scope: &mut Scope<'a, 'ctx>, + parent: FunctionValue<'ctx>, + stmt: &roc_mono::ir::Stmt<'a>, +) -> BasicValueEnum<'ctx> { + use roc_mono::ir::Stmt::*; + + match stmt { + Let(symbol, expr, layout, cont) => { + let context = &env.context; + + let val = build_exp_expr(env, layout_ids, &scope, parent, &expr); + let expr_bt = basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes); + let alloca = + create_entry_block_alloca(env, parent, expr_bt, symbol.ident_string(&env.interns)); + + env.builder.build_store(alloca, val); + + // Make a new scope which includes the binding we just encountered. + // This should be done *after* compiling the bound expr, since any + // recursive (in the LetRec sense) bindings should already have + // been extracted as procedures. Nothing in here should need to + // access itself! + // scope = scope.clone(); + + scope.insert(*symbol, (layout.clone(), alloca)); + let result = build_exp_stmt(env, layout_ids, scope, parent, cont); + scope.remove(symbol); + + result + } + Ret(symbol) => load_symbol(env, scope, symbol), + + Cond { + branching_symbol, + pass: pass_stmt, + fail: fail_stmt, + ret_layout, + .. + } => { + let ret_type = + basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); + + let cond_expr = load_symbol(env, scope, branching_symbol); + + match cond_expr { + IntValue(value) => { + // This is a call tobuild_basic_phi2, except inlined to prevent + // problems with lifetimes and closures involving layout_ids. + let builder = env.builder; + let context = env.context; + + // build blocks + let then_block = context.append_basic_block(parent, "then"); + let else_block = context.append_basic_block(parent, "else"); + let mut blocks: std::vec::Vec<( + &dyn inkwell::values::BasicValue<'_>, + inkwell::basic_block::BasicBlock<'_>, + )> = std::vec::Vec::with_capacity(2); + let cont_block = context.append_basic_block(parent, "branchcont"); + + builder.build_conditional_branch(value, then_block, else_block); + + // build then block + builder.position_at_end(then_block); + let then_val = build_exp_stmt(env, layout_ids, scope, parent, pass_stmt); + if then_block.get_terminator().is_none() { + builder.build_unconditional_branch(cont_block); + let then_block = builder.get_insert_block().unwrap(); + blocks.push((&then_val, then_block)); + } + + // build else block + builder.position_at_end(else_block); + let else_val = build_exp_stmt(env, layout_ids, scope, parent, fail_stmt); + if else_block.get_terminator().is_none() { + let else_block = builder.get_insert_block().unwrap(); + builder.build_unconditional_branch(cont_block); + blocks.push((&else_val, else_block)); + } + + // emit merge block + if blocks.is_empty() { + // SAFETY there are no other references to this block in this case + unsafe { + cont_block.delete().unwrap(); + } + + // return garbage value + context.i64_type().const_int(0, false).into() + } else { + builder.position_at_end(cont_block); + + let phi = builder.build_phi(ret_type, "branch"); + + // phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); + phi.add_incoming(&blocks); + + phi.as_basic_value() + } + } + _ => panic!( + "Tried to make a branch out of an invalid condition: cond_expr = {:?}", + cond_expr, + ), + } + } + + Switch { + branches, + default_branch, + ret_layout, + cond_layout, + cond_symbol, + } => { + let ret_type = + basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); + + let switch_args = SwitchArgsIr { + cond_layout: cond_layout.clone(), + cond_symbol: *cond_symbol, + branches, + default_branch, + ret_type, + }; + + build_switch_ir(env, layout_ids, scope, parent, switch_args) + } + Join { + id, + parameters, + remainder, + continuation, + } => { + let builder = env.builder; + let context = env.context; + + let mut joinpoint_args = Vec::with_capacity_in(parameters.len(), env.arena); + + for param in parameters.iter() { + let btype = + basic_type_from_layout(env.arena, env.context, ¶m.layout, env.ptr_bytes); + joinpoint_args.push(create_entry_block_alloca( + env, + parent, + btype, + "joinpointarg", + )); + } + + // create new block + let cont_block = context.append_basic_block(parent, "joinpointcont"); + + // store this join point + let joinpoint_args = joinpoint_args.into_bump_slice(); + scope.join_points.insert(*id, (cont_block, joinpoint_args)); + + // construct the blocks that may jump to this join point + build_exp_stmt(env, layout_ids, scope, parent, remainder); + + // remove this join point again + scope.join_points.remove(&id); + + for (ptr, param) in joinpoint_args.iter().zip(parameters.iter()) { + scope.insert(param.symbol, (param.layout.clone(), *ptr)); + } + + let phi_block = builder.get_insert_block().unwrap(); + + // put the cont block at the back + builder.position_at_end(cont_block); + + // put the continuation in + let result = build_exp_stmt(env, layout_ids, scope, parent, continuation); + + cont_block.move_after(phi_block).unwrap(); + + result + } + Jump(join_point, arguments) => { + let builder = env.builder; + let context = env.context; + let (cont_block, argument_pointers) = scope.join_points.get(join_point).unwrap(); + + for (pointer, argument) in argument_pointers.iter().zip(arguments.iter()) { + let value = load_symbol(env, scope, argument); + builder.build_store(*pointer, value); + } + + builder.build_unconditional_branch(*cont_block); + + // This doesn't currently do anything + context.i64_type().const_zero().into() + } + Inc(symbol, cont) => { + let (value, layout) = load_symbol_and_layout(env, scope, symbol); + let layout = layout.clone(); + + match layout { + Layout::Builtin(Builtin::List(MemoryMode::Refcounted, _)) => { + increment_refcount_list(env, value.into_struct_value()); + build_exp_stmt(env, layout_ids, scope, parent, cont) + } + _ => build_exp_stmt(env, layout_ids, scope, parent, cont), + } + } + Dec(symbol, cont) => { + let (value, layout) = load_symbol_and_layout(env, scope, symbol); + let layout = layout.clone(); + + if layout.contains_refcounted() { + decrement_refcount_layout(env, parent, value, &layout); + } + + build_exp_stmt(env, layout_ids, scope, parent, cont) + } + _ => todo!("unsupported expr {:?}", stmt), + } +} + +fn refcount_is_one_comparison<'ctx>( + builder: &Builder<'ctx>, + context: &'ctx Context, + refcount: IntValue<'ctx>, +) -> IntValue<'ctx> { + let refcount_one: IntValue<'ctx> = context.i64_type().const_int((std::usize::MAX) as _, false); + // Note: Check for refcount < refcount_1 as the "true" condition, + // to avoid misprediction. (In practice this should usually pass, + // and CPUs generally default to predicting that a forward jump + // shouldn't be taken; that is, they predict "else" won't be taken.) + builder.build_int_compare( + IntPredicate::EQ, + refcount, + refcount_one, + "refcount_one_check", + ) +} + +#[allow(dead_code)] +fn list_get_refcount_ptr<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + list_wrapper: StructValue<'ctx>, +) -> PointerValue<'ctx> { + let builder = env.builder; + let ctx = env.context; + + // pointer to usize + let ptr_bytes = env.ptr_bytes; + let int_type = ptr_int(ctx, ptr_bytes); + + // fetch the pointer to the array data, as an integer + let ptr_as_int = builder + .build_extract_value(list_wrapper, Builtin::WRAPPER_PTR, "read_list_ptr") + .unwrap() + .into_int_value(); + + // subtract ptr_size, to access the refcount + let refcount_ptr = builder.build_int_sub( + ptr_as_int, + ctx.i64_type().const_int(env.ptr_bytes as u64, false), + "make_refcount_ptr", + ); + + builder.build_int_to_ptr( + refcount_ptr, + int_type.ptr_type(AddressSpace::Generic), + "get_refcount_ptr", + ) +} + +fn decrement_refcount_layout<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + value: BasicValueEnum<'ctx>, + layout: &Layout<'a>, +) { + use Layout::*; + + match layout { + Builtin(builtin) => decrement_refcount_builtin(env, parent, value, builtin), + Struct(layouts) => { + let wrapper_struct = value.into_struct_value(); + + for (i, field_layout) in layouts.iter().enumerate() { + if field_layout.contains_refcounted() { + let field_ptr = env + .builder + .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") + .unwrap(); + + decrement_refcount_layout(env, parent, field_ptr, field_layout) + } + } + } + Union(tags) => { + debug_assert!(!tags.is_empty()); + let wrapper_struct = value.into_struct_value(); + + // read the tag_id + let tag_id = env + .builder + .build_extract_value(wrapper_struct, 0, "read_tag_id") + .unwrap() + .into_int_value(); + + // next, make a jump table for all possible values of the tag_id + let mut cases = Vec::with_capacity_in(tags.len(), env.arena); + + let merge_block = env.context.append_basic_block(parent, "decrement_merge"); + + for (tag_id, field_layouts) in tags.iter().enumerate() { + let block = env.context.append_basic_block(parent, "tag_id_decrement"); + env.builder.position_at_end(block); + + for (i, field_layout) in field_layouts.iter().enumerate() { + if field_layout.contains_refcounted() { + let field_ptr = env + .builder + .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") + .unwrap(); + + decrement_refcount_layout(env, parent, field_ptr, field_layout) + } + } + + env.builder.build_unconditional_branch(merge_block); + + cases.push((env.context.i8_type().const_int(tag_id as u64, false), block)); + } + + let (_, default_block) = cases.pop().unwrap(); + + env.builder.build_switch(tag_id, default_block, &cases); + + env.builder.position_at_end(merge_block); + } + + FunctionPointer(_, _) | Pointer(_) => {} + } +} + +#[inline(always)] +fn decrement_refcount_builtin<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + value: BasicValueEnum<'ctx>, + builtin: &Builtin<'a>, +) { + use Builtin::*; + + match builtin { + List(MemoryMode::Refcounted, element_layout) => { + if element_layout.contains_refcounted() { + // TODO decrement all values + } + let wrapper_struct = value.into_struct_value(); + decrement_refcount_list(env, parent, wrapper_struct); + } + List(MemoryMode::Unique, _element_layout) => { + // do nothing + } + Set(element_layout) => { + if element_layout.contains_refcounted() { + // TODO decrement all values + } + let wrapper_struct = value.into_struct_value(); + decrement_refcount_list(env, parent, wrapper_struct); + } + Map(key_layout, value_layout) => { + if key_layout.contains_refcounted() || value_layout.contains_refcounted() { + // TODO decrement all values + } + + let wrapper_struct = value.into_struct_value(); + decrement_refcount_list(env, parent, wrapper_struct); + } + _ => {} + } +} + +fn increment_refcount_list<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + original_wrapper: StructValue<'ctx>, +) { + let builder = env.builder; + let ctx = env.context; + + let refcount_ptr = list_get_refcount_ptr(env, original_wrapper); + + let refcount = env + .builder + .build_load(refcount_ptr, "get_refcount") + .into_int_value(); + + // our refcount 0 is actually usize::MAX, so incrementing the refcount means decrementing this value. + let decremented = env.builder.build_int_sub( + refcount, + ctx.i64_type().const_int(1 as u64, false), + "incremented_refcount", + ); + + // Mutate the new array in-place to change the element. + builder.build_store(refcount_ptr, decremented); +} + +fn decrement_refcount_list<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + original_wrapper: StructValue<'ctx>, +) { + let builder = env.builder; + let ctx = env.context; + + let refcount_ptr = list_get_refcount_ptr(env, original_wrapper); + + let refcount = env + .builder + .build_load(refcount_ptr, "get_refcount") + .into_int_value(); + + let comparison = refcount_is_one_comparison(builder, env.context, refcount); + + // build blocks + let then_block = ctx.append_basic_block(parent, "then"); + let else_block = ctx.append_basic_block(parent, "else"); + let cont_block = ctx.append_basic_block(parent, "branchcont"); + + builder.build_conditional_branch(comparison, then_block, else_block); + + // build then block + { + builder.position_at_end(then_block); + // our refcount 0 is actually usize::MAX, so decrementing the refcount means incrementing this value. + let decremented = env.builder.build_int_add( + ctx.i64_type().const_int(1 as u64, false), + refcount, + "decremented_refcount", + ); + + // Mutate the new array in-place to change the element. + builder.build_store(refcount_ptr, decremented); + + builder.build_unconditional_branch(cont_block); + } + + // build else block + { + builder.position_at_end(else_block); + if !env.leak { + let free = builder.build_free(refcount_ptr); + builder.insert_instruction(&free, None); + } + builder.build_unconditional_branch(cont_block); + } + + // emit merge block + builder.position_at_end(cont_block); +} + fn load_symbol<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, scope: &Scope<'a, 'ctx>, @@ -733,6 +1074,21 @@ fn load_symbol<'a, 'ctx, 'env>( } } +fn load_symbol_and_layout<'a, 'ctx, 'env, 'b>( + env: &Env<'a, 'ctx, 'env>, + scope: &'b Scope<'a, 'ctx>, + symbol: &Symbol, +) -> (BasicValueEnum<'ctx>, &'b Layout<'a>) { + match scope.get(symbol) { + Some((layout, ptr)) => ( + env.builder + .build_load(*ptr, symbol.ident_string(&env.interns)), + layout, + ), + None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope), + } +} + /// Cast a struct to another struct of the same (or smaller?) size fn cast_struct_struct<'ctx>( builder: &Builder<'ctx>, @@ -758,7 +1114,7 @@ fn cast_basic_basic<'ctx>( .build_bitcast( argument_pointer, to_type.ptr_type(inkwell::AddressSpace::Generic), - "", + "cast_basic_basic", ) .into_pointer_value(); @@ -781,33 +1137,38 @@ fn extract_tag_discriminant<'a, 'ctx, 'env>( .into_int_value() } -struct SwitchArgs<'a, 'ctx> { - pub cond_expr: &'a Expr<'a>, +struct SwitchArgsIr<'a, 'ctx> { + pub cond_symbol: Symbol, pub cond_layout: Layout<'a>, - pub branches: &'a [(u64, Expr<'a>)], - pub default_branch: &'a Expr<'a>, + pub branches: &'a [(u64, roc_mono::ir::Stmt<'a>)], + pub default_branch: &'a roc_mono::ir::Stmt<'a>, pub ret_type: BasicTypeEnum<'ctx>, } -fn build_switch<'a, 'ctx, 'env>( +fn build_switch_ir<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, scope: &Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, - switch_args: SwitchArgs<'a, 'ctx>, + switch_args: SwitchArgsIr<'a, 'ctx>, ) -> BasicValueEnum<'ctx> { let arena = env.arena; let builder = env.builder; let context = env.context; - let SwitchArgs { + let SwitchArgsIr { branches, - cond_expr, + cond_symbol, mut cond_layout, default_branch, ret_type, .. } = switch_args; + let mut copy = scope.clone(); + let scope = &mut copy; + + let cond_symbol = &cond_symbol; + let cont_block = context.append_basic_block(parent, "cont"); // Build the condition @@ -815,7 +1176,7 @@ fn build_switch<'a, 'ctx, 'env>( Layout::Builtin(Builtin::Float64) => { // float matches are done on the bit pattern cond_layout = Layout::Builtin(Builtin::Int64); - let full_cond = build_expr(env, layout_ids, scope, parent, cond_expr); + let full_cond = load_symbol(env, scope, cond_symbol); builder .build_bitcast(full_cond, env.context.i64_type(), "") @@ -824,14 +1185,11 @@ fn build_switch<'a, 'ctx, 'env>( Layout::Union(_) => { // we match on the discriminant, not the whole Tag cond_layout = Layout::Builtin(Builtin::Int64); - let full_cond = - build_expr(env, layout_ids, scope, parent, cond_expr).into_struct_value(); + let full_cond = load_symbol(env, scope, cond_symbol).into_struct_value(); extract_tag_discriminant(env, full_cond) } - Layout::Builtin(_) => { - build_expr(env, layout_ids, scope, parent, cond_expr).into_int_value() - } + Layout::Builtin(_) => load_symbol(env, scope, cond_symbol).into_int_value(), other => todo!("Build switch value from layout: {:?}", other), }; @@ -871,32 +1229,42 @@ fn build_switch<'a, 'ctx, 'env>( for ((_, branch_expr), (_, block)) in branches.iter().zip(cases) { builder.position_at_end(block); - let branch_val = build_expr(env, layout_ids, scope, parent, branch_expr); + let branch_val = build_exp_stmt(env, layout_ids, scope, parent, branch_expr); - builder.build_unconditional_branch(cont_block); - - incoming.push((branch_val, block)); + if block.get_terminator().is_none() { + builder.build_unconditional_branch(cont_block); + incoming.push((branch_val, block)); + } } // The block for the conditional's default branch. builder.position_at_end(default_block); - let default_val = build_expr(env, layout_ids, scope, parent, default_branch); + let default_val = build_exp_stmt(env, layout_ids, scope, parent, default_branch); - builder.build_unconditional_branch(cont_block); - - incoming.push((default_val, default_block)); - - // emit merge block - builder.position_at_end(cont_block); - - let phi = builder.build_phi(ret_type, "branch"); - - for (branch_val, block) in incoming { - phi.add_incoming(&[(&Into::::into(branch_val), block)]); + if default_block.get_terminator().is_none() { + builder.build_unconditional_branch(cont_block); + incoming.push((default_val, default_block)); } - phi.as_basic_value() + // emit merge block + if incoming.is_empty() { + unsafe { + cont_block.delete().unwrap(); + } + // produce unused garbage value + context.i64_type().const_zero().into() + } else { + builder.position_at_end(cont_block); + + let phi = builder.build_phi(ret_type, "branch"); + + for (branch_val, block) in incoming { + phi.add_incoming(&[(&Into::::into(branch_val), block)]); + } + + phi.as_basic_value() + } } fn build_basic_phi2<'a, 'ctx, 'env, PassFn, FailFn>( @@ -980,7 +1348,7 @@ pub fn build_proc_header<'a, 'ctx, 'env>( layout_ids: &mut LayoutIds<'a>, symbol: Symbol, layout: &Layout<'a>, - proc: &Proc<'a>, + proc: &roc_mono::ir::Proc<'a>, ) -> (FunctionValue<'ctx>, Vec<'a, BasicTypeEnum<'ctx>>) { let args = proc.args; let arena = env.arena; @@ -1013,7 +1381,7 @@ pub fn build_proc_header<'a, 'ctx, 'env>( pub fn build_proc<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - proc: Proc<'a>, + proc: roc_mono::ir::Proc<'a>, fn_val: FunctionValue<'ctx>, arg_basic_types: Vec<'a, BasicTypeEnum<'ctx>>, ) { @@ -1026,7 +1394,7 @@ pub fn build_proc<'a, 'ctx, 'env>( builder.position_at_end(entry); - let mut scope = ImMap::default(); + let mut scope = Scope::default(); // Add args to scope for ((arg_val, arg_type), (layout, arg_symbol)) in @@ -1042,7 +1410,7 @@ pub fn build_proc<'a, 'ctx, 'env>( scope.insert(*arg_symbol, (layout.clone(), alloca)); } - let body = build_expr(env, layout_ids, &scope, fn_val, &proc.body); + let body = build_exp_stmt(env, layout_ids, &mut scope, fn_val, &proc.body); builder.build_return(Some(&body)); } @@ -1057,6 +1425,67 @@ pub fn verify_fn(fn_val: FunctionValue<'_>) { } } +pub fn allocate_list<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + elem_layout: &Layout<'a>, + length: IntValue<'ctx>, +) -> PointerValue<'ctx> { + let builder = env.builder; + let ctx = env.context; + + let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); + let elem_bytes = elem_layout.stack_size(env.ptr_bytes) as u64; + + let len_type = env.ptr_int(); + // bytes per element + let bytes_len = len_type.const_int(elem_bytes, false); + let offset = (env.ptr_bytes as u64).max(elem_bytes); + + let ptr = { + let len = builder.build_int_mul(bytes_len, length, "data_length"); + let len = + builder.build_int_add(len, len_type.const_int(offset, false), "add_refcount_space"); + + env.builder + .build_array_malloc(ctx.i8_type(), len, "create_list_ptr") + .unwrap() + + // TODO check if malloc returned null; if so, runtime error for OOM! + }; + + // We must return a pointer to the first element: + let ptr_bytes = env.ptr_bytes; + let int_type = ptr_int(ctx, ptr_bytes); + let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "list_cast_ptr"); + let incremented = builder.build_int_add( + ptr_as_int, + ctx.i64_type().const_int(offset, false), + "increment_list_ptr", + ); + + let ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); + let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "list_cast_ptr"); + + // subtract ptr_size, to access the refcount + let refcount_ptr = builder.build_int_sub( + incremented, + ctx.i64_type().const_int(env.ptr_bytes as u64, false), + "refcount_ptr", + ); + + let refcount_ptr = builder.build_int_to_ptr( + refcount_ptr, + int_type.ptr_type(AddressSpace::Generic), + "make ptr", + ); + + // put our "refcount 0" in the first slot + let ref_count_zero = ctx.i64_type().const_int(std::usize::MAX as u64, false); + builder.build_store(refcount_ptr, ref_count_zero); + + list_element_ptr +} + /// List.single : a -> List a fn list_single<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, @@ -1066,20 +1495,9 @@ fn list_single<'a, 'ctx, 'env>( let builder = env.builder; let ctx = env.context; - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let elem_bytes = elem_layout.stack_size(env.ptr_bytes) as u64; - - let ptr = { - let bytes_len = elem_bytes; - let len_type = env.ptr_int(); - let len = len_type.const_int(bytes_len, false); - - env.builder - .build_array_malloc(elem_type, len, "create_list_ptr") - .unwrap() - - // TODO check if malloc returned null; if so, runtime error for OOM! - }; + // allocate a list of size 1 on the heap + let size = ctx.i64_type().const_int(1, false); + let ptr = allocate_list(env, elem_layout, size); // Put the element into the list let elem_ptr = unsafe { @@ -1136,7 +1554,6 @@ fn list_repeat<'a, 'ctx, 'env>( ) -> BasicValueEnum<'ctx> { let builder = env.builder; let ctx = env.context; - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); // list_len > 0 // We have to do a loop below, continuously adding the `elem` @@ -1154,10 +1571,7 @@ fn list_repeat<'a, 'ctx, 'env>( let build_then = || { // Allocate space for the new array that we'll copy into. - let list_ptr = builder - .build_array_malloc(elem_type, list_len, "create_list_ptr") - .unwrap(); - + let list_ptr = allocate_list(env, elem_layout, list_len); // TODO check if malloc returned null; if so, runtime error for OOM! let index_name = "#index"; @@ -1231,7 +1645,7 @@ fn list_repeat<'a, 'ctx, 'env>( ) }; - let build_else = || empty_list(env); + let build_else = || empty_polymorphic_list(env); let struct_type = collection(ctx, env.ptr_bytes); @@ -1245,19 +1659,20 @@ fn list_repeat<'a, 'ctx, 'env>( ) } +// #[allow(clippy::cognitive_complexity)] #[inline(always)] -#[allow(clippy::cognitive_complexity)] fn call_with_args<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, layout: &Layout<'a>, symbol: Symbol, _parent: FunctionValue<'ctx>, - args: &[(BasicValueEnum<'ctx>, &'a Layout<'a>)], + args: &[BasicValueEnum<'ctx>], ) -> BasicValueEnum<'ctx> { let fn_name = layout_ids .get(symbol, layout) .to_symbol_string(symbol, &env.interns); + let fn_val = env .module .get_function(fn_name.as_str()) @@ -1268,15 +1683,8 @@ fn call_with_args<'a, 'ctx, 'env>( panic!("Unrecognized non-builtin function: {:?}", symbol) } }); - let mut arg_vals: Vec = Vec::with_capacity_in(args.len(), env.arena); - for (arg, _layout) in args.iter() { - arg_vals.push(*arg); - } - - let call = env - .builder - .build_call(fn_val, arg_vals.into_bump_slice(), "call"); + let call = env.builder.build_call(fn_val, args, "call"); call.set_call_convention(fn_val.get_call_conventions()); @@ -1367,13 +1775,11 @@ fn clone_nonempty_list<'a, 'ctx, 'env>( .const_int(elem_layout.stack_size(env.ptr_bytes) as u64, false); let size = env .builder - .build_int_mul(elem_bytes, list_len, "mul_len_by_elem_bytes"); + .build_int_mul(elem_bytes, list_len, "clone_mul_len_by_elem_bytes"); // Allocate space for the new array that we'll copy into. - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let clone_ptr = builder - .build_array_malloc(elem_type, list_len, "list_ptr") - .unwrap(); + let clone_ptr = allocate_list(env, elem_layout, list_len); + let int_type = ptr_int(ctx, ptr_bytes); let ptr_as_int = builder.build_ptr_to_int(clone_ptr, int_type, "list_cast_ptr"); @@ -1420,11 +1826,84 @@ fn clone_nonempty_list<'a, 'ctx, 'env>( (answer, clone_ptr) } +#[derive(Debug)] enum InPlace { InPlace, Clone, } +fn empty_polymorphic_list<'a, 'ctx, 'env>(env: &Env<'a, 'ctx, 'env>) -> BasicValueEnum<'ctx> { + let ctx = env.context; + + let struct_type = collection(ctx, env.ptr_bytes); + + // The pointer should be null (aka zero) and the length should be zero, + // so the whole struct should be a const_zero + BasicValueEnum::StructValue(struct_type.const_zero()) +} + +fn list_literal<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + scope: &Scope<'a, 'ctx>, + elem_layout: &Layout<'a>, + elems: &&[Symbol], +) -> BasicValueEnum<'ctx> { + let ctx = env.context; + let builder = env.builder; + + let len_u64 = elems.len() as u64; + let elem_bytes = elem_layout.stack_size(env.ptr_bytes) as u64; + + let ptr = { + let bytes_len = elem_bytes * len_u64; + let len_type = env.ptr_int(); + let len = len_type.const_int(bytes_len, false); + + allocate_list(env, elem_layout, len) + + // TODO check if malloc returned null; if so, runtime error for OOM! + }; + + // Copy the elements from the list literal into the array + for (index, symbol) in elems.iter().enumerate() { + let val = load_symbol(env, scope, symbol); + let index_val = ctx.i64_type().const_int(index as u64, false); + let elem_ptr = unsafe { builder.build_in_bounds_gep(ptr, &[index_val], "index") }; + + builder.build_store(elem_ptr, val); + } + + let ptr_bytes = env.ptr_bytes; + let int_type = ptr_int(ctx, ptr_bytes); + let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "list_cast_ptr"); + let struct_type = collection(ctx, ptr_bytes); + let len = BasicValueEnum::IntValue(env.ptr_int().const_int(len_u64, false)); + let mut struct_val; + + // Store the pointer + struct_val = builder + .build_insert_value( + struct_type.get_undef(), + ptr_as_int, + Builtin::WRAPPER_PTR, + "insert_ptr", + ) + .unwrap(); + + // Store the length + struct_val = builder + .build_insert_value(struct_val, len, Builtin::WRAPPER_LEN, "insert_len") + .unwrap(); + + // Bitcast to an array of raw bytes + builder.build_bitcast( + struct_val.into_struct_value(), + collection(ctx, ptr_bytes), + "cast_collection", + ) +} + +// TODO investigate: does this cause problems when the layout is known? this value is now not refcounted! fn empty_list<'a, 'ctx, 'env>(env: &Env<'a, 'ctx, 'env>) -> BasicValueEnum<'ctx> { let ctx = env.context; @@ -1485,9 +1964,7 @@ fn list_append<'a, 'ctx, 'env>( .build_int_mul(elem_bytes, list_len, "mul_old_len_by_elem_bytes"); // Allocate space for the new array that we'll copy into. - let clone_ptr = builder - .build_array_malloc(elem_type, new_list_len, "list_ptr") - .unwrap(); + let clone_ptr = allocate_list(env, elem_layout, new_list_len); let int_type = ptr_int(ctx, ptr_bytes); let ptr_as_int = builder.build_ptr_to_int(clone_ptr, int_type, "list_cast_ptr"); @@ -1551,9 +2028,10 @@ fn list_join<'a, 'ctx, 'env>( // If the input list is empty, or if it is a list of empty lists // then simply return an empty list Layout::Builtin(Builtin::EmptyList) - | Layout::Builtin(Builtin::List(Layout::Builtin(Builtin::EmptyList))) => empty_list(env), - Layout::Builtin(Builtin::List(Layout::Builtin(Builtin::List(elem_layout)))) => { - let inner_list_layout = Layout::Builtin(Builtin::List(elem_layout)); + | Layout::Builtin(Builtin::List(_, Layout::Builtin(Builtin::EmptyList))) => empty_list(env), + Layout::Builtin(Builtin::List(_, Layout::Builtin(Builtin::List(_, elem_layout)))) => { + let inner_list_layout = + Layout::Builtin(Builtin::List(MemoryMode::Refcounted, elem_layout)); let builder = env.builder; let ctx = env.context; @@ -2036,11 +2514,10 @@ pub static COLD_CALL_CONV: u32 = 9; fn run_low_level<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, scope: &Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, op: LowLevel, - args: &[(Expr<'a>, Layout<'a>)], + args: &[Symbol], ) -> BasicValueEnum<'ctx> { use LowLevel::*; @@ -2049,7 +2526,7 @@ fn run_low_level<'a, 'ctx, 'env>( // List.len : List * -> Int debug_assert_eq!(args.len(), 1); - let arg = build_expr(env, layout_ids, scope, parent, &args[0].0); + let arg = load_symbol(env, scope, &args[0]); load_list_len(env.builder, arg.into_struct_value()).into() } @@ -2057,17 +2534,16 @@ fn run_low_level<'a, 'ctx, 'env>( // List.single : a -> List a debug_assert_eq!(args.len(), 1); - let arg = build_expr(env, layout_ids, scope, parent, &args[0].0); + let (arg, arg_layout) = load_symbol_and_layout(env, scope, &args[0]); - list_single(env, arg, &args[0].1) + list_single(env, arg, arg_layout) } ListRepeat => { // List.repeat : Int, elem -> List elem debug_assert_eq!(args.len(), 2); - let list_len = build_expr(env, layout_ids, scope, parent, &args[0].0).into_int_value(); - let elem = build_expr(env, layout_ids, scope, parent, &args[1].0); - let elem_layout = &args[1].1; + let list_len = load_symbol(env, scope, &args[0]).into_int_value(); + let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]); list_repeat(env, parent, list_len, elem, elem_layout) } @@ -2075,12 +2551,12 @@ fn run_low_level<'a, 'ctx, 'env>( // List.reverse : List elem -> List elem debug_assert_eq!(args.len(), 1); - let (list, list_layout) = &args[0]; + let list = &args[0]; + let (_, list_layout) = load_symbol_and_layout(env, scope, &args[0]); match list_layout { - Layout::Builtin(Builtin::List(elem_layout)) => { - let wrapper_struct = - build_expr(env, layout_ids, scope, parent, list).into_struct_value(); + Layout::Builtin(Builtin::List(_, elem_layout)) => { + let wrapper_struct = load_symbol(env, scope, list).into_struct_value(); let builder = env.builder; let ctx = env.context; @@ -2104,10 +2580,7 @@ fn run_low_level<'a, 'ctx, 'env>( let ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); - let reversed_list_ptr = env - .builder - .build_array_malloc(elem_type, list_len, "create_reversed_list_ptr") - .unwrap(); + let reversed_list_ptr = allocate_list(env, elem_layout, list_len); // TODO check if malloc returned null; if so, runtime error for OOM! @@ -2215,7 +2688,7 @@ fn run_low_level<'a, 'ctx, 'env>( ) }; - let build_else = || empty_list(env); + let build_else = || empty_polymorphic_list(env); let struct_type = collection(ctx, env.ptr_bytes); @@ -2228,21 +2701,19 @@ fn run_low_level<'a, 'ctx, 'env>( BasicTypeEnum::StructType(struct_type), ) } - Layout::Builtin(Builtin::EmptyList) => empty_list(env), + Layout::Builtin(Builtin::EmptyList) => empty_polymorphic_list(env), _ => { unreachable!("Invalid List layout for List.reverse {:?}", list_layout); } } } - ListConcat => list_concat(env, layout_ids, scope, parent, args), + ListConcat => list_concat(env, scope, parent, args), ListAppend => { // List.append : List elem, elem -> List elem debug_assert_eq!(args.len(), 2); - let original_wrapper = - build_expr(env, layout_ids, scope, parent, &args[0].0).into_struct_value(); - let elem = build_expr(env, layout_ids, scope, parent, &args[1].0); - let elem_layout = &args[1].1; + let original_wrapper = load_symbol(env, scope, &args[0]).into_struct_value(); + let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]); list_append(env, original_wrapper, elem, elem_layout) } @@ -2250,10 +2721,8 @@ fn run_low_level<'a, 'ctx, 'env>( // List.prepend : List elem, elem -> List elem debug_assert_eq!(args.len(), 2); - let original_wrapper = - build_expr(env, layout_ids, scope, parent, &args[0].0).into_struct_value(); - let elem = build_expr(env, layout_ids, scope, parent, &args[1].0); - let elem_layout = &args[1].1; + let original_wrapper = load_symbol(env, scope, &args[0]).into_struct_value(); + let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]); list_prepend(env, original_wrapper, elem, elem_layout) } @@ -2261,18 +2730,15 @@ fn run_low_level<'a, 'ctx, 'env>( // List.join : List (List elem) -> List elem debug_assert_eq!(args.len(), 1); - let (list, outer_list_layout) = &args[0]; - - let outer_wrapper_struct = - build_expr(env, layout_ids, scope, parent, list).into_struct_value(); + let (list, outer_list_layout) = load_symbol_and_layout(env, scope, &args[0]); + let outer_wrapper_struct = list.into_struct_value(); list_join(env, parent, outer_wrapper_struct, outer_list_layout) } NumAbs | NumNeg | NumRound | NumSqrtUnchecked | NumSin | NumCos | NumToFloat => { debug_assert_eq!(args.len(), 1); - let arg = build_expr(env, layout_ids, scope, parent, &args[0].0); - let arg_layout = &args[0].1; + let (arg, arg_layout) = load_symbol_and_layout(env, scope, &args[0]); match arg_layout { Layout::Builtin(arg_builtin) => { @@ -2302,10 +2768,8 @@ fn run_low_level<'a, 'ctx, 'env>( | NumDivUnchecked => { debug_assert_eq!(args.len(), 2); - let lhs_arg = build_expr(env, layout_ids, scope, parent, &args[0].0); - let lhs_layout = &args[0].1; - let rhs_arg = build_expr(env, layout_ids, scope, parent, &args[1].0); - let rhs_layout = &args[1].1; + let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]); + let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]); match (lhs_layout, rhs_layout) { (Layout::Builtin(lhs_builtin), Layout::Builtin(rhs_builtin)) @@ -2343,20 +2807,16 @@ fn run_low_level<'a, 'ctx, 'env>( Eq => { debug_assert_eq!(args.len(), 2); - let lhs_arg = build_expr(env, layout_ids, scope, parent, &args[0].0); - let lhs_layout = &args[0].1; - let rhs_arg = build_expr(env, layout_ids, scope, parent, &args[1].0); - let rhs_layout = &args[1].1; + let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]); + let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]); build_eq(env, lhs_arg, rhs_arg, lhs_layout, rhs_layout) } NotEq => { debug_assert_eq!(args.len(), 2); - let lhs_arg = build_expr(env, layout_ids, scope, parent, &args[0].0); - let lhs_layout = &args[0].1; - let rhs_arg = build_expr(env, layout_ids, scope, parent, &args[1].0); - let rhs_layout = &args[1].1; + let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]); + let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]); build_neq(env, lhs_arg, rhs_arg, lhs_layout, rhs_layout) } @@ -2364,8 +2824,8 @@ fn run_low_level<'a, 'ctx, 'env>( // The (&&) operator debug_assert_eq!(args.len(), 2); - let lhs_arg = build_expr(env, layout_ids, scope, parent, &args[0].0); - let rhs_arg = build_expr(env, layout_ids, scope, parent, &args[1].0); + let lhs_arg = load_symbol(env, scope, &args[0]); + let rhs_arg = load_symbol(env, scope, &args[1]); let bool_val = env.builder.build_and( lhs_arg.into_int_value(), rhs_arg.into_int_value(), @@ -2378,8 +2838,8 @@ fn run_low_level<'a, 'ctx, 'env>( // The (||) operator debug_assert_eq!(args.len(), 2); - let lhs_arg = build_expr(env, layout_ids, scope, parent, &args[0].0); - let rhs_arg = build_expr(env, layout_ids, scope, parent, &args[1].0); + let lhs_arg = load_symbol(env, scope, &args[0]); + let rhs_arg = load_symbol(env, scope, &args[1]); let bool_val = env.builder.build_or( lhs_arg.into_int_value(), rhs_arg.into_int_value(), @@ -2392,7 +2852,7 @@ fn run_low_level<'a, 'ctx, 'env>( // The (!) operator debug_assert_eq!(args.len(), 1); - let arg = build_expr(env, layout_ids, scope, parent, &args[0].0); + let arg = load_symbol(env, scope, &args[0]); let bool_val = env.builder.build_not(arg.into_int_value(), "bool_not"); BasicValueEnum::IntValue(bool_val) @@ -2402,14 +2862,12 @@ fn run_low_level<'a, 'ctx, 'env>( debug_assert_eq!(args.len(), 2); let builder = env.builder; - let (_, list_layout) = &args[0]; - let wrapper_struct = - build_expr(env, layout_ids, scope, parent, &args[0].0).into_struct_value(); - let elem_index = - build_expr(env, layout_ids, scope, parent, &args[1].0).into_int_value(); + let (wrapper_struct, list_layout) = load_symbol_and_layout(env, scope, &args[0]); + let wrapper_struct = wrapper_struct.into_struct_value(); + let elem_index = load_symbol(env, scope, &args[1]).into_int_value(); match list_layout { - Layout::Builtin(Builtin::List(elem_layout)) => { + Layout::Builtin(Builtin::List(_, elem_layout)) => { let ctx = env.context; let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); @@ -2436,18 +2894,9 @@ fn run_low_level<'a, 'ctx, 'env>( ListSet => list_set( parent, &[ - ( - build_expr(env, layout_ids, scope, parent, &args[0].0), - &args[0].1, - ), - ( - build_expr(env, layout_ids, scope, parent, &args[1].0), - &args[1].1, - ), - ( - build_expr(env, layout_ids, scope, parent, &args[2].0), - &args[2].1, - ), + (load_symbol_and_layout(env, scope, &args[0])), + (load_symbol_and_layout(env, scope, &args[1])), + (load_symbol_and_layout(env, scope, &args[2])), ], env, InPlace::Clone, @@ -2455,18 +2904,9 @@ fn run_low_level<'a, 'ctx, 'env>( ListSetInPlace => list_set( parent, &[ - ( - build_expr(env, layout_ids, scope, parent, &args[0].0), - &args[0].1, - ), - ( - build_expr(env, layout_ids, scope, parent, &args[1].0), - &args[1].1, - ), - ( - build_expr(env, layout_ids, scope, parent, &args[2].0), - &args[2].1, - ), + (load_symbol_and_layout(env, scope, &args[0])), + (load_symbol_and_layout(env, scope, &args[1])), + (load_symbol_and_layout(env, scope, &args[2])), ], env, InPlace::InPlace, @@ -2505,10 +2945,9 @@ fn build_int_binop<'a, 'ctx, 'env>( fn list_concat<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, scope: &Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, - args: &[(Expr<'a>, Layout<'a>)], + args: &[Symbol], ) -> BasicValueEnum<'ctx> { // List.concat : List elem, List elem -> List elem debug_assert_eq!(args.len(), 2); @@ -2546,20 +2985,19 @@ fn list_concat<'a, 'ctx, 'env>( let builder = env.builder; let ctx = env.context; - let (first_list, first_list_layout) = &args[0]; + let (first_list, first_list_layout) = load_symbol_and_layout(env, scope, &args[0]); - let (second_list, second_list_layout) = &args[1]; + let (second_list, second_list_layout) = load_symbol_and_layout(env, scope, &args[1]); - let second_list_wrapper = - build_expr(env, layout_ids, scope, parent, second_list).into_struct_value(); + let second_list_wrapper = second_list.into_struct_value(); let second_list_len = load_list_len(builder, second_list_wrapper); match first_list_layout { Layout::Builtin(Builtin::EmptyList) => { match second_list_layout { - Layout::Builtin(Builtin::EmptyList) => empty_list(env), - Layout::Builtin(Builtin::List(elem_layout)) => { + Layout::Builtin(Builtin::EmptyList) => empty_polymorphic_list(env), + Layout::Builtin(Builtin::List(_, elem_layout)) => { // THIS IS A COPY AND PASTE // All the code under the Layout::Builtin(Builtin::List()) match branch // is the same as what is under `if_first_list_is_empty`. Re-using @@ -2586,7 +3024,7 @@ fn list_concat<'a, 'ctx, 'env>( BasicValueEnum::StructValue(new_wrapper) }; - let build_second_list_else = || empty_list(env); + let build_second_list_else = || empty_polymorphic_list(env); build_basic_phi2( env, @@ -2605,9 +3043,8 @@ fn list_concat<'a, 'ctx, 'env>( } } } - Layout::Builtin(Builtin::List(elem_layout)) => { - let first_list_wrapper = - build_expr(env, layout_ids, scope, parent, first_list).into_struct_value(); + Layout::Builtin(Builtin::List(_, elem_layout)) => { + let first_list_wrapper = first_list.into_struct_value(); let first_list_len = load_list_len(builder, first_list_wrapper); @@ -2620,8 +3057,6 @@ fn list_concat<'a, 'ctx, 'env>( let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); let ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let if_second_list_is_empty = || { let (new_wrapper, _) = clone_nonempty_list( env, @@ -2644,7 +3079,7 @@ fn list_concat<'a, 'ctx, 'env>( BasicValueEnum::StructValue(new_wrapper) } - Layout::Builtin(Builtin::List(_)) => { + Layout::Builtin(Builtin::List(_, _)) => { // second_list_len > 0 // We do this check to avoid allocating memory. If the second input // list is empty, then we can just return the first list cloned @@ -2658,14 +3093,8 @@ fn list_concat<'a, 'ctx, 'env>( "add_list_lengths", ); - let combined_list_ptr = env - .builder - .build_array_malloc( - elem_type, - combined_list_len, - "create_combined_list_ptr", - ) - .unwrap(); + let combined_list_ptr = + allocate_list(env, elem_layout, combined_list_len); let index_name = "#index"; let index_alloca = builder.build_alloca(ctx.i64_type(), index_name); @@ -2882,8 +3311,8 @@ fn list_concat<'a, 'ctx, 'env>( let if_first_list_is_empty = || { match second_list_layout { - Layout::Builtin(Builtin::EmptyList) => empty_list(env), - Layout::Builtin(Builtin::List(elem_layout)) => { + Layout::Builtin(Builtin::EmptyList) => empty_polymorphic_list(env), + Layout::Builtin(Builtin::List(_, elem_layout)) => { // second_list_len > 0 // We do this check to avoid allocating memory. If the second input // list is empty, then we can just return the first list cloned @@ -2905,7 +3334,7 @@ fn list_concat<'a, 'ctx, 'env>( BasicValueEnum::StructValue(new_wrapper) }; - let build_second_list_else = || empty_list(env); + let build_second_list_else = || empty_polymorphic_list(env); build_basic_phi2( env, diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index 5d6b91bdd6..9fad55a36b 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -112,12 +112,31 @@ pub fn basic_type_from_layout<'ctx>( let ptr_size = std::mem::size_of::(); let union_size = layout.stack_size(ptr_size as u32); - let array_type = context - .i8_type() - .array_type(union_size) - .as_basic_type_enum(); + // The memory layout of Union is a bit tricky. + // We have tags with different memory layouts, that are part of the same type. + // For llvm, all tags must have the same memory layout. + // + // So, we convert all tags to a layout of bytes of some size. + // It turns out that encoding to i64 for as many elements as possible is + // a nice optimization, the remainder is encoded as bytes. - context.struct_type(&[array_type], false).into() + let num_i64 = union_size / 8; + let num_i8 = union_size % 8; + + let i64_array_type = context.i64_type().array_type(num_i64).as_basic_type_enum(); + + if num_i8 == 0 { + // the object fits perfectly in some number of i64's + // (i.e. the size is a multiple of 8 bytes) + context.struct_type(&[i64_array_type], false).into() + } else { + // there are some trailing bytes at the end + let i8_array_type = context.i8_type().array_type(num_i8).as_basic_type_enum(); + + context + .struct_type(&[i64_array_type, i8_array_type], false) + .into() + } } Builtin(builtin) => match builtin { @@ -137,7 +156,7 @@ pub fn basic_type_from_layout<'ctx>( .as_basic_type_enum(), Map(_, _) | EmptyMap => panic!("TODO layout_to_basic_type for Builtin::Map"), Set(_) | EmptySet => panic!("TODO layout_to_basic_type for Builtin::Set"), - List(_) => collection(context, ptr_bytes).into(), + List(_, _) => collection(context, ptr_bytes).into(), EmptyList => BasicTypeEnum::StructType(collection(context, ptr_bytes)), }, } diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index e9e2499604..71f9895f54 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -13,25 +13,18 @@ mod helpers; #[cfg(test)] mod gen_list { - use crate::helpers::{can_expr, infer_expr, uniq_expr, with_larger_debug_stack, CanExprOut}; - use bumpalo::Bump; - use inkwell::context::Context; - use inkwell::execution_engine::JitFunction; - use inkwell::passes::PassManager; - use inkwell::types::BasicType; - use inkwell::OptimizationLevel; - use roc_collections::all::ImMap; - use roc_gen::llvm::build::{build_proc, build_proc_header}; - use roc_gen::llvm::convert::basic_type_from_layout; - use roc_mono::expr::{Expr, Procs}; - use roc_mono::layout::Layout; - use roc_types::subs::Subs; + use crate::helpers::with_larger_debug_stack; #[test] fn empty_list_literal() { assert_evals_to!("[]", &[], &'static [i64]); } + #[test] + fn int_singleton_list_literal() { + assert_evals_to!("[1]", &[1], &'static [i64]); + } + #[test] fn int_list_literal() { assert_evals_to!("[ 12, 9, 6, 3 ]", &[12, 9, 6, 3], &'static [i64]); @@ -141,7 +134,7 @@ mod gen_list { empty : List Float empty = [] - + List.join [ [ 0.2, 11.11 ], empty ] "# ), @@ -248,7 +241,7 @@ mod gen_list { assert_evals_to!("List.concat [] [ 23, 24 ]", &[23, 24], &'static [i64]); assert_evals_to!( - "List.concat [ 1, 2 ] [ 3, 4 ]", + "List.concat [1, 2 ] [ 3, 4 ]", &[1, 2, 3, 4], &'static [i64] ); @@ -272,7 +265,9 @@ mod gen_list { assert_evals_to!( &format!("List.concat {} {}", slice_str1, slice_str2), expected_slice, - &'static [i64] + &'static [i64], + |x| x, + true ); } @@ -312,6 +307,13 @@ mod gen_list { assert_concat_worked(2, 3); assert_concat_worked(3, 3); assert_concat_worked(4, 4); + } + + #[test] + fn list_concat_large() { + // these values produce mono ASTs so large that + // it can cause a stack overflow. This has been solved + // for current code, but may become a problem again in the future. assert_concat_worked(150, 150); assert_concat_worked(129, 350); assert_concat_worked(350, 129); @@ -634,6 +636,120 @@ mod gen_list { ); } + #[test] + fn gen_swap() { + assert_evals_to!( + indoc!( + r#" + swap : Int, Int, List a -> List a + swap = \i, j, list -> + when Pair (List.get list i) (List.get list j) is + Pair (Ok atI) (Ok atJ) -> + list + |> List.set i atJ + |> List.set j atI + + _ -> + [] + swap 0 1 [ 1, 2 ] + "# + ), + &[2, 1], + &'static [i64] + ); + } + + // #[test] + // fn gen_partition() { + // assert_evals_to!( + // indoc!( + // r#" + // swap : Int, Int, List a -> List a + // swap = \i, j, list -> + // when Pair (List.get list i) (List.get list j) is + // Pair (Ok atI) (Ok atJ) -> + // list + // |> List.set i atJ + // |> List.set j atI + // + // _ -> + // [] + // partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] + // partition = \low, high, initialList -> + // when List.get initialList high is + // Ok pivot -> + // when partitionHelp (low - 1) low initialList high pivot is + // Pair newI newList -> + // Pair (newI + 1) (swap (newI + 1) high newList) + // + // Err _ -> + // Pair (low - 1) initialList + // + // + // partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ] + // partitionHelp = \i, j, list, high, pivot -> + // if j < high then + // when List.get list j is + // Ok value -> + // if value <= pivot then + // partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + // else + // partitionHelp i (j + 1) list high pivot + // + // Err _ -> + // Pair i list + // else + // Pair i list + // + // # when partition 0 0 [ 1,2,3,4,5 ] is + // # Pair list _ -> list + // [ 1,3 ] + // "# + // ), + // &[2, 1], + // &'static [i64] + // ); + // } + + // #[test] + // fn gen_partition() { + // assert_evals_to!( + // indoc!( + // r#" + // swap : Int, Int, List a -> List a + // swap = \i, j, list -> + // when Pair (List.get list i) (List.get list j) is + // Pair (Ok atI) (Ok atJ) -> + // list + // |> List.set i atJ + // |> List.set j atI + // + // _ -> + // [] + // partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] + // partition = \low, high, initialList -> + // when List.get initialList high is + // Ok pivot -> + // when partitionHelp (low - 1) low initialList high pivot is + // Pair newI newList -> + // Pair (newI + 1) (swap (newI + 1) high newList) + // + // Err _ -> + // Pair (low - 1) initialList + // + // + // partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ] + // + // # when partition 0 0 [ 1,2,3,4,5 ] is + // # Pair list _ -> list + // [ 1,3 ] + // "# + // ), + // &[2, 1], + // &'static [i64] + // ); + // } + #[test] fn gen_quicksort() { with_larger_debug_stack(|| { @@ -642,7 +758,8 @@ mod gen_list { r#" quicksort : List (Num a) -> List (Num a) quicksort = \list -> - quicksortHelp list 0 (List.len list - 1) + n = List.len list + quicksortHelp list 0 (n - 1) quicksortHelp : List (Num a), Int, Int -> List (Num a) @@ -680,7 +797,7 @@ mod gen_list { Pair (low - 1) initialList - partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ] + partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ] partitionHelp = \i, j, list, high, pivot -> if j < high then when List.get list j is @@ -695,14 +812,257 @@ mod gen_list { else Pair i list - - quicksort [ 7, 4, 21, 19 ] "# ), &[4, 7, 19, 21], - &'static [i64] + &'static [i64], + |x| x, + true ); }) } + + // #[test] + // fn foobar2() { + // with_larger_debug_stack(|| { + // assert_evals_to!( + // indoc!( + // r#" + // quicksort : List (Num a) -> List (Num a) + // quicksort = \list -> + // quicksortHelp list 0 (List.len list - 1) + // + // + // quicksortHelp : List (Num a), Int, Int -> List (Num a) + // quicksortHelp = \list, low, high -> + // if low < high then + // when partition low high list is + // Pair partitionIndex partitioned -> + // partitioned + // |> quicksortHelp low (partitionIndex - 1) + // |> quicksortHelp (partitionIndex + 1) high + // else + // list + // + // + // swap : Int, Int, List a -> List a + // swap = \i, j, list -> + // when Pair (List.get list i) (List.get list j) is + // Pair (Ok atI) (Ok atJ) -> + // list + // |> List.set i atJ + // |> List.set j atI + // + // _ -> + // [] + // + // partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] + // partition = \low, high, initialList -> + // when List.get initialList high is + // Ok pivot -> + // when partitionHelp (low - 1) low initialList high pivot is + // Pair newI newList -> + // Pair (newI + 1) (swap (newI + 1) high newList) + // + // Err _ -> + // Pair (low - 1) initialList + // + // + // partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ] + // partitionHelp = \i, j, list, high, pivot -> + // # if j < high then + // if False then + // when List.get list j is + // Ok value -> + // if value <= pivot then + // partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + // else + // partitionHelp i (j + 1) list high pivot + // + // Err _ -> + // Pair i list + // else + // Pair i list + // + // + // + // quicksort [ 7, 4, 21, 19 ] + // "# + // ), + // &[19, 7, 4, 21], + // &'static [i64], + // |x| x, + // true + // ); + // }) + // } + + // #[test] + // fn foobar() { + // with_larger_debug_stack(|| { + // assert_evals_to!( + // indoc!( + // r#" + // quicksort : List (Num a) -> List (Num a) + // quicksort = \list -> + // quicksortHelp list 0 (List.len list - 1) + // + // + // quicksortHelp : List (Num a), Int, Int -> List (Num a) + // quicksortHelp = \list, low, high -> + // if low < high then + // when partition low high list is + // Pair partitionIndex partitioned -> + // partitioned + // |> quicksortHelp low (partitionIndex - 1) + // |> quicksortHelp (partitionIndex + 1) high + // else + // list + // + // + // swap : Int, Int, List a -> List a + // swap = \i, j, list -> + // when Pair (List.get list i) (List.get list j) is + // Pair (Ok atI) (Ok atJ) -> + // list + // |> List.set i atJ + // |> List.set j atI + // + // _ -> + // [] + // + // partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] + // partition = \low, high, initialList -> + // when List.get initialList high is + // Ok pivot -> + // when partitionHelp (low - 1) low initialList high pivot is + // Pair newI newList -> + // Pair (newI + 1) (swap (newI + 1) high newList) + // + // Err _ -> + // Pair (low - 1) initialList + // + // + // partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ] + // partitionHelp = \i, j, list, high, pivot -> + // if j < high then + // when List.get list j is + // Ok value -> + // if value <= pivot then + // partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + // else + // partitionHelp i (j + 1) list high pivot + // + // Err _ -> + // Pair i list + // else + // Pair i list + // + // + // + // when List.first (quicksort [0x1]) is + // _ -> 4 + // "# + // ), + // 4, + // i64, + // |x| x, + // false + // ); + // }) + // } + + #[test] + fn empty_list_increment_decrement() { + assert_evals_to!( + indoc!( + r#" + x : List Int + x = [] + + List.len x + List.len x + "# + ), + 0, + i64 + ); + } + + #[test] + fn list_literal_increment_decrement() { + assert_evals_to!( + indoc!( + r#" + x : List Int + x = [1,2,3] + + List.len x + List.len x + "# + ), + 6, + i64 + ); + } + + #[test] + fn list_pass_to_function() { + assert_evals_to!( + indoc!( + r#" + x : List Int + x = [1,2,3] + + id : List Int -> List Int + id = \y -> y + + id x + "# + ), + &[1, 2, 3], + &'static [i64], + |x| x, + true + ); + } + + #[test] + fn list_pass_to_set() { + assert_evals_to!( + indoc!( + r#" + x : List Int + x = [1,2,3] + + id : List Int -> List Int + id = \y -> List.set y 0 0 + + id x + "# + ), + &[0, 2, 3], + &'static [i64], + |x| x, + true + ); + } + + #[test] + fn list_wrap_in_tag() { + assert_evals_to!( + indoc!( + r#" + id : List Int -> [ Pair (List Int) Int ] + id = \y -> Pair y 4 + + when id [1,2,3] is + Pair v _ -> v + "# + ), + &[1, 2, 3], + &'static [i64], + |x| x, + true + ); + } } diff --git a/compiler/gen/tests/gen_num.rs b/compiler/gen/tests/gen_num.rs index cb0afe9ecd..e460e42340 100644 --- a/compiler/gen/tests/gen_num.rs +++ b/compiler/gen/tests/gen_num.rs @@ -13,19 +13,15 @@ mod helpers; #[cfg(test)] mod gen_num { - use crate::helpers::{can_expr, infer_expr, uniq_expr, CanExprOut}; - use bumpalo::Bump; - use inkwell::context::Context; - use inkwell::execution_engine::JitFunction; + /* use inkwell::passes::PassManager; use inkwell::types::BasicType; use inkwell::OptimizationLevel; - use roc_collections::all::ImMap; use roc_gen::llvm::build::{build_proc, build_proc_header}; use roc_gen::llvm::convert::basic_type_from_layout; - use roc_mono::expr::{Expr, Procs}; use roc_mono::layout::Layout; use roc_types::subs::Subs; + */ #[test] fn f64_sqrt() { @@ -44,7 +40,7 @@ mod gen_num { } #[test] - fn f64_round() { + fn f64_round_old() { assert_evals_to!("Num.round 3.6", 4, i64); } @@ -68,6 +64,26 @@ mod gen_num { #[test] fn gen_if_fn() { + assert_evals_to!( + indoc!( + r#" + limitedNegate = \num -> + x = + if num == 1 then + -1 + else if num == -1 then + 1 + else + num + x + + limitedNegate 1 + "# + ), + -1, + i64 + ); + assert_evals_to!( indoc!( r#" @@ -462,7 +478,7 @@ mod gen_num { } #[test] - fn if_guard_bind_variable() { + fn if_guard_bind_variable_false() { assert_evals_to!( indoc!( r#" @@ -474,7 +490,10 @@ mod gen_num { 42, i64 ); + } + #[test] + fn if_guard_bind_variable_true() { assert_evals_to!( indoc!( r#" diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 8b5b9037cd..e2914ff7c9 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -13,19 +13,6 @@ mod helpers; #[cfg(test)] mod gen_primitives { - use crate::helpers::{can_expr, infer_expr, uniq_expr, CanExprOut}; - use bumpalo::Bump; - use inkwell::context::Context; - use inkwell::execution_engine::JitFunction; - use inkwell::passes::PassManager; - use inkwell::types::BasicType; - use inkwell::OptimizationLevel; - use roc_collections::all::ImMap; - use roc_gen::llvm::build::{build_proc, build_proc_header}; - use roc_gen::llvm::convert::basic_type_from_layout; - use roc_mono::expr::{Expr, Procs}; - use roc_mono::layout::Layout; - use roc_types::subs::Subs; use std::ffi::{CStr, CString}; use std::os::raw::c_char; diff --git a/compiler/gen/tests/gen_records.rs b/compiler/gen/tests/gen_records.rs index 7066594b44..577538db01 100644 --- a/compiler/gen/tests/gen_records.rs +++ b/compiler/gen/tests/gen_records.rs @@ -13,20 +13,6 @@ mod helpers; #[cfg(test)] mod gen_records { - use crate::helpers::{can_expr, infer_expr, uniq_expr, CanExprOut}; - use bumpalo::Bump; - use inkwell::context::Context; - use inkwell::execution_engine::JitFunction; - use inkwell::passes::PassManager; - use inkwell::types::BasicType; - use inkwell::OptimizationLevel; - use roc_collections::all::ImMap; - use roc_gen::llvm::build::{build_proc, build_proc_header}; - use roc_gen::llvm::convert::basic_type_from_layout; - use roc_mono::expr::{Expr, Procs}; - use roc_mono::layout::Layout; - use roc_types::subs::Subs; - #[test] fn basic_record() { assert_evals_to!( @@ -201,7 +187,10 @@ mod gen_records { 5, i64 ); + } + #[test] + fn when_on_record_with_guard_pattern() { assert_evals_to!( indoc!( r#" @@ -212,7 +201,10 @@ mod gen_records { 5, i64 ); + } + #[test] + fn let_with_record_pattern() { assert_evals_to!( indoc!( r#" @@ -391,4 +383,20 @@ mod gen_records { bool ); } + + #[test] + fn return_record() { + assert_evals_to!( + indoc!( + r#" + x = 4 + y = 3 + + { x, y } + "# + ), + (4, 3), + (i64, i64) + ); + } } diff --git a/compiler/gen/tests/gen_tags.rs b/compiler/gen/tests/gen_tags.rs index 96ae1a3477..28d93a7668 100644 --- a/compiler/gen/tests/gen_tags.rs +++ b/compiler/gen/tests/gen_tags.rs @@ -13,19 +13,23 @@ mod helpers; #[cfg(test)] mod gen_tags { - use crate::helpers::{can_expr, infer_expr, uniq_expr, CanExprOut}; - use bumpalo::Bump; - use inkwell::context::Context; - use inkwell::execution_engine::JitFunction; - use inkwell::passes::PassManager; - use inkwell::types::BasicType; - use inkwell::OptimizationLevel; - use roc_collections::all::ImMap; - use roc_gen::llvm::build::{build_proc, build_proc_header}; - use roc_gen::llvm::convert::basic_type_from_layout; - use roc_mono::expr::{Expr, Procs}; - use roc_mono::layout::Layout; - use roc_types::subs::Subs; + #[test] + fn applied_tag_nothing_ir() { + assert_evals_to!( + indoc!( + r#" + Maybe a : [ Just a, Nothing ] + + x : Maybe Int + x = Nothing + + 0x1 + "# + ), + 1, + i64 + ); + } #[test] fn applied_tag_nothing() { @@ -63,6 +67,24 @@ mod gen_tags { ); } + #[test] + fn applied_tag_just_ir() { + assert_evals_to!( + indoc!( + r#" + Maybe a : [ Just a, Nothing ] + + y : Maybe Int + y = Just 0x4 + + 0x1 + "# + ), + 1, + i64 + ); + } + #[test] fn applied_tag_just_unit() { assert_evals_to!( @@ -621,4 +643,77 @@ mod gen_tags { i64 ); } + + #[test] + fn join_point_if() { + assert_evals_to!( + indoc!( + r#" + x = + if True then 1 else 2 + + x + "# + ), + 1, + i64 + ); + } + + #[test] + fn join_point_when() { + assert_evals_to!( + indoc!( + r#" + x : [ Red, White, Blue ] + x = Blue + + y = + when x is + Red -> 1 + White -> 2 + Blue -> 3.1 + + y + "# + ), + 3.1, + f64 + ); + } + + #[test] + fn join_point_with_cond_expr() { + assert_evals_to!( + indoc!( + r#" + y = + when 1 + 2 is + 3 -> 3 + 1 -> 1 + _ -> 0 + + y + "# + ), + 3, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + y = + if 1 + 2 > 0 then + 3 + else + 0 + + y + "# + ), + 3, + i64 + ); + } } diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index 0885361f8a..839136a089 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -1,11 +1,34 @@ -#[macro_export] -macro_rules! assert_llvm_evals_to { - ($src:expr, $expected:expr, $ty:ty, $transform:expr) => { - let target = target_lexicon::Triple::host(); - let ptr_bytes = target.pointer_width().unwrap().bytes() as u32; - let arena = Bump::new(); - let CanExprOut { loc_expr, var_store, var, constraint, home, interns, problems, .. } = can_expr($src); - let errors = problems.into_iter().filter(|problem| { +use roc_types::subs::Subs; + +pub fn helper_without_uniqueness<'a>( + arena: &'a bumpalo::Bump, + src: &str, + leak: bool, + context: &'a inkwell::context::Context, +) -> (&'static str, inkwell::execution_engine::ExecutionEngine<'a>) { + use crate::helpers::{can_expr, infer_expr, CanExprOut}; + use inkwell::types::BasicType; + use inkwell::OptimizationLevel; + use roc_gen::llvm::build::Scope; + use roc_gen::llvm::build::{build_proc, build_proc_header}; + use roc_gen::llvm::convert::basic_type_from_layout; + use roc_mono::layout::Layout; + + let target = target_lexicon::Triple::host(); + let ptr_bytes = target.pointer_width().unwrap().bytes() as u32; + let CanExprOut { + loc_expr, + var_store, + var, + constraint, + home, + interns, + problems, + .. + } = can_expr(src); + let errors = problems + .into_iter() + .filter(|problem| { use roc_problem::can::Problem::*; // Ignore "unused" problems @@ -13,159 +36,381 @@ macro_rules! assert_llvm_evals_to { UnusedDef(_, _) | UnusedArgument(_, _, _) | UnusedImport(_, _) => false, _ => true, } - }).collect::>(); + }) + .collect::>(); - assert_eq!(errors, Vec::new(), "Encountered errors: {:?}", errors); + assert_eq!(errors, Vec::new(), "Encountered errors: {:?}", errors); - let subs = Subs::new(var_store.into()); - let mut unify_problems = Vec::new(); - let (content, mut subs) = infer_expr(subs, &mut unify_problems, &constraint, var); + let subs = Subs::new(var_store.into()); + let mut unify_problems = Vec::new(); + let (content, mut subs) = infer_expr(subs, &mut unify_problems, &constraint, var); - assert_eq!(unify_problems, Vec::new(), "Encountered type mismatches: {:?}", unify_problems); + assert_eq!( + unify_problems, + Vec::new(), + "Encountered type mismatches: {:?}", + unify_problems + ); - let context = Context::create(); - let module = roc_gen::llvm::build::module_from_builtins(&context, "app"); - let builder = context.create_builder(); - let opt_level = if cfg!(debug_assertions) { - roc_gen::llvm::build::OptLevel::Normal + let module = roc_gen::llvm::build::module_from_builtins(context, "app"); + let builder = context.create_builder(); + let opt_level = if cfg!(debug_assertions) { + roc_gen::llvm::build::OptLevel::Normal + } else { + roc_gen::llvm::build::OptLevel::Optimize + }; + + let module = arena.alloc(module); + let (module_pass, function_pass) = + roc_gen::llvm::build::construct_optimization_passes(module, opt_level); + + // Compute main_fn_type before moving subs to Env + let layout = Layout::new(&arena, content, &subs).unwrap_or_else(|err| { + panic!( + "Code gen error in NON-OPTIMIZED test: could not convert to layout. Err was {:?}", + err + ) + }); + let execution_engine = module + .create_jit_execution_engine(OptimizationLevel::None) + .expect("Error creating JIT execution engine for test"); + + let main_fn_type = + basic_type_from_layout(&arena, context, &layout, ptr_bytes).fn_type(&[], false); + let main_fn_name = "$Test.main"; + + // Compile and add all the Procs before adding main + let mut env = roc_gen::llvm::build::Env { + arena: &arena, + builder: &builder, + context: context, + interns, + module, + ptr_bytes, + leak: leak, + }; + let mut procs = roc_mono::ir::Procs::default(); + let mut ident_ids = env.interns.all_ident_ids.remove(&home).unwrap(); + let mut layout_ids = roc_gen::layout_id::LayoutIds::default(); + + // Populate Procs and get the low-level Expr from the canonical Expr + let mut mono_problems = Vec::new(); + let mut mono_env = roc_mono::ir::Env { + arena: &arena, + subs: &mut subs, + problems: &mut mono_problems, + home, + ident_ids: &mut ident_ids, + }; + + let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); + let main_body = + roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); + + let mut headers = { + let num_headers = match &procs.pending_specializations { + Some(map) => map.len(), + None => 0, + }; + + Vec::with_capacity(num_headers) + }; + let mut layout_cache = roc_mono::layout::LayoutCache::default(); + let procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut layout_cache); + + assert_eq!( + procs.runtime_errors, + roc_collections::all::MutMap::default() + ); + + // Put this module's ident_ids back in the interns, so we can use them in env. + // This must happen *after* building the headers, because otherwise there's + // a conflicting mutable borrow on ident_ids. + env.interns.all_ident_ids.insert(home, ident_ids); + + // Add all the Proc headers to the module. + // We have to do this in a separate pass first, + // because their bodies may reference each other. + + for ((symbol, layout), proc) in procs.get_specialized_procs(env.arena).drain() { + let (fn_val, arg_basic_types) = + build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); + + headers.push((proc, fn_val, arg_basic_types)); + } + + // Build each proc using its header info. + for (proc, fn_val, arg_basic_types) in headers { + build_proc(&env, &mut layout_ids, proc, fn_val, arg_basic_types); + + if fn_val.verify(true) { + function_pass.run_on(&fn_val); } else { - roc_gen::llvm::build::OptLevel::Optimize - }; - let fpm = PassManager::create(&module); - - roc_gen::llvm::build::add_passes(&fpm, opt_level); - - fpm.initialize(); - - // Compute main_fn_type before moving subs to Env - let layout = Layout::new(&arena, content, &subs) - .unwrap_or_else(|err| panic!("Code gen error in NON-OPTIMIZED test: could not convert to layout. Err was {:?}", err)); - let execution_engine = - module - .create_jit_execution_engine(OptimizationLevel::None) - .expect("Error creating JIT execution engine for test"); - - let main_fn_type = basic_type_from_layout(&arena, &context, &layout, ptr_bytes) - .fn_type(&[], false); - let main_fn_name = "$Test.main"; - - // Compile and add all the Procs before adding main - let mut env = roc_gen::llvm::build::Env { - arena: &arena, - builder: &builder, - context: &context, - interns, - module: arena.alloc(module), - ptr_bytes - }; - let mut procs = Procs::default(); - let mut ident_ids = env.interns.all_ident_ids.remove(&home).unwrap(); - let mut layout_ids = roc_gen::layout_id::LayoutIds::default(); - - // Populate Procs and get the low-level Expr from the canonical Expr - let mut mono_problems = Vec::new(); - let mut mono_env = roc_mono::expr::Env { - arena: &arena, - subs: &mut subs, - problems: &mut mono_problems, - home, - ident_ids: &mut ident_ids, - }; - - let main_body = Expr::new(&mut mono_env, loc_expr.value, &mut procs); - let mut headers = { - let num_headers = match &procs.pending_specializations { - Some(map) => map.len(), - None => 0 - }; - - Vec::with_capacity(num_headers) - }; - let mut layout_cache = roc_mono::layout::LayoutCache::default(); - let mut procs = roc_mono::expr::specialize_all(&mut mono_env, procs, &mut layout_cache); - - assert_eq!(procs.runtime_errors, roc_collections::all::MutMap::default()); - - // Put this module's ident_ids back in the interns, so we can use them in env. - // This must happen *after* building the headers, because otherwise there's - // a conflicting mutable borrow on ident_ids. - env.interns.all_ident_ids.insert(home, ident_ids); - - // Add all the Proc headers to the module. - // We have to do this in a separate pass first, - // because their bodies may reference each other. - for ((symbol, layout), proc) in procs.specialized.drain() { - use roc_mono::expr::InProgressProc::*; - - match proc { - InProgress => { - panic!("A specialization was still marked InProgress after monomorphization had completed: {:?} with layout {:?}", symbol, layout); - } - Done(proc) => { - let (fn_val, arg_basic_types) = - build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); - - headers.push((proc, fn_val, arg_basic_types)); - } - } - } - - // Build each proc using its header info. - for (proc, fn_val, arg_basic_types) in headers { - build_proc(&env, &mut layout_ids, proc, fn_val, arg_basic_types); - - if fn_val.verify(true) { - fpm.run_on(&fn_val); - } else { - eprintln!( + eprintln!( "\n\nFunction {:?} failed LLVM verification in NON-OPTIMIZED build. Its content was:\n", fn_val.get_name().to_str().unwrap() ); - fn_val.print_to_stderr(); + fn_val.print_to_stderr(); - panic!( + panic!( "The preceding code was from {:?}, which failed LLVM verification in NON-OPTIMIZED build.", fn_val.get_name().to_str().unwrap() ); + } + } + + // Add main to the module. + let main_fn = env.module.add_function(main_fn_name, main_fn_type, None); + let cc = + roc_gen::llvm::build::get_call_conventions(target.default_calling_convention().unwrap()); + + main_fn.set_call_conventions(cc); + + // Add main's body + let basic_block = context.append_basic_block(main_fn, "entry"); + + builder.position_at_end(basic_block); + + let ret = roc_gen::llvm::build::build_exp_stmt( + &env, + &mut layout_ids, + &mut Scope::default(), + main_fn, + &main_body, + ); + + builder.build_return(Some(&ret)); + + // Uncomment this to see the module's un-optimized LLVM instruction output: + // env.module.print_to_stderr(); + + if main_fn.verify(true) { + function_pass.run_on(&main_fn); + } else { + panic!("Main function {} failed LLVM verification in NON-OPTIMIZED build. Uncomment things nearby to see more details.", main_fn_name); + } + + module_pass.run_on(env.module); + + // Verify the module + if let Err(errors) = env.module.verify() { + panic!("Errors defining module: {:?}", errors); + } + + // Uncomment this to see the module's optimized LLVM instruction output: + // env.module.print_to_stderr(); + + (main_fn_name, execution_engine.clone()) +} + +pub fn helper_with_uniqueness<'a>( + arena: &'a bumpalo::Bump, + src: &str, + leak: bool, + context: &'a inkwell::context::Context, +) -> (&'static str, inkwell::execution_engine::ExecutionEngine<'a>) { + use crate::helpers::{infer_expr, uniq_expr}; + use inkwell::types::BasicType; + use inkwell::OptimizationLevel; + use roc_gen::llvm::build::Scope; + use roc_gen::llvm::build::{build_proc, build_proc_header}; + use roc_gen::llvm::convert::basic_type_from_layout; + use roc_mono::layout::Layout; + + let target = target_lexicon::Triple::host(); + let ptr_bytes = target.pointer_width().unwrap().bytes() as u32; + let (loc_expr, _output, problems, subs, var, constraint, home, interns) = uniq_expr(src); + + let errors = problems + .into_iter() + .filter(|problem| { + use roc_problem::can::Problem::*; + + // Ignore "unused" problems + match problem { + UnusedDef(_, _) | UnusedArgument(_, _, _) | UnusedImport(_, _) => false, + _ => true, } - } + }) + .collect::>(); - // Add main to the module. - let main_fn = env.module.add_function(main_fn_name, main_fn_type, None); - let cc = roc_gen::llvm::build::get_call_conventions(target.default_calling_convention().unwrap()); + assert_eq!(errors, Vec::new(), "Encountered errors: {:?}", errors); - main_fn.set_call_conventions(cc); + let mut unify_problems = Vec::new(); + let (content, mut subs) = infer_expr(subs, &mut unify_problems, &constraint, var); - // Add main's body - let basic_block = context.append_basic_block(main_fn, "entry"); + assert_eq!( + unify_problems, + Vec::new(), + "Encountered one or more type mismatches: {:?}", + unify_problems + ); - builder.position_at_end(basic_block); + let module = arena.alloc(roc_gen::llvm::build::module_from_builtins(context, "app")); + let builder = context.create_builder(); + let opt_level = if cfg!(debug_assertions) { + roc_gen::llvm::build::OptLevel::Normal + } else { + roc_gen::llvm::build::OptLevel::Optimize + }; + let (mpm, fpm) = roc_gen::llvm::build::construct_optimization_passes(module, opt_level); - let ret = roc_gen::llvm::build::build_expr( - &env, - &mut layout_ids, - &ImMap::default(), - main_fn, - &main_body, - ); + // Compute main_fn_type before moving subs to Env + let layout = Layout::new(&arena, content, &subs).unwrap_or_else(|err| { + panic!( + "Code gen error in OPTIMIZED test: could not convert to layout. Err was {:?}", + err + ) + }); - builder.build_return(Some(&ret)); + let execution_engine = module + .create_jit_execution_engine(OptimizationLevel::None) + .expect("Error creating JIT execution engine for test"); - // Uncomment this to see the module's un-optimized LLVM instruction output: - // env.module.print_to_stderr(); + let main_fn_type = basic_type_from_layout(&arena, context, &layout, ptr_bytes) + .fn_type(&[], false) + .clone(); + let main_fn_name = "$Test.main"; - if main_fn.verify(true) { - fpm.run_on(&main_fn); + // Compile and add all the Procs before adding main + let mut env = roc_gen::llvm::build::Env { + arena: &arena, + builder: &builder, + context: context, + interns, + module, + ptr_bytes, + leak: leak, + }; + let mut procs = roc_mono::ir::Procs::default(); + let mut ident_ids = env.interns.all_ident_ids.remove(&home).unwrap(); + let mut layout_ids = roc_gen::layout_id::LayoutIds::default(); + + // Populate Procs and get the low-level Expr from the canonical Expr + let mut mono_problems = Vec::new(); + let mut mono_env = roc_mono::ir::Env { + arena: &arena, + subs: &mut subs, + problems: &mut mono_problems, + home, + ident_ids: &mut ident_ids, + }; + + let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); + let main_body = + roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); + let mut headers = { + let num_headers = match &procs.pending_specializations { + Some(map) => map.len(), + None => 0, + }; + + Vec::with_capacity(num_headers) + }; + let mut layout_cache = roc_mono::layout::LayoutCache::default(); + let procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut layout_cache); + + assert_eq!( + procs.runtime_errors, + roc_collections::all::MutMap::default() + ); + + // Put this module's ident_ids back in the interns, so we can use them in env. + // This must happen *after* building the headers, because otherwise there's + // a conflicting mutable borrow on ident_ids. + env.interns.all_ident_ids.insert(home, ident_ids); + + // Add all the Proc headers to the module. + // We have to do this in a separate pass first, + // because their bodies may reference each other. + for ((symbol, layout), proc) in procs.get_specialized_procs(env.arena).drain() { + let (fn_val, arg_basic_types) = + build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); + + headers.push((proc, fn_val, arg_basic_types)); + } + + // Build each proc using its header info. + for (proc, fn_val, arg_basic_types) in headers { + build_proc(&env, &mut layout_ids, proc, fn_val, arg_basic_types); + + if fn_val.verify(true) { + fpm.run_on(&fn_val); } else { - panic!("Main function {} failed LLVM verification in NON-OPTIMIZED build. Uncomment things nearby to see more details.", main_fn_name); - } + eprintln!( + "\n\nFunction {:?} failed LLVM verification in OPTIMIZED build. Its content was:\n", + fn_val.get_name().to_str().unwrap() + ); - // Verify the module - if let Err(errors) = env.module.verify() { - panic!("Errors defining module: {:?}", errors); - } + fn_val.print_to_stderr(); - // Uncomment this to see the module's optimized LLVM instruction output: - // env.module.print_to_stderr(); + panic!( + "The preceding code was from {:?}, which failed LLVM verification in OPTIMIZED build.", fn_val.get_name().to_str().unwrap() + ); + } + } + + // Add main to the module. + let main_fn = env.module.add_function(main_fn_name, main_fn_type, None); + let cc = + roc_gen::llvm::build::get_call_conventions(target.default_calling_convention().unwrap()); + + main_fn.set_call_conventions(cc); + + // Add main's body + let basic_block = context.append_basic_block(main_fn, "entry"); + + builder.position_at_end(basic_block); + + let ret = roc_gen::llvm::build::build_exp_stmt( + &env, + &mut layout_ids, + &mut Scope::default(), + main_fn, + &main_body, + ); + + builder.build_return(Some(&ret)); + + // you're in the version with uniqueness! + + // Uncomment this to see the module's un-optimized LLVM instruction output: + // env.module.print_to_stderr(); + + if main_fn.verify(true) { + fpm.run_on(&main_fn); + } else { + panic!("main function {} failed LLVM verification in OPTIMIZED build. Uncomment nearby statements to see more details.", main_fn_name); + } + + mpm.run_on(module); + + // Verify the module + if let Err(errors) = env.module.verify() { + panic!("Errors defining module: {:?}", errors); + } + + // Uncomment this to see the module's optimized LLVM instruction output: + // env.module.print_to_stderr(); + + (main_fn_name, execution_engine) +} + +// TODO this is almost all code duplication with assert_llvm_evals_to +// the only difference is that this calls uniq_expr instead of can_expr. +// Should extract the common logic into test helpers. +#[macro_export] +macro_rules! assert_opt_evals_to { + ($src:expr, $expected:expr, $ty:ty, $transform:expr, $leak:expr) => { + use bumpalo::Bump; + use inkwell::context::Context; + use inkwell::execution_engine::JitFunction; + + let arena = Bump::new(); + + let context = Context::create(); + + let (main_fn_name, execution_engine) = + $crate::helpers::eval::helper_with_uniqueness(&arena, $src, $leak, &context); unsafe { let main: JitFunction $ty> = execution_engine @@ -177,179 +422,25 @@ macro_rules! assert_llvm_evals_to { assert_eq!($transform(main.call()), $expected); } }; + + ($src:expr, $expected:expr, $ty:ty, $transform:expr) => { + assert_opt_evals_to!($src, $expected, $ty, $transform, true) + }; } -// TODO this is almost all code duplication with assert_llvm_evals_to -// the only difference is that this calls uniq_expr instead of can_expr. -// Should extract the common logic into test helpers. #[macro_export] -macro_rules! assert_opt_evals_to { - ($src:expr, $expected:expr, $ty:ty, $transform:expr) => { +macro_rules! assert_llvm_evals_to { + ($src:expr, $expected:expr, $ty:ty, $transform:expr, $leak:expr) => { + use bumpalo::Bump; + use inkwell::context::Context; + use inkwell::execution_engine::JitFunction; + let arena = Bump::new(); - let target = target_lexicon::Triple::host(); - let ptr_bytes = target.pointer_width().unwrap().bytes() as u32; - let (loc_expr, _output, problems, subs, var, constraint, home, interns) = uniq_expr($src); - let errors = problems.into_iter().filter(|problem| { - use roc_problem::can::Problem::*; - - // Ignore "unused" problems - match problem { - UnusedDef(_, _) | UnusedArgument(_, _, _) | UnusedImport(_, _) => false, - _ => true, - } - }).collect::>(); - - assert_eq!(errors, Vec::new(), "Encountered errors: {:?}", errors); - - let mut unify_problems = Vec::new(); - let (content, mut subs) = infer_expr(subs, &mut unify_problems, &constraint, var); - - assert_eq!(unify_problems, Vec::new(), "Encountered one or more type mismatches: {:?}", unify_problems); let context = Context::create(); - let module = roc_gen::llvm::build::module_from_builtins(&context, "app"); - let builder = context.create_builder(); - let opt_level = if cfg!(debug_assertions) { - roc_gen::llvm::build::OptLevel::Normal - } else { - roc_gen::llvm::build::OptLevel::Optimize - }; - let fpm = PassManager::create(&module); - roc_gen::llvm::build::add_passes(&fpm, opt_level); - - fpm.initialize(); - - // Compute main_fn_type before moving subs to Env - let layout = Layout::new(&arena, content, &subs) - .unwrap_or_else(|err| panic!("Code gen error in OPTIMIZED test: could not convert to layout. Err was {:?}", err)); - - let execution_engine = - module - .create_jit_execution_engine(OptimizationLevel::None) - .expect("Error creating JIT execution engine for test"); - - let main_fn_type = basic_type_from_layout(&arena, &context, &layout, ptr_bytes) - .fn_type(&[], false); - let main_fn_name = "$Test.main"; - - // Compile and add all the Procs before adding main - let mut env = roc_gen::llvm::build::Env { - arena: &arena, - builder: &builder, - context: &context, - interns, - module: arena.alloc(module), - ptr_bytes - }; - let mut procs = Procs::default(); - let mut ident_ids = env.interns.all_ident_ids.remove(&home).unwrap(); - let mut layout_ids = roc_gen::layout_id::LayoutIds::default(); - - // Populate Procs and get the low-level Expr from the canonical Expr - let mut mono_problems = Vec::new(); - let mut mono_env = roc_mono::expr::Env { - arena: &arena, - subs: &mut subs, - problems: &mut mono_problems, - home, - ident_ids: &mut ident_ids, - }; - let main_body = Expr::new(&mut mono_env, loc_expr.value, &mut procs); - - let mut headers = { - let num_headers = match &procs.pending_specializations { - Some(map) => map.len(), - None => 0 - }; - - Vec::with_capacity(num_headers) - }; - let mut layout_cache = roc_mono::layout::LayoutCache::default(); - let mut procs = roc_mono::expr::specialize_all(&mut mono_env, procs, &mut layout_cache); - - assert_eq!(procs.runtime_errors, roc_collections::all::MutMap::default()); - - // Put this module's ident_ids back in the interns, so we can use them in env. - // This must happen *after* building the headers, because otherwise there's - // a conflicting mutable borrow on ident_ids. - env.interns.all_ident_ids.insert(home, ident_ids); - - // Add all the Proc headers to the module. - // We have to do this in a separate pass first, - // because their bodies may reference each other. - for ((symbol, layout), proc) in procs.specialized.drain() { - use roc_mono::expr::InProgressProc::*; - - match proc { - InProgress => { - panic!("A specialization was still marked InProgress after monomorphization had completed: {:?} with layout {:?}", symbol, layout); - } - Done(proc) => { - let (fn_val, arg_basic_types) = - build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); - - headers.push((proc, fn_val, arg_basic_types)); - } - } - } - - // Build each proc using its header info. - for (proc, fn_val, arg_basic_types) in headers { - build_proc(&env, &mut layout_ids, proc, fn_val, arg_basic_types); - - if fn_val.verify(true) { - fpm.run_on(&fn_val); - } else { - eprintln!( - "\n\nFunction {:?} failed LLVM verification in OPTIMIZED build. Its content was:\n", fn_val.get_name().to_str().unwrap() - ); - - fn_val.print_to_stderr(); - - panic!( - "The preceding code was from {:?}, which failed LLVM verification in OPTIMIZED build.", fn_val.get_name().to_str().unwrap() - ); - } - } - - // Add main to the module. - let main_fn = env.module.add_function(main_fn_name, main_fn_type, None); - let cc = roc_gen::llvm::build::get_call_conventions(target.default_calling_convention().unwrap()); - - main_fn.set_call_conventions(cc); - - // Add main's body - let basic_block = context.append_basic_block(main_fn, "entry"); - - builder.position_at_end(basic_block); - - let ret = roc_gen::llvm::build::build_expr( - &env, - &mut layout_ids, - &ImMap::default(), - main_fn, - &main_body, - ); - - builder.build_return(Some(&ret)); - - // Uncomment this to see the module's un-optimized LLVM instruction output: - // env.module.print_to_stderr(); - - if main_fn.verify(true) { - fpm.run_on(&main_fn); - } else { - panic!("main function {} failed LLVM verification in OPTIMIZED build. Uncomment nearby statements to see more details.", main_fn_name); - } - - // Verify the module - if let Err(errors) = env.module.verify() { - panic!("Errors defining module: {:?}", errors); - } - - // Uncomment this to see the module's optimized LLVM instruction output: - // env.module.print_to_stderr(); + let (main_fn_name, execution_engine) = + $crate::helpers::eval::helper_without_uniqueness(&arena, $src, $leak, &context); unsafe { let main: JitFunction $ty> = execution_engine @@ -361,6 +452,10 @@ macro_rules! assert_opt_evals_to { assert_eq!($transform(main.call()), $expected); } }; + + ($src:expr, $expected:expr, $ty:ty, $transform:expr) => { + assert_llvm_evals_to!($src, $expected, $ty, $transform, true); + }; } #[macro_export] @@ -386,4 +481,13 @@ macro_rules! assert_evals_to { assert_opt_evals_to!($src, $expected, $ty, $transform); } }; + ($src:expr, $expected:expr, $ty:ty, $transform:expr, $leak:expr) => { + // Same as above, except with an additional transformation argument. + { + assert_llvm_evals_to!($src, $expected, $ty, $transform, $leak); + } + { + assert_opt_evals_to!($src, $expected, $ty, $transform, $leak); + } + }; } diff --git a/compiler/load/tests/fixtures/build/app_with_deps/Quicksort.roc b/compiler/load/tests/fixtures/build/app_with_deps/Quicksort.roc index 1d363a1b5b..9cbf0c7e38 100644 --- a/compiler/load/tests/fixtures/build/app_with_deps/Quicksort.roc +++ b/compiler/load/tests/fixtures/build/app_with_deps/Quicksort.roc @@ -1,5 +1,5 @@ app Quicksort - provides [ swap, partition, quicksort ] + provides [ swap, partition, partitionHelp, quicksort ] imports [] quicksort : List (Num a), Int, Int -> List (Num a) @@ -27,23 +27,25 @@ partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] partition = \low, high, initialList -> when List.get initialList high is Ok pivot -> - go = \i, j, list -> - if j < high then - when List.get list j is - Ok value -> - if value <= pivot then - go (i + 1) (j + 1) (swap (i + 1) j list) - else - go i (j + 1) list - - Err _ -> - Pair i list - else - Pair i list - - when go (low - 1) low initialList is + when partitionHelp (low - 1) low initialList high pivot is Pair newI newList -> Pair (newI + 1) (swap (newI + 1) high newList) Err _ -> Pair (low - 1) initialList + + +partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ] +partitionHelp = \i, j, list, high, pivot -> + if j < high then + when List.get list j is + Ok value -> + if value <= pivot then + partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + else + partitionHelp i (j + 1) list high pivot + + Err _ -> + Pair i list + else + Pair i list diff --git a/compiler/load/tests/fixtures/build/interface_with_deps/Quicksort.roc b/compiler/load/tests/fixtures/build/interface_with_deps/Quicksort.roc index 9f02ef5b3d..d4e9b79490 100644 --- a/compiler/load/tests/fixtures/build/interface_with_deps/Quicksort.roc +++ b/compiler/load/tests/fixtures/build/interface_with_deps/Quicksort.roc @@ -27,23 +27,25 @@ partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] partition = \low, high, initialList -> when List.get initialList high is Ok pivot -> - go = \i, j, list -> - if j < high then - when List.get list j is - Ok value -> - if value <= pivot then - go (i + 1) (j + 1) (swap (i + 1) j list) - else - go i (j + 1) list - - Err _ -> - Pair i list - else - Pair i list - - when go (low - 1) low initialList is + when partitionHelp (low - 1) low initialList high pivot is Pair newI newList -> Pair (newI + 1) (swap (newI + 1) high newList) Err _ -> Pair (low - 1) initialList + + +partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ] +partitionHelp = \i, j, list, high, pivot -> + if j < high then + when List.get list j is + Ok value -> + if value <= pivot then + partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + else + partitionHelp i (j + 1) list high pivot + + Err _ -> + Pair i list + else + Pair i list diff --git a/compiler/load/tests/test_load.rs b/compiler/load/tests/test_load.rs index e98f05dfdc..69bcb0f483 100644 --- a/compiler/load/tests/test_load.rs +++ b/compiler/load/tests/test_load.rs @@ -214,6 +214,7 @@ mod test_load { hashmap! { "swap" => "Int, Int, List a -> List a", "partition" => "Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ]", + "partitionHelp" => "Int, Int, List (Num a), Int, Num a -> [ Pair Int (List (Num a)) ]", "quicksort" => "List (Num a), Int, Int -> List (Num a)", }, ); @@ -229,6 +230,7 @@ mod test_load { hashmap! { "swap" => "Int, Int, List a -> List a", "partition" => "Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ]", + "partitionHelp" => "Int, Int, List (Num a), Int, Num a -> [ Pair Int (List (Num a)) ]", "quicksort" => "List (Num a), Int, Int -> List (Num a)", }, ); diff --git a/compiler/load/tests/test_uniq_load.rs b/compiler/load/tests/test_uniq_load.rs index 4ee214f80f..052cfdc66f 100644 --- a/compiler/load/tests/test_uniq_load.rs +++ b/compiler/load/tests/test_uniq_load.rs @@ -233,6 +233,8 @@ mod test_uniq_load { hashmap! { "swap" => "Attr * (Attr * Int, Attr * Int, Attr * (List (Attr Shared a)) -> Attr * (List (Attr Shared a)))", "partition" => "Attr * (Attr Shared Int, Attr Shared Int, Attr b (List (Attr Shared (Num (Attr Shared a)))) -> Attr * [ Pair (Attr * Int) (Attr b (List (Attr Shared (Num (Attr Shared a))))) ])", + + "partitionHelp" => "Attr Shared (Attr b Int, Attr Shared Int, Attr c (List (Attr Shared (Num (Attr Shared a)))), Attr Shared Int, Attr Shared (Num (Attr Shared a)) -> Attr * [ Pair (Attr b Int) (Attr c (List (Attr Shared (Num (Attr Shared a))))) ])", "quicksort" => "Attr Shared (Attr b (List (Attr Shared (Num (Attr Shared a)))), Attr Shared Int, Attr Shared Int -> Attr b (List (Attr Shared (Num (Attr Shared a)))))", }, ); diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 15a8477c65..bb09d9aff7 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -146,6 +146,17 @@ impl fmt::Debug for Symbol { } } +impl fmt::Display for Symbol { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let module_id = self.module_id(); + let ident_id = self.ident_id(); + + match ident_id { + IdentId(value) => write!(f, "{:?}.{:?}", module_id, value), + } + } +} + fn fallback_debug_fmt(symbol: Symbol, f: &mut fmt::Formatter) -> fmt::Result { let module_id = symbol.module_id(); let ident_id = symbol.ident_id(); diff --git a/compiler/mono/Cargo.toml b/compiler/mono/Cargo.toml index e75d8afa11..5301d0bdc3 100644 --- a/compiler/mono/Cargo.toml +++ b/compiler/mono/Cargo.toml @@ -13,6 +13,7 @@ roc_types = { path = "../types" } roc_can = { path = "../can" } roc_unify = { path = "../unify" } roc_problem = { path = "../problem" } +ven_pretty = { path = "../../vendor/pretty" } bumpalo = { version = "3.2", features = ["collections"] } [dev-dependencies] diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index ce135e5c45..536fc8f937 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -1,7 +1,6 @@ -use crate::expr::{DestructType, Env, Expr, Pattern}; -use crate::layout::{Builtin, Layout}; -use crate::pattern::{Ctor, RenderAs, TagId, Union}; -use bumpalo::Bump; +use crate::exhaustive::{Ctor, RenderAs, TagId, Union}; +use crate::ir::{DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern, Procs, Stmt}; +use crate::layout::{Builtin, Layout, LayoutCache}; use roc_collections::all::{MutMap, MutSet}; use roc_module::ident::TagName; use roc_module::low_level::LowLevel; @@ -31,8 +30,12 @@ pub fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> Decision pub enum Guard<'a> { NoGuard, Guard { - stores: &'a [(Symbol, Layout<'a>, Expr<'a>)], - expr: Expr<'a>, + /// Symbol that stores a boolean + /// when true this branch is picked, otherwise skipped + symbol: Symbol, + /// after assigning to symbol, the stmt jumps to this label + id: JoinPointId, + stmt: Stmt<'a>, }, } @@ -57,7 +60,7 @@ pub enum Test<'a> { IsCtor { tag_id: u8, tag_name: TagName, - union: crate::pattern::Union, + union: crate::exhaustive::Union, arguments: Vec<(Pattern<'a>, Layout<'a>)>, }, IsInt(i64), @@ -72,8 +75,12 @@ pub enum Test<'a> { // A pattern that always succeeds (like `_`) can still have a guard Guarded { opt_test: Option>>, - stores: &'a [(Symbol, Layout<'a>, Expr<'a>)], - expr: Expr<'a>, + /// Symbol that stores a boolean + /// when true this branch is picked, otherwise skipped + symbol: Symbol, + /// after assigning to symbol, the stmt jumps to this label + id: JoinPointId, + stmt: Stmt<'a>, }, } use std::hash::{Hash, Hasher}; @@ -355,11 +362,12 @@ fn test_at_path<'a>(selected_path: &Path, branch: &Branch<'a>, all_tests: &mut V None => {} Some((_, guard, pattern)) => { let guarded = |test| { - if let Guard::Guard { stores, expr } = guard { + if let Guard::Guard { symbol, id, stmt } = guard { Guarded { opt_test: Some(Box::new(test)), - stores, - expr: expr.clone(), + stmt: stmt.clone(), + symbol: *symbol, + id: *id, } } else { test @@ -369,11 +377,12 @@ fn test_at_path<'a>(selected_path: &Path, branch: &Branch<'a>, all_tests: &mut V match pattern { // TODO use guard! Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => { - if let Guard::Guard { stores, expr } = guard { + if let Guard::Guard { symbol, id, stmt } = guard { all_tests.push(Guarded { opt_test: None, - stores, - expr: expr.clone(), + stmt: stmt.clone(), + symbol: *symbol, + id: *id, }); } } @@ -577,6 +586,8 @@ fn to_relevant_branch_help<'a>( start.push((Path::Unbox(Box::new(path.clone())), guard, arg.0)); start.extend(end); } + } else if union.alternatives.len() == 1 { + todo!("this should need a special index, right?") } else { let sub_positions = arguments @@ -855,31 +866,30 @@ enum Decider<'a, T> { #[derive(Clone, Debug, PartialEq)] enum Choice<'a> { - Inline(Stores<'a>, Expr<'a>), + Inline(Stmt<'a>), Jump(Label), } -type Stores<'a> = &'a [(Symbol, Layout<'a>, Expr<'a>)]; +type StoresVec<'a> = bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>; pub fn optimize_when<'a>( env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, cond_symbol: Symbol, cond_layout: Layout<'a>, ret_layout: Layout<'a>, - opt_branches: Vec<(Pattern<'a>, Guard<'a>, Stores<'a>, Expr<'a>)>, -) -> Expr<'a> { + opt_branches: bumpalo::collections::Vec<'a, (Pattern<'a>, Guard<'a>, Stmt<'a>)>, +) -> Stmt<'a> { let (patterns, _indexed_branches) = opt_branches .into_iter() .enumerate() - .map(|(index, (pattern, guard, stores, branch))| { - ( - (guard, pattern, index as u64), - (index as u64, stores, branch), - ) + .map(|(index, (pattern, guard, branch))| { + ((guard, pattern, index as u64), (index as u64, branch)) }) .unzip(); - let indexed_branches: Vec<(u64, Stores<'a>, Expr<'a>)> = _indexed_branches; + let indexed_branches: Vec<(u64, Stmt<'a>)> = _indexed_branches; let decision_tree = compile(patterns); let decider = tree_to_decider(decision_tree); @@ -888,9 +898,8 @@ pub fn optimize_when<'a>( let mut choices = MutMap::default(); let mut jumps = Vec::new(); - for (index, stores, branch) in indexed_branches.into_iter() { - let ((branch_index, choice), opt_jump) = - create_choices(&target_counts, index, stores, branch); + for (index, branch) in indexed_branches.into_iter() { + let ((branch_index, choice), opt_jump) = create_choices(&target_counts, index, branch); if let Some(jump) = opt_jump { jumps.push(jump); @@ -901,16 +910,16 @@ pub fn optimize_when<'a>( let choice_decider = insert_choices(&choices, decider); - let (stores, expr) = decide_to_branching( + decide_to_branching( env, + procs, + layout_cache, cond_symbol, cond_layout, ret_layout, choice_decider, &jumps, - ); - - Expr::Store(stores, env.arena.alloc(expr)) + ) } fn path_to_expr<'a>( @@ -918,47 +927,61 @@ fn path_to_expr<'a>( symbol: Symbol, path: &Path, layout: &Layout<'a>, -) -> Expr<'a> { - path_to_expr_help(env, symbol, path, layout.clone()).0 +) -> (StoresVec<'a>, Symbol) { + let (symbol, stores, _) = path_to_expr_help2(env, symbol, path, layout.clone()); + + (stores, symbol) } -fn path_to_expr_help<'a>( +fn path_to_expr_help2<'a>( env: &mut Env<'a, '_>, - symbol: Symbol, - path: &Path, - layout: Layout<'a>, -) -> (Expr<'a>, Layout<'a>) { - match path { - Path::Unbox(unboxed) => path_to_expr_help(env, symbol, unboxed, layout), - Path::Empty => (Expr::Load(symbol), layout), + mut symbol: Symbol, + mut path: &Path, + mut layout: Layout<'a>, +) -> (Symbol, StoresVec<'a>, Layout<'a>) { + let mut stores = bumpalo::collections::Vec::new_in(env.arena); - Path::Index { - index, - tag_id, - path: nested, - } => { - let (outer_expr, outer_layout) = path_to_expr_help(env, symbol, nested, layout); + loop { + match path { + Path::Unbox(unboxed) => { + path = unboxed; + } + Path::Empty => break, - let (is_unwrapped, field_layouts) = match outer_layout { - Layout::Union(layouts) => (layouts.is_empty(), layouts[*tag_id as usize].to_vec()), - Layout::Struct(layouts) => (true, layouts.to_vec()), - other => (true, vec![other]), - }; + Path::Index { + index, + tag_id, + path: nested, + } => { + let (is_unwrapped, field_layouts) = match layout.clone() { + Layout::Union(layouts) => { + (layouts.is_empty(), layouts[*tag_id as usize].to_vec()) + } + Layout::Struct(layouts) => (true, layouts.to_vec()), + other => (true, vec![other]), + }; - debug_assert!(*index < field_layouts.len() as u64); + debug_assert!(*index < field_layouts.len() as u64); - let inner_layout = field_layouts[*index as usize].clone(); + let inner_layout = field_layouts[*index as usize].clone(); - let inner_expr = Expr::AccessAtIndex { - index: *index, - field_layouts: env.arena.alloc(field_layouts), - expr: env.arena.alloc(outer_expr), - is_unwrapped, - }; + let inner_expr = Expr::AccessAtIndex { + index: *index, + field_layouts: env.arena.alloc(field_layouts), + structure: symbol, + is_unwrapped, + }; - (inner_expr, inner_layout) + symbol = env.unique_symbol(); + stores.push((symbol, inner_layout.clone(), inner_expr)); + + layout = inner_layout; + path = nested; + } } } + + (symbol, stores, layout) } fn test_to_equality<'a>( @@ -967,8 +990,7 @@ fn test_to_equality<'a>( cond_layout: &Layout<'a>, path: &Path, test: Test<'a>, - tests: &mut Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>, -) { +) -> (StoresVec<'a>, Symbol, Symbol, Layout<'a>) { match test { Test::IsCtor { tag_id, @@ -980,7 +1002,7 @@ fn test_to_equality<'a>( // (e.g. record pattern guard matches) debug_assert!(union.alternatives.len() > 1); - let lhs = Expr::Int(tag_id as i64); + let lhs = Expr::Literal(Literal::Int(tag_id as i64)); let mut field_layouts = bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena); @@ -995,89 +1017,131 @@ fn test_to_equality<'a>( let rhs = Expr::AccessAtIndex { index: 0, field_layouts: field_layouts.into_bump_slice(), - expr: env.arena.alloc(Expr::Load(cond_symbol)), + structure: cond_symbol, is_unwrapped: union.alternatives.len() == 1, }; - tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64))); + let lhs_symbol = env.unique_symbol(); + let rhs_symbol = env.unique_symbol(); + + let mut stores = bumpalo::collections::Vec::with_capacity_in(2, env.arena); + + stores.push((lhs_symbol, Layout::Builtin(Builtin::Int64), lhs)); + stores.push((rhs_symbol, Layout::Builtin(Builtin::Int64), rhs)); + + ( + stores, + lhs_symbol, + rhs_symbol, + Layout::Builtin(Builtin::Int64), + ) } Test::IsInt(test_int) => { - let lhs = Expr::Int(test_int); - let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); + let lhs = Expr::Literal(Literal::Int(test_int)); + let lhs_symbol = env.unique_symbol(); + let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); + stores.push((lhs_symbol, Layout::Builtin(Builtin::Int64), lhs)); - tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64))); + ( + stores, + lhs_symbol, + rhs_symbol, + Layout::Builtin(Builtin::Int64), + ) } Test::IsFloat(test_int) => { // TODO maybe we can actually use i64 comparison here? let test_float = f64::from_bits(test_int as u64); - let lhs = Expr::Float(test_float); - let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); + let lhs = Expr::Literal(Literal::Float(test_float)); + let lhs_symbol = env.unique_symbol(); + let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); + stores.push((lhs_symbol, Layout::Builtin(Builtin::Float64), lhs)); - tests.push((lhs, rhs, Layout::Builtin(Builtin::Float64))); + ( + stores, + lhs_symbol, + rhs_symbol, + Layout::Builtin(Builtin::Float64), + ) } Test::IsByte { tag_id: test_byte, .. } => { - let lhs = Expr::Byte(test_byte); - let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); + let lhs = Expr::Literal(Literal::Byte(test_byte)); + let lhs_symbol = env.unique_symbol(); + let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); + stores.push((lhs_symbol, Layout::Builtin(Builtin::Int8), lhs)); - tests.push((lhs, rhs, Layout::Builtin(Builtin::Int8))); + ( + stores, + lhs_symbol, + rhs_symbol, + Layout::Builtin(Builtin::Int8), + ) } Test::IsBit(test_bit) => { - let lhs = Expr::Bool(test_bit); - let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); + let lhs = Expr::Literal(Literal::Bool(test_bit)); + let lhs_symbol = env.unique_symbol(); + let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); + stores.push((lhs_symbol, Layout::Builtin(Builtin::Int1), lhs)); - tests.push((lhs, rhs, Layout::Builtin(Builtin::Int1))); + ( + stores, + lhs_symbol, + rhs_symbol, + Layout::Builtin(Builtin::Int1), + ) } Test::IsStr(test_str) => { - let lhs = Expr::Str(env.arena.alloc(test_str)); - let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); + let lhs = Expr::Literal(Literal::Str(env.arena.alloc(test_str))); + let lhs_symbol = env.unique_symbol(); + let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); - tests.push((lhs, rhs, Layout::Builtin(Builtin::Str))); + stores.push((lhs_symbol, Layout::Builtin(Builtin::Str), lhs)); + + ( + stores, + lhs_symbol, + rhs_symbol, + Layout::Builtin(Builtin::Str), + ) } - Test::Guarded { - opt_test, - stores, - expr, - } => { - if let Some(nested) = opt_test { - test_to_equality(env, cond_symbol, cond_layout, path, *nested, tests); - } - - let lhs = Expr::Bool(true); - let rhs = Expr::Store(stores, env.arena.alloc(expr)); - - tests.push((lhs, rhs, Layout::Builtin(Builtin::Int1))); - } + Test::Guarded { .. } => unreachable!("should be handled elsewhere"), } } +// TODO procs and layout are currently unused, but potentially required +// for defining optional fields? +// if not, do remove +#[allow(clippy::too_many_arguments)] fn decide_to_branching<'a>( env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, cond_symbol: Symbol, cond_layout: Layout<'a>, ret_layout: Layout<'a>, decider: Decider<'a, Choice<'a>>, - jumps: &Vec<(u64, Stores<'a>, Expr<'a>)>, -) -> (Stores<'a>, Expr<'a>) { + jumps: &Vec<(u64, Stmt<'a>)>, +) -> Stmt<'a> { use Choice::*; use Decider::*; match decider { Leaf(Jump(label)) => { // we currently inline the jumps: does fewer jumps but produces a larger artifact - let (_, stores, expr) = jumps + let (_, expr) = jumps .iter() - .find(|(l, _, _)| l == &label) + .find(|(l, _)| l == &label) .expect("jump not in list of jumps"); - (stores, expr.clone()) + expr.clone() } - Leaf(Inline(stores, expr)) => (stores, expr), + Leaf(Inline(expr)) => expr, Chain { test_chain, success, @@ -1085,14 +1149,10 @@ fn decide_to_branching<'a>( } => { // generate a switch based on the test chain - let mut tests = Vec::with_capacity(test_chain.len()); - - for (path, test) in test_chain { - test_to_equality(env, cond_symbol, &cond_layout, &path, test, &mut tests); - } - - let (pass_stores, pass_expr) = decide_to_branching( + let pass_expr = decide_to_branching( env, + procs, + layout_cache, cond_symbol, cond_layout.clone(), ret_layout.clone(), @@ -1100,8 +1160,10 @@ fn decide_to_branching<'a>( jumps, ); - let (fail_stores, fail_expr) = decide_to_branching( + let fail_expr = decide_to_branching( env, + procs, + layout_cache, cond_symbol, cond_layout.clone(), ret_layout.clone(), @@ -1109,30 +1171,148 @@ fn decide_to_branching<'a>( jumps, ); - let fail = (fail_stores, &*env.arena.alloc(fail_expr)); - let pass = (pass_stores, &*env.arena.alloc(pass_expr)); + let fail = &*env.arena.alloc(fail_expr); + let pass = &*env.arena.alloc(pass_expr); - let condition = boolean_all(env.arena, tests); + let branching_symbol = env.unique_symbol(); + let branching_layout = Layout::Builtin(Builtin::Int1); - let branch_symbol = env.unique_symbol(); - let stores = [(branch_symbol, Layout::Builtin(Builtin::Int1), condition)]; + let mut cond = Stmt::Cond { + cond_symbol, + cond_layout: cond_layout.clone(), + branching_symbol, + branching_layout, + pass, + fail, + ret_layout, + }; - let cond_layout = Layout::Builtin(Builtin::Int1); + let true_symbol = env.unique_symbol(); - ( - env.arena.alloc(stores), - Expr::Store( - &[], - env.arena.alloc(Expr::Cond { + let mut tests = Vec::with_capacity(test_chain.len()); + + let mut guard = None; + + // Assumption: there is at most 1 guard, and it is the outer layer. + for (path, test) in test_chain { + match test { + Test::Guarded { + opt_test, + id, + symbol, + stmt, + } => { + if let Some(nested) = opt_test { + tests.push(test_to_equality( + env, + cond_symbol, + &cond_layout, + &path, + *nested, + )); + } + + // let (stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); + + guard = Some((symbol, id, stmt)); + } + + _ => tests.push(test_to_equality( + env, cond_symbol, - branch_symbol, - cond_layout, - pass, - fail, - ret_layout, - }), - ), - ) + &cond_layout, + &path, + test, + )), + } + } + + let mut current_symbol = branching_symbol; + + // TODO There must be some way to remove this iterator/loop + let nr = (tests.len() as i64) - 1 + (guard.is_some() as i64); + let accum_symbols = std::iter::once(true_symbol) + .chain((0..nr).map(|_| env.unique_symbol())) + .rev() + .collect::>(); + + let mut accum_it = accum_symbols.into_iter(); + + // the guard is the final thing that we check, so needs to be layered on first! + if let Some((_, id, stmt)) = guard { + let accum = accum_it.next().unwrap(); + let test_symbol = env.unique_symbol(); + + let and_expr = + Expr::RunLowLevel(LowLevel::And, env.arena.alloc([test_symbol, accum])); + + // write to the branching symbol + cond = Stmt::Let( + current_symbol, + and_expr, + Layout::Builtin(Builtin::Int1), + env.arena.alloc(cond), + ); + + // calculate the guard value + let param = Param { + symbol: test_symbol, + layout: Layout::Builtin(Builtin::Int1), + borrow: false, + }; + cond = Stmt::Join { + id, + parameters: env.arena.alloc([param]), + remainder: env.arena.alloc(stmt), + continuation: env.arena.alloc(cond), + }; + + // load all the variables (the guard might need them); + + current_symbol = accum; + } + + for ((new_stores, lhs, rhs, _layout), accum) in tests.into_iter().rev().zip(accum_it) { + let test_symbol = env.unique_symbol(); + let test = Expr::RunLowLevel( + LowLevel::Eq, + bumpalo::vec![in env.arena; lhs, rhs].into_bump_slice(), + ); + + let and_expr = + Expr::RunLowLevel(LowLevel::And, env.arena.alloc([test_symbol, accum])); + + // write to the branching symbol + cond = Stmt::Let( + current_symbol, + and_expr, + Layout::Builtin(Builtin::Int1), + env.arena.alloc(cond), + ); + + // write to the test symbol + cond = Stmt::Let( + test_symbol, + test, + Layout::Builtin(Builtin::Int1), + env.arena.alloc(cond), + ); + + for (symbol, layout, expr) in new_stores.into_iter() { + cond = Stmt::Let(symbol, expr, layout, env.arena.alloc(cond)); + } + + current_symbol = accum; + } + + cond = Stmt::Let( + true_symbol, + Expr::Literal(Literal::Bool(true)), + Layout::Builtin(Builtin::Int1), + env.arena.alloc(cond), + ); + + cond } FanOut { path, @@ -1141,23 +1321,28 @@ fn decide_to_branching<'a>( } => { // the cond_layout can change in the process. E.g. if the cond is a Tag, we actually // switch on the tag discriminant (currently an i64 value) - let (cond, cond_layout) = path_to_expr_help(env, cond_symbol, &path, cond_layout); + // NOTE the tag discriminant is not actually loaded, `cond` can point to a tag + let (cond, cond_stores_vec, cond_layout) = + path_to_expr_help2(env, cond_symbol, &path, cond_layout); - let (default_stores, default_expr) = decide_to_branching( + let default_branch = decide_to_branching( env, + procs, + layout_cache, cond_symbol, cond_layout.clone(), ret_layout.clone(), *fallback, jumps, ); - let default_branch = (default_stores, &*env.arena.alloc(default_expr)); let mut branches = bumpalo::collections::Vec::with_capacity_in(tests.len(), env.arena); for (test, decider) in tests { - let (stores, branch) = decide_to_branching( + let branch = decide_to_branching( env, + procs, + layout_cache, cond_symbol, cond_layout.clone(), ret_layout.clone(), @@ -1174,24 +1359,28 @@ fn decide_to_branching<'a>( other => todo!("other {:?}", other), }; - branches.push((tag, stores, branch)); + branches.push((tag, branch)); + } + + let mut switch = Stmt::Switch { + cond_layout, + cond_symbol: cond, + branches: branches.into_bump_slice(), + default_branch: env.arena.alloc(default_branch), + ret_layout, + }; + + for (symbol, layout, expr) in cond_stores_vec.into_iter() { + switch = Stmt::Let(symbol, expr, layout, env.arena.alloc(switch)); } // make a jump table based on the tests - ( - &[], - Expr::Switch { - cond: env.arena.alloc(cond), - cond_layout, - branches: branches.into_bump_slice(), - default_branch, - ret_layout, - }, - ) + switch } } } +/* fn boolean_all<'a>(arena: &'a Bump, tests: Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>) -> Expr<'a> { let mut expr = Expr::Bool(true); @@ -1212,6 +1401,7 @@ fn boolean_all<'a>(arena: &'a Bump, tests: Vec<(Expr<'a>, Expr<'a>, Layout<'a>)> expr } +*/ /// TREE TO DECIDER /// @@ -1385,19 +1575,15 @@ fn count_targets_help(decision_tree: &Decider, targets: &mut MutMap( target_counts: &MutMap, target: u64, - stores: Stores<'a>, - branch: Expr<'a>, -) -> ((u64, Choice<'a>), Option<(u64, Stores<'a>, Expr<'a>)>) { + branch: Stmt<'a>, +) -> ((u64, Choice<'a>), Option<(u64, Stmt<'a>)>) { match target_counts.get(&target) { None => unreachable!( "this should never happen: {:?} not in {:?}", target, target_counts ), - Some(1) => ((target, Choice::Inline(stores, branch)), None), - Some(_) => ( - (target, Choice::Jump(target)), - Some((target, stores, branch)), - ), + Some(1) => ((target, Choice::Inline(branch)), None), + Some(_) => ((target, Choice::Jump(target)), Some((target, branch))), } } diff --git a/compiler/mono/src/pattern.rs b/compiler/mono/src/exhaustive.rs similarity index 98% rename from compiler/mono/src/pattern.rs rename to compiler/mono/src/exhaustive.rs index bb1cbddc81..70077c2199 100644 --- a/compiler/mono/src/pattern.rs +++ b/compiler/mono/src/exhaustive.rs @@ -1,4 +1,4 @@ -use crate::expr::DestructType; +use crate::ir::DestructType; use roc_collections::all::{Index, MutMap}; use roc_module::ident::{Lowercase, TagName}; use roc_region::all::{Located, Region}; @@ -44,8 +44,8 @@ pub enum Literal { Str(Box), } -fn simplify<'a>(pattern: &crate::expr::Pattern<'a>) -> Pattern { - use crate::expr::Pattern::*; +fn simplify<'a>(pattern: &crate::ir::Pattern<'a>) -> Pattern { + use crate::ir::Pattern::*; match pattern { IntLiteral(v) => Literal(Literal::Int(*v)), @@ -137,7 +137,7 @@ pub enum Guard { pub fn check<'a>( region: Region, - patterns: &[(Located>, Guard)], + patterns: &[(Located>, Guard)], context: Context, ) -> Result<(), Vec> { let mut errors = Vec::new(); @@ -153,7 +153,7 @@ pub fn check<'a>( pub fn check_patterns<'a>( region: Region, context: Context, - patterns: &[(Located>, Guard)], + patterns: &[(Located>, Guard)], errors: &mut Vec, ) { match to_nonredundant_rows(region, patterns) { @@ -286,7 +286,7 @@ fn recover_ctor( /// INVARIANT: Produces a list of rows where (forall row. length row == 1) fn to_nonredundant_rows<'a>( overall_region: Region, - patterns: &[(Located>, Guard)], + patterns: &[(Located>, Guard)], ) -> Result>, Error> { let mut checked_rows = Vec::with_capacity(patterns.len()); diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs new file mode 100644 index 0000000000..d830a63506 --- /dev/null +++ b/compiler/mono/src/inc_dec.rs @@ -0,0 +1,901 @@ +use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt}; +use crate::layout::Layout; +use bumpalo::collections::Vec; +use bumpalo::Bump; +use roc_collections::all::{MutMap, MutSet}; +use roc_module::symbol::Symbol; + +pub fn free_variables(stmt: &Stmt<'_>) -> MutSet { + let (mut occuring, bound) = occuring_variables(stmt); + + for ref s in bound { + occuring.remove(s); + } + + occuring +} + +pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet, MutSet) { + let mut stack = std::vec![stmt]; + let mut result = MutSet::default(); + let mut bound_variables = MutSet::default(); + + while let Some(stmt) = stack.pop() { + use Stmt::*; + + match stmt { + Let(symbol, expr, _, cont) => { + occuring_variables_expr(expr, &mut result); + result.insert(*symbol); + bound_variables.insert(*symbol); + stack.push(cont); + } + Ret(symbol) => { + result.insert(*symbol); + } + + Inc(symbol, cont) | Dec(symbol, cont) => { + result.insert(*symbol); + stack.push(cont); + } + + Jump(_, arguments) => { + result.extend(arguments.iter().copied()); + } + + Join { + parameters, + continuation, + remainder, + .. + } => { + result.extend(parameters.iter().map(|p| p.symbol)); + + stack.push(continuation); + stack.push(remainder); + } + + Switch { + cond_symbol, + branches, + default_branch, + .. + } => { + result.insert(*cond_symbol); + + stack.extend(branches.iter().map(|(_, s)| s)); + stack.push(default_branch); + } + + Cond { + cond_symbol, + branching_symbol, + pass, + fail, + .. + } => { + result.insert(*cond_symbol); + result.insert(*branching_symbol); + + stack.push(pass); + stack.push(fail); + } + + RuntimeError(_) => {} + } + } + + (result, bound_variables) +} + +pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet) { + use Expr::*; + + match expr { + FunctionPointer(symbol, _) + | AccessAtIndex { + structure: symbol, .. + } => { + result.insert(*symbol); + } + + FunctionCall { args, .. } => { + // NOTE thouth the function name does occur, it is a static constant in the program + // for liveness, it should not be included here. + result.extend(args.iter().copied()); + } + + Tag { arguments, .. } + | Struct(arguments) + | Array { + elems: arguments, .. + } => { + result.extend(arguments.iter().copied()); + } + + RunLowLevel(_, _) | EmptyArray | RuntimeErrorFunction(_) | Literal(_) => {} + } +} + +/* Insert explicit RC instructions. So, it assumes the input code does not contain `inc` nor `dec` instructions. + This transformation is applied before lower level optimizations + that introduce the instructions `release` and `set` +*/ + +#[derive(Clone, Debug, Copy)] +struct VarInfo { + reference: bool, // true if the variable may be a reference (aka pointer) at runtime + persistent: bool, // true if the variable is statically known to be marked a Persistent at runtime + consume: bool, // true if the variable RC must be "consumed" +} + +type VarMap = MutMap; +type LiveVarSet = MutSet; +type JPLiveVarMap = MutMap; + +#[derive(Clone, Debug)] +pub struct Context<'a> { + arena: &'a Bump, + vars: VarMap, + jp_live_vars: JPLiveVarMap, // map: join point => live variables + local_context: LocalContext<'a>, // we use it to store the join point declarations + function_params: MutMap]>, +} + +fn update_live_vars<'a>(expr: &Expr<'a>, v: &LiveVarSet) -> LiveVarSet { + let mut v = v.clone(); + + occuring_variables_expr(expr, &mut v); + + v +} + +fn is_first_occurence(xs: &[Symbol], i: usize) -> bool { + match xs.get(i) { + None => unreachable!(), + Some(s) => i == xs.iter().position(|v| s == v).unwrap(), + } +} + +fn get_num_consumptions(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> usize +where + F: Fn(usize) -> bool, +{ + let mut n = 0; + + for (i, y) in ys.iter().enumerate() { + if x == *y && consume_param_pred(i) { + n += 1; + } + } + n +} + +fn is_borrow_param_help(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> bool +where + F: Fn(usize) -> bool, +{ + ys.iter() + .enumerate() + .any(|(i, y)| x == *y && !consume_param_pred(i)) +} + +fn is_borrow_param(x: Symbol, ys: &[Symbol], ps: &[Param]) -> bool { + // default to owned arguments + let pred = |i: usize| match ps.get(i) { + Some(param) => !param.borrow, + None => true, + }; + is_borrow_param_help(x, ys, pred) +} + +// We do not need to consume the projection of a variable that is not consumed +fn consume_expr(m: &VarMap, e: &Expr<'_>) -> bool { + match e { + Expr::AccessAtIndex { structure: x, .. } => match m.get(x) { + Some(info) => info.consume, + None => true, + }, + _ => true, + } +} + +impl<'a> Context<'a> { + pub fn new(arena: &'a Bump) -> Self { + Self { + arena, + vars: MutMap::default(), + jp_live_vars: MutMap::default(), + local_context: LocalContext::default(), + function_params: MutMap::default(), + } + } + + fn get_var_info(&self, symbol: Symbol) -> VarInfo { + match self.vars.get(&symbol) { + Some(info) => *info, + None => panic!( + "Symbol {:?} {} has no info in {:?}", + symbol, symbol, self.vars + ), + } + } + + fn add_inc(&self, symbol: Symbol, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + let info = self.get_var_info(symbol); + + if info.persistent { + // persistent values are never reference counted + return stmt; + } + + // if this symbol is never a reference, don't emit + if !info.reference { + return stmt; + } + + self.arena.alloc(Stmt::Inc(symbol, stmt)) + } + + fn add_dec(&self, symbol: Symbol, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + let info = self.get_var_info(symbol); + + if info.persistent { + // persistent values are never reference counted + return stmt; + } + + // if this symbol is never a reference, don't emit + if !info.reference { + return stmt; + } + + self.arena.alloc(Stmt::Dec(symbol, stmt)) + } + + fn add_inc_before_consume_all_help( + &self, + xs: &[Symbol], + consume_param_pred: F, + mut b: &'a Stmt<'a>, + live_vars_after: &LiveVarSet, + ) -> &'a Stmt<'a> + where + F: Fn(usize) -> bool + Clone, + { + for (i, x) in xs.iter().enumerate() { + let info = self.get_var_info(*x); + if !info.reference || !is_first_occurence(xs, i) { + // do nothing + } else { + // number of times the argument is used (in the body?) + let num_consumptions = get_num_consumptions(*x, xs, consume_param_pred.clone()); + + // `x` is not a variable that must be consumed by the current procedure + // `x` is live after executing instruction + // `x` is used in a position that is passed as a borrow reference + let lives_on = !info.consume + || live_vars_after.contains(x) + || is_borrow_param_help(*x, xs, consume_param_pred.clone()); + + let num_incs = if lives_on { + num_consumptions + } else { + num_consumptions - 1 + }; + + // Lean can increment by more than 1 at once. Is that needed? + debug_assert!(num_incs <= 1); + + if num_incs == 1 { + b = self.add_inc(*x, b); + } + } + } + + b + } + + fn add_inc_before_consume_all( + &self, + xs: &[Symbol], + b: &'a Stmt<'a>, + live_vars_after: &LiveVarSet, + ) -> &'a Stmt<'a> { + self.add_inc_before_consume_all_help(xs, |_: usize| true, b, live_vars_after) + } + + fn add_inc_before_help( + &self, + xs: &[Symbol], + consume_param_pred: F, + mut b: &'a Stmt<'a>, + live_vars_after: &LiveVarSet, + ) -> &'a Stmt<'a> + where + F: Fn(usize) -> bool + Clone, + { + for (i, x) in xs.iter().enumerate() { + let info = self.get_var_info(*x); + if !info.reference || !is_first_occurence(xs, i) { + // do nothing + } else { + let num_consumptions = get_num_consumptions(*x, xs, consume_param_pred.clone()); // number of times the argument is used + let num_incs = if !info.consume || // `x` is not a variable that must be consumed by the current procedure + live_vars_after.contains(x) || // `x` is live after executing instruction + is_borrow_param_help( *x ,xs, consume_param_pred.clone()) + // `x` is used in a position that is passed as a borrow reference + { + num_consumptions + } else { + num_consumptions - 1 + }; + + // verify that this is indeed always 1 + debug_assert!(num_incs <= 1); + + if num_incs == 1 { + b = self.add_inc(*x, b) + } + } + } + b + } + + fn add_inc_before( + &self, + xs: &[Symbol], + ps: &[Param], + b: &'a Stmt<'a>, + live_vars_after: &LiveVarSet, + ) -> &'a Stmt<'a> { + // default to owned arguments + let pred = |i: usize| match ps.get(i) { + Some(param) => !param.borrow, + None => true, + }; + self.add_inc_before_help(xs, pred, b, live_vars_after) + } + + fn add_dec_if_needed( + &self, + x: Symbol, + b: &'a Stmt<'a>, + b_live_vars: &LiveVarSet, + ) -> &'a Stmt<'a> { + if self.must_consume(x) && !b_live_vars.contains(&x) { + self.add_dec(x, b) + } else { + b + } + } + + fn must_consume(&self, x: Symbol) -> bool { + let info = self.get_var_info(x); + info.reference && info.consume + } + + fn add_dec_after_application( + &self, + xs: &[Symbol], + ps: &[Param], + mut b: &'a Stmt<'a>, + b_live_vars: &LiveVarSet, + ) -> &'a Stmt<'a> { + for (i, x) in xs.iter().enumerate() { + /* We must add a `dec` if `x` must be consumed, it is alive after the application, + and it has been borrowed by the application. + Remark: `x` may occur multiple times in the application (e.g., `f x y x`). + This is why we check whether it is the first occurrence. */ + if self.must_consume(*x) + && is_first_occurence(xs, i) + && is_borrow_param(*x, xs, ps) + && !b_live_vars.contains(x) + { + b = self.add_dec(*x, b) + } + } + + b + } + + #[allow(clippy::many_single_char_names)] + fn visit_variable_declaration( + &self, + z: Symbol, + v: Expr<'a>, + l: Layout<'a>, + b: &'a Stmt<'a>, + b_live_vars: &LiveVarSet, + ) -> (&'a Stmt<'a>, LiveVarSet) { + use Expr::*; + + let mut live_vars = update_live_vars(&v, &b_live_vars); + live_vars.remove(&z); + + let new_b = match v { + Tag { arguments: ys, .. } | Struct(ys) | Array { elems: ys, .. } => self + .add_inc_before_consume_all( + ys, + self.arena.alloc(Stmt::Let(z, v, l, b)), + &b_live_vars, + ), + AccessAtIndex { structure: x, .. } => { + let b = self.add_dec_if_needed(x, b, b_live_vars); + let info_x = self.get_var_info(x); + let b = if info_x.consume { + self.add_inc(z, b) + } else { + b + }; + + self.arena.alloc(Stmt::Let(z, v, l, b)) + } + + RunLowLevel(_, _) => { + // THEORY: runlowlevel only occurs + // + // - in a custom hard-coded function + // - when we insert them as compiler authors + // + // if we're carefule to only use RunLowLevel for non-rc'd types + // (e.g. when building a cond/switch, we check equality on integers, and to boolean and) + // then RunLowLevel should not change in any way the refcounts. + + // let b = self.add_dec_after_application(ys, ps, b, b_live_vars); + self.arena.alloc(Stmt::Let(z, v, l, b)) + } + + FunctionCall { + args: ys, + call_type, + arg_layouts, + .. + } => { + // this is where the borrow signature would come in + //let ps := (getDecl ctx f).params; + use crate::ir::CallType; + use crate::layout::Builtin; + let symbol = match call_type { + CallType::ByName(s) => s, + CallType::ByPointer(s) => s, + }; + + let ps = Vec::from_iter_in( + arg_layouts.iter().map(|layout| { + let borrow = match layout { + Layout::Builtin(Builtin::List(_, _)) => true, + _ => false, + }; + + Param { + symbol, + borrow, + layout: layout.clone(), + } + }), + self.arena, + ) + .into_bump_slice(); + + let b = self.add_dec_after_application(ys, ps, b, b_live_vars); + self.arena.alloc(Stmt::Let(z, v, l, b)) + } + + EmptyArray | FunctionPointer(_, _) | Literal(_) | RuntimeErrorFunction(_) => { + // EmptyArray is always stack-allocated + // function pointers are persistent + self.arena.alloc(Stmt::Let(z, v, l, b)) + } + }; + + (new_b, live_vars) + } + + fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self { + let mut ctx = self.clone(); + + // TODO actually make these non-constant + + // can this type be reference-counted at runtime? + let reference = layout.contains_refcounted(); + + // is this value a constant? + let persistent = false; + + // must this value be consumed? + let consume = consume_expr(&ctx.vars, expr); + + let info = VarInfo { + reference, + persistent, + consume, + }; + + ctx.vars.insert(symbol, info); + + ctx + } + + fn update_var_info_with_params(&self, ps: &[Param]) -> Self { + //def updateVarInfoWithParams (ctx : Context) (ps : Array Param) : Context := + //let m := ps.foldl (fun (m : VarMap) p => m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }) ctx.varMap; + //{ ctx with varMap := m } + let mut ctx = self.clone(); + + for p in ps.iter() { + let info = VarInfo { + reference: p.layout.contains_refcounted(), + consume: !p.borrow, + persistent: false, + }; + ctx.vars.insert(p.symbol, info); + } + + ctx + } + + /* Add `dec` instructions for parameters that are references, are not alive in `b`, and are not borrow. + That is, we must make sure these parameters are consumed. */ + fn add_dec_for_dead_params( + &self, + ps: &[Param<'a>], + mut b: &'a Stmt<'a>, + b_live_vars: &LiveVarSet, + ) -> &'a Stmt<'a> { + for p in ps.iter() { + if !p.borrow && p.layout.contains_refcounted() && !b_live_vars.contains(&p.symbol) { + b = self.add_dec(p.symbol, b) + } + } + + b + } + + fn add_dec_for_alt( + &self, + case_live_vars: &LiveVarSet, + alt_live_vars: &LiveVarSet, + mut b: &'a Stmt<'a>, + ) -> &'a Stmt<'a> { + for x in case_live_vars.iter() { + if !alt_live_vars.contains(x) && self.must_consume(*x) { + b = self.add_dec(*x, b); + } + } + + b + } + + fn visit_stmt(&self, stmt: &'a Stmt<'a>) -> (&'a Stmt<'a>, LiveVarSet) { + use Stmt::*; + + // let-chains can be very long, especially for large (list) literals + // in (rust) debug mode, this function can overflow the stack for such values + // so we have to write an explicit loop. + { + let mut cont = stmt; + let mut triples = Vec::new_in(self.arena); + while let Stmt::Let(symbol, expr, layout, new_cont) = cont { + triples.push((symbol, expr, layout)); + cont = new_cont; + } + + if !triples.is_empty() { + let mut ctx = self.clone(); + for (symbol, expr, layout) in triples.iter() { + ctx = ctx.update_var_info(**symbol, layout, expr); + } + let (mut b, mut b_live_vars) = ctx.visit_stmt(cont); + for (symbol, expr, layout) in triples.into_iter().rev() { + let pair = ctx.visit_variable_declaration( + *symbol, + (*expr).clone(), + (*layout).clone(), + b, + &b_live_vars, + ); + + b = pair.0; + b_live_vars = pair.1; + } + + return (b, b_live_vars); + } + } + + match stmt { + Let(symbol, expr, layout, cont) => { + let ctx = self.update_var_info(*symbol, layout, expr); + let (b, b_live_vars) = ctx.visit_stmt(cont); + ctx.visit_variable_declaration( + *symbol, + expr.clone(), + layout.clone(), + b, + &b_live_vars, + ) + } + + Join { + id: j, + parameters: xs, + remainder: b, + continuation: v, + } => { + let xs = *xs; + + let v_orig = v; + + let (v, v_live_vars) = { + let ctx = self.update_var_info_with_params(xs); + ctx.visit_stmt(v) + }; + + let v = self.add_dec_for_dead_params(xs, v, &v_live_vars); + let mut ctx = self.clone(); + + // NOTE deviation from lean, insert into local context + ctx.local_context.join_points.insert(*j, (xs, v_orig)); + + update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars); + + let (b, b_live_vars) = ctx.visit_stmt(b); + + ( + ctx.arena.alloc(Join { + id: *j, + parameters: xs, + remainder: b, + continuation: v, + }), + b_live_vars, + ) + } + + Ret(x) => { + let info = self.get_var_info(*x); + + let mut live_vars = MutSet::default(); + live_vars.insert(*x); + + if info.reference && !info.consume { + (self.add_inc(*x, stmt), live_vars) + } else { + (stmt, live_vars) + } + } + + Jump(j, xs) => { + let empty = MutSet::default(); + let j_live_vars = match self.jp_live_vars.get(j) { + Some(vars) => vars, + None => &empty, + }; + let ps = self.local_context.join_points.get(j).unwrap().0; + let b = self.add_inc_before(xs, ps, stmt, j_live_vars); + + let b_live_vars = collect_stmt(b, &self.jp_live_vars, MutSet::default()); + + (b, b_live_vars) + } + + Cond { + pass, + fail, + cond_symbol, + cond_layout, + branching_symbol, + branching_layout, + ret_layout, + } => { + let case_live_vars = collect_stmt(stmt, &self.jp_live_vars, MutSet::default()); + + let pass = { + // TODO should we use ctor info like Lean? + let ctx = self.clone(); + let (b, alt_live_vars) = ctx.visit_stmt(pass); + ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b) + }; + + let fail = { + // TODO should we use ctor info like Lean? + let ctx = self.clone(); + let (b, alt_live_vars) = ctx.visit_stmt(fail); + ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b) + }; + + let cond = self.arena.alloc(Cond { + cond_symbol: *cond_symbol, + cond_layout: cond_layout.clone(), + branching_symbol: *branching_symbol, + branching_layout: branching_layout.clone(), + pass, + fail, + ret_layout: ret_layout.clone(), + }); + + (cond, case_live_vars) + } + + Switch { + cond_symbol, + cond_layout, + branches, + default_branch, + ret_layout, + } => { + let case_live_vars = collect_stmt(stmt, &self.jp_live_vars, MutSet::default()); + + let branches = Vec::from_iter_in( + branches.iter().map(|(label, branch)| { + // TODO should we use ctor info like Lean? + let ctx = self.clone(); + let (b, alt_live_vars) = ctx.visit_stmt(branch); + let b = ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b); + + (*label, b.clone()) + }), + self.arena, + ) + .into_bump_slice(); + + let default_branch = { + // TODO should we use ctor info like Lean? + let ctx = self.clone(); + let (b, alt_live_vars) = ctx.visit_stmt(default_branch); + ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b) + }; + + let switch = self.arena.alloc(Switch { + cond_symbol: *cond_symbol, + branches, + default_branch, + cond_layout: cond_layout.clone(), + ret_layout: ret_layout.clone(), + }); + + (switch, case_live_vars) + } + + RuntimeError(_) | Inc(_, _) | Dec(_, _) => (stmt, MutSet::default()), + } + } +} + +#[derive(Clone, Debug, Default)] +struct LocalContext<'a> { + join_points: MutMap], &'a Stmt<'a>)>, +} + +pub fn collect_stmt( + stmt: &Stmt<'_>, + jp_live_vars: &JPLiveVarMap, + mut vars: LiveVarSet, +) -> LiveVarSet { + use Stmt::*; + + match stmt { + Let(symbol, expr, _, cont) => { + vars = collect_stmt(cont, jp_live_vars, vars); + vars.remove(symbol); + let mut result = MutSet::default(); + occuring_variables_expr(expr, &mut result); + vars.extend(result); + + vars + } + Ret(symbol) => { + vars.insert(*symbol); + vars + } + + Inc(symbol, cont) | Dec(symbol, cont) => { + vars.insert(*symbol); + collect_stmt(cont, jp_live_vars, vars) + } + + Jump(_, arguments) => { + vars.extend(arguments.iter().copied()); + vars + } + + Join { + id: j, + parameters, + remainder: b, + continuation: v, + } => { + let mut j_live_vars = collect_stmt(v, jp_live_vars, MutSet::default()); + for param in parameters.iter() { + j_live_vars.remove(¶m.symbol); + } + + let mut jp_live_vars = jp_live_vars.clone(); + jp_live_vars.insert(*j, j_live_vars); + + collect_stmt(b, &jp_live_vars, vars) + } + + Switch { + cond_symbol, + branches, + default_branch, + .. + } => { + vars.insert(*cond_symbol); + + for (_, branch) in branches.iter() { + vars.extend(collect_stmt(branch, jp_live_vars, vars.clone())); + } + + vars.extend(collect_stmt(default_branch, jp_live_vars, vars.clone())); + + vars + } + + Cond { + cond_symbol, + branching_symbol, + pass, + fail, + .. + } => { + vars.insert(*cond_symbol); + vars.insert(*branching_symbol); + + vars.extend(collect_stmt(pass, jp_live_vars, vars.clone())); + vars.extend(collect_stmt(fail, jp_live_vars, vars.clone())); + + vars + } + + RuntimeError(_) => vars, + } +} + +fn update_jp_live_vars(j: JoinPointId, ys: &[Param], v: &Stmt<'_>, m: &mut JPLiveVarMap) { + let j_live_vars = MutSet::default(); + let mut j_live_vars = collect_stmt(v, m, j_live_vars); + + for param in ys { + j_live_vars.remove(¶m.symbol); + } + + m.insert(j, j_live_vars); +} + +pub fn visit_declaration<'a>(arena: &'a Bump, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + let ctx = Context::new(arena); + + let params = &[] as &[_]; + let ctx = ctx.update_var_info_with_params(params); + let (b, b_live_vars) = ctx.visit_stmt(stmt); + ctx.add_dec_for_dead_params(params, b, &b_live_vars) +} + +pub fn visit_proc<'a>(arena: &'a Bump, proc: &mut Proc<'a>) { + let ctx = Context::new(arena); + + if proc.name.is_builtin() { + // we must take care of our own refcounting in builtins + return; + } + + let params = Vec::from_iter_in( + proc.args.iter().map(|(layout, symbol)| Param { + symbol: *symbol, + layout: layout.clone(), + borrow: layout.contains_refcounted(), + }), + arena, + ) + .into_bump_slice(); + + let stmt = arena.alloc(proc.body.clone()); + let ctx = ctx.update_var_info_with_params(params); + let (b, b_live_vars) = ctx.visit_stmt(stmt); + let b = ctx.add_dec_for_dead_params(params, b, &b_live_vars); + + proc.body = b.clone(); +} diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs new file mode 100644 index 0000000000..1bbd90bafc --- /dev/null +++ b/compiler/mono/src/ir.rs @@ -0,0 +1,3223 @@ +use self::InProgressProc::*; +use crate::exhaustive::{Ctor, Guard, RenderAs, TagId}; +use crate::layout::{Builtin, Layout, LayoutCache, LayoutProblem}; +use bumpalo::collections::Vec; +use bumpalo::Bump; +use roc_collections::all::{default_hasher, MutMap, MutSet}; +use roc_module::ident::{Ident, Lowercase, TagName}; +use roc_module::low_level::LowLevel; +use roc_module::symbol::{IdentIds, ModuleId, Symbol}; +use roc_problem::can::RuntimeError; +use roc_region::all::{Located, Region}; +use roc_types::subs::{Content, FlatType, Subs, Variable}; +use std::collections::HashMap; +use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder}; + +#[derive(Clone, Debug)] +pub enum MonoProblem { + PatternProblem(crate::exhaustive::Error), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PartialProc<'a> { + pub annotation: Variable, + pub pattern_symbols: Vec<'a, Symbol>, + pub body: roc_can::expr::Expr, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PendingSpecialization<'a> { + pub fn_var: Variable, + pub ret_var: Variable, + pub pattern_vars: Vec<'a, Variable>, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Proc<'a> { + pub name: Symbol, + pub args: &'a [(Layout<'a>, Symbol)], + pub body: Stmt<'a>, + pub closes_over: Layout<'a>, + pub ret_layout: Layout<'a>, +} + +impl<'a> Proc<'a> { + pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D, _parens: bool) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + let args_doc = self + .args + .iter() + .map(|(_, symbol)| alloc.text(format!("{}", symbol))); + + alloc + .text(format!("procedure {} (", self.name)) + .append(alloc.intersperse(args_doc, ", ")) + .append("):") + .append(alloc.hardline()) + .append(self.body.to_doc(alloc).indent(4)) + } + + pub fn to_pretty(&self, width: usize) -> String { + let allocator = BoxAllocator; + let mut w = std::vec::Vec::new(); + self.to_doc::<_, ()>(&allocator, false) + .1 + .render(width, &mut w) + .unwrap(); + w.push(b'\n'); + String::from_utf8(w).unwrap() + } +} + +#[derive(Clone, Debug, PartialEq, Default)] +pub struct Procs<'a> { + pub partial_procs: MutMap>, + pub module_thunks: MutSet, + pub pending_specializations: + Option, PendingSpecialization<'a>>>>, + pub specialized: MutMap<(Symbol, Layout<'a>), InProgressProc<'a>>, + pub runtime_errors: MutMap, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum InProgressProc<'a> { + InProgress, + Done(Proc<'a>), +} + +impl<'a> Procs<'a> { + // TODO investigate make this an iterator? + pub fn get_specialized_procs(self, arena: &'a Bump) -> MutMap<(Symbol, Layout<'a>), Proc<'a>> { + let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher()); + + for (key, in_prog_proc) in self.specialized.into_iter() { + match in_prog_proc { + InProgress => unreachable!("The procedure {:?} should have be done by now", key), + Done(mut proc) => { + crate::inc_dec::visit_proc(arena, &mut proc); + result.insert(key, proc); + } + } + } + result + } + + // TODO trim down these arguments! + #[allow(clippy::too_many_arguments)] + pub fn insert_named( + &mut self, + env: &mut Env<'a, '_>, + layout_cache: &mut LayoutCache<'a>, + name: Symbol, + annotation: Variable, + loc_args: std::vec::Vec<(Variable, Located)>, + loc_body: Located, + ret_var: Variable, + ) { + match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { + Ok((_, pattern_symbols, body)) => { + // a named closure. Since these aren't specialized by the surrounding + // context, we can't add pending specializations for them yet. + // (If we did, all named polymorphic functions would immediately error + // on trying to convert a flex var to a Layout.) + self.partial_procs.insert( + name, + PartialProc { + annotation, + pattern_symbols, + body: body.value, + }, + ); + } + + Err(error) => { + // If the function has invalid patterns in its arguments, + // its call sites will code gen to runtime errors. This happens + // at the call site so we don't have to try to define the + // function LLVM, which would be difficult considering LLVM + // wants to know what symbols each argument corresponds to, + // and in this case the patterns were invalid, so we don't know + // what the symbols ought to be. + + let error_msg = format!("TODO generate a RuntimeError message for {:?}", error); + + self.runtime_errors.insert(name, env.arena.alloc(error_msg)); + } + } + } + + // TODO trim these down + #[allow(clippy::too_many_arguments)] + pub fn insert_anonymous( + &mut self, + env: &mut Env<'a, '_>, + symbol: Symbol, + annotation: Variable, + loc_args: std::vec::Vec<(Variable, Located)>, + loc_body: Located, + ret_var: Variable, + layout_cache: &mut LayoutCache<'a>, + ) -> Result, RuntimeError> { + match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { + Ok((pattern_vars, pattern_symbols, body)) => { + // an anonymous closure. These will always be specialized already + // by the surrounding context, so we can add pending specializations + // for them immediately. + let layout = layout_cache + .from_var(env.arena, annotation, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + // if we've already specialized this one, no further work is needed. + // + // NOTE: this #[allow(clippy::map_entry)] here is for correctness! + // Changing it to use .entry() would necessarily make it incorrect. + #[allow(clippy::map_entry)] + if !self.specialized.contains_key(&(symbol, layout.clone())) { + let pending = PendingSpecialization { + ret_var, + fn_var: annotation, + pattern_vars, + }; + + match &mut self.pending_specializations { + Some(pending_specializations) => { + // register the pending specialization, so this gets code genned later + add_pending(pending_specializations, symbol, layout.clone(), pending); + + debug_assert!(!self.partial_procs.contains_key(&symbol), "Procs was told to insert a value for symbol {:?}, but there was already an entry for that key! Procs should never attempt to insert duplicates.", symbol); + + self.partial_procs.insert( + symbol, + PartialProc { + annotation, + pattern_symbols, + body: body.value, + }, + ); + } + None => { + // TODO should pending_procs hold a Rc? + let partial_proc = PartialProc { + annotation, + pattern_symbols, + body: body.value, + }; + + // Mark this proc as in-progress, so if we're dealing with + // mutually recursive functions, we don't loop forever. + // (We had a bug around this before this system existed!) + self.specialized + .insert((symbol, layout.clone()), InProgress); + + match specialize(env, self, symbol, layout_cache, pending, partial_proc) + { + Ok(proc) => { + self.specialized + .insert((symbol, layout.clone()), Done(proc)); + } + Err(error) => { + let error_msg = format!( + "TODO generate a RuntimeError message for {:?}", + error + ); + self.runtime_errors + .insert(symbol, env.arena.alloc(error_msg)); + } + } + } + } + } + + Ok(layout) + } + Err(loc_error) => Err(loc_error.value), + } + } +} + +fn add_pending<'a>( + pending_specializations: &mut MutMap, PendingSpecialization<'a>>>, + symbol: Symbol, + layout: Layout<'a>, + pending: PendingSpecialization<'a>, +) { + let all_pending = pending_specializations + .entry(symbol) + .or_insert_with(|| HashMap::with_capacity_and_hasher(1, default_hasher())); + + all_pending.insert(layout, pending); +} + +#[derive(Default)] +pub struct Specializations<'a> { + by_symbol: MutMap, Proc<'a>>>, + runtime_errors: MutSet, +} + +impl<'a> Specializations<'a> { + pub fn insert(&mut self, symbol: Symbol, layout: Layout<'a>, proc: Proc<'a>) { + let procs_by_layout = self + .by_symbol + .entry(symbol) + .or_insert_with(|| HashMap::with_capacity_and_hasher(1, default_hasher())); + + // If we already have an entry for this, it should be no different + // from what we're about to insert. + debug_assert!( + !procs_by_layout.contains_key(&layout) || procs_by_layout.get(&layout) == Some(&proc) + ); + + // We shouldn't already have a runtime error recorded for this symbol + debug_assert!(!self.runtime_errors.contains(&symbol)); + + procs_by_layout.insert(layout, proc); + } + + pub fn runtime_error(&mut self, symbol: Symbol) { + // We shouldn't already have a normal proc recorded for this symbol + debug_assert!(!self.by_symbol.contains_key(&symbol)); + + self.runtime_errors.insert(symbol); + } + + pub fn into_owned(self) -> (MutMap, Proc<'a>>>, MutSet) { + (self.by_symbol, self.runtime_errors) + } + + pub fn len(&self) -> usize { + let runtime_errors: usize = self.runtime_errors.len(); + let specializations: usize = self.by_symbol.len(); + + runtime_errors + specializations + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +pub struct Env<'a, 'i> { + pub arena: &'a Bump, + pub subs: &'a mut Subs, + pub problems: &'i mut std::vec::Vec, + pub home: ModuleId, + pub ident_ids: &'i mut IdentIds, +} + +impl<'a, 'i> Env<'a, 'i> { + pub fn unique_symbol(&mut self) -> Symbol { + let ident_id = self.ident_ids.gen_unique(); + + self.home.register_debug_idents(&self.ident_ids); + + Symbol::new(self.home, ident_id) + } +} + +#[derive(Clone, Debug, PartialEq, Copy, Eq, Hash)] +pub struct JoinPointId(Symbol); + +#[derive(Clone, Debug, PartialEq)] +pub struct Param<'a> { + pub symbol: Symbol, + pub borrow: bool, + pub layout: Layout<'a>, +} + +pub type Stores<'a> = &'a [(Symbol, Layout<'a>, Expr<'a>)]; +#[derive(Clone, Debug, PartialEq)] +pub enum Stmt<'a> { + Let(Symbol, Expr<'a>, Layout<'a>, &'a Stmt<'a>), + Switch { + /// This *must* stand for an integer, because Switch potentially compiles to a jump table. + cond_symbol: Symbol, + cond_layout: Layout<'a>, + /// The u64 in the tuple will be compared directly to the condition Expr. + /// If they are equal, this branch will be taken. + branches: &'a [(u64, Stmt<'a>)], + /// If no other branches pass, this default branch will be taken. + default_branch: &'a Stmt<'a>, + /// Each branch must return a value of this type. + ret_layout: Layout<'a>, + }, + Cond { + // The left-hand side of the conditional comparison and the right-hand side. + // These are stored separately because there are different machine instructions + // for e.g. "compare float and jump" vs. "compare integer and jump" + + // symbol storing the original expression that we branch on, e.g. `Ok 42` + // required for RC logic + cond_symbol: Symbol, + cond_layout: Layout<'a>, + + // symbol storing the value that we branch on, e.g. `1` representing the `Ok` tag + branching_symbol: Symbol, + branching_layout: Layout<'a>, + + // What to do if the condition either passes or fails + pass: &'a Stmt<'a>, + fail: &'a Stmt<'a>, + ret_layout: Layout<'a>, + }, + Ret(Symbol), + Inc(Symbol, &'a Stmt<'a>), + Dec(Symbol, &'a Stmt<'a>), + Join { + id: JoinPointId, + parameters: &'a [Param<'a>], + /// does not contain jumps to this id + continuation: &'a Stmt<'a>, + /// contains the jumps to this id + remainder: &'a Stmt<'a>, + }, + Jump(JoinPointId, &'a [Symbol]), + RuntimeError(&'a str), +} +#[derive(Clone, Debug, PartialEq)] +pub enum Literal<'a> { + // Literals + Int(i64), + Float(f64), + Str(&'a str), + /// Closed tag unions containing exactly two (0-arity) tags compile to Expr::Bool, + /// so they can (at least potentially) be emitted as 1-bit machine bools. + /// + /// So [ True, False ] compiles to this, and so do [ A, B ] and [ Foo, Bar ]. + /// However, a union like [ True, False, Other Int ] would not. + Bool(bool), + /// Closed tag unions containing between 3 and 256 tags (all of 0 arity) + /// compile to bytes, e.g. [ Blue, Black, Red, Green, White ] + Byte(u8), +} +#[derive(Clone, Debug, PartialEq, Copy)] +pub enum CallType { + ByName(Symbol), + ByPointer(Symbol), +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Expr<'a> { + Literal(Literal<'a>), + + // Functions + FunctionPointer(Symbol, Layout<'a>), + FunctionCall { + call_type: CallType, + layout: Layout<'a>, + arg_layouts: &'a [Layout<'a>], + args: &'a [Symbol], + }, + RunLowLevel(LowLevel, &'a [Symbol]), + + Tag { + tag_layout: Layout<'a>, + tag_name: TagName, + tag_id: u8, + union_size: u8, + arguments: &'a [Symbol], + }, + Struct(&'a [Symbol]), + AccessAtIndex { + index: u64, + field_layouts: &'a [Layout<'a>], + structure: Symbol, + is_unwrapped: bool, + }, + + Array { + elem_layout: Layout<'a>, + elems: &'a [Symbol], + }, + EmptyArray, + + RuntimeErrorFunction(&'a str), +} + +impl<'a> Literal<'a> { + pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + use Literal::*; + + match self { + Int(lit) => alloc.text(format!("{}i64", lit)), + Float(lit) => alloc.text(format!("{}f64", lit)), + Bool(lit) => alloc.text(format!("{}", lit)), + Byte(lit) => alloc.text(format!("{}u8", lit)), + Str(lit) => alloc.text(format!("{:?}", lit)), + } + } +} + +fn symbol_to_doc<'b, D, A>(alloc: &'b D, symbol: Symbol) -> DocBuilder<'b, D, A> +where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, +{ + alloc.text(format!("{}", symbol)) +} + +fn join_point_to_doc<'b, D, A>(alloc: &'b D, symbol: JoinPointId) -> DocBuilder<'b, D, A> +where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, +{ + alloc.text(format!("{}", symbol.0)) +} + +impl<'a> Expr<'a> { + pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + use Expr::*; + + match self { + Literal(lit) => lit.to_doc(alloc), + + FunctionPointer(symbol, _) => symbol_to_doc(alloc, *symbol), + + FunctionCall { + call_type, args, .. + } => match call_type { + CallType::ByName(name) => { + let it = std::iter::once(name) + .chain(args.iter()) + .map(|s| symbol_to_doc(alloc, *s)); + + alloc.text("CallByName ").append(alloc.intersperse(it, " ")) + } + CallType::ByPointer(name) => { + let it = std::iter::once(name) + .chain(args.iter()) + .map(|s| symbol_to_doc(alloc, *s)); + + alloc + .text("CallByPointer ") + .append(alloc.intersperse(it, " ")) + } + }, + RunLowLevel(lowlevel, args) => { + let it = args.iter().map(|s| symbol_to_doc(alloc, *s)); + + alloc + .text(format!("lowlevel {:?} ", lowlevel)) + .append(alloc.intersperse(it, " ")) + } + Tag { + tag_name, + arguments, + .. + } => { + let doc_tag = match tag_name { + TagName::Global(s) => alloc.text(s.as_str()), + TagName::Private(s) => alloc.text(format!("{}", s)), + }; + + let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s)); + + doc_tag + .append(alloc.space()) + .append(alloc.intersperse(it, " ")) + } + Struct(args) => { + let it = args.iter().map(|s| symbol_to_doc(alloc, *s)); + + alloc + .text("Struct {") + .append(alloc.intersperse(it, ", ")) + .append(alloc.text("}")) + } + Array { elems, .. } => { + let it = elems.iter().map(|s| symbol_to_doc(alloc, *s)); + + alloc + .text("Array [") + .append(alloc.intersperse(it, ", ")) + .append(alloc.text("]")) + } + EmptyArray => alloc.text("Array []"), + + AccessAtIndex { + index, structure, .. + } => alloc + .text(format!("Index {} ", index)) + .append(symbol_to_doc(alloc, *structure)), + + RuntimeErrorFunction(s) => alloc.text(format!("ErrorFunction {}", s)), + } + } +} + +impl<'a> Stmt<'a> { + pub fn new( + env: &mut Env<'a, '_>, + can_expr: roc_can::expr::Expr, + procs: &mut Procs<'a>, + ) -> Self { + let mut layout_cache = LayoutCache::default(); + + from_can(env, can_expr, procs, &mut layout_cache) + } + pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D) -> DocBuilder<'b, D, A> + where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, + { + use Stmt::*; + + match self { + Let(symbol, expr, _, cont) => alloc + .text("let ") + .append(symbol_to_doc(alloc, *symbol)) + .append(" = ") + .append(expr.to_doc(alloc)) + .append(";") + .append(alloc.hardline()) + .append(cont.to_doc(alloc)), + + Ret(symbol) => alloc + .text("ret ") + .append(symbol_to_doc(alloc, *symbol)) + .append(";"), + + Switch { + cond_symbol, + branches, + default_branch, + .. + } => { + let default_doc = alloc + .text("default:") + .append(alloc.hardline()) + .append(default_branch.to_doc(alloc).indent(4)) + .indent(4); + + let branches_docs = branches + .iter() + .map(|(tag, expr)| { + alloc + .text(format!("case {}:", tag)) + .append(alloc.hardline()) + .append(expr.to_doc(alloc).indent(4)) + .indent(4) + }) + .chain(std::iter::once(default_doc)); + // + alloc + .text(format!("switch {}:", cond_symbol)) + .append(alloc.hardline()) + .append( + alloc.intersperse(branches_docs, alloc.hardline().append(alloc.hardline())), + ) + .append(alloc.hardline()) + } + + Cond { + branching_symbol, + pass, + fail, + .. + } => alloc + .text(format!("if {} then", branching_symbol)) + .append(alloc.hardline()) + .append(pass.to_doc(alloc).indent(4)) + .append(alloc.hardline()) + .append(alloc.text("else")) + .append(alloc.hardline()) + .append(fail.to_doc(alloc).indent(4)), + RuntimeError(s) => alloc.text(format!("Error {}", s)), + + Join { + id, + parameters, + continuation, + remainder, + } => { + let it = parameters.iter().map(|p| symbol_to_doc(alloc, p.symbol)); + + alloc.intersperse( + vec![ + remainder.to_doc(alloc), + alloc + .text("joinpoint ") + .append(join_point_to_doc(alloc, *id)) + .append(" ".repeat(parameters.len().min(1))) + .append(alloc.intersperse(it, alloc.space())) + .append(":"), + continuation.to_doc(alloc).indent(4), + ], + alloc.hardline(), + ) + } + Jump(id, arguments) => { + let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s)); + + alloc + .text("jump ") + .append(join_point_to_doc(alloc, *id)) + .append(" ".repeat(arguments.len().min(1))) + .append(alloc.intersperse(it, alloc.space())) + .append(";") + } + Inc(symbol, cont) => alloc + .text("inc ") + .append(symbol_to_doc(alloc, *symbol)) + .append(";") + .append(alloc.hardline()) + .append(cont.to_doc(alloc)), + Dec(symbol, cont) => alloc + .text("dec ") + .append(symbol_to_doc(alloc, *symbol)) + .append(";") + .append(alloc.hardline()) + .append(cont.to_doc(alloc)), + } + } + + pub fn to_pretty(&self, width: usize) -> String { + let allocator = BoxAllocator; + let mut w = std::vec::Vec::new(); + self.to_doc::<_, ()>(&allocator) + .1 + .render(width, &mut w) + .unwrap(); + w.push(b'\n'); + String::from_utf8(w).unwrap() + } + + pub fn is_terminal(&self) -> bool { + use Stmt::*; + + match self { + Cond { .. } | Switch { .. } => { + // TODO is this the reason Lean only looks at the outermost `when`? + true + } + Ret(_) => true, + Jump(_, _) => true, + _ => false, + } + } +} + +/// turn record/tag patterns into a when expression, e.g. +/// +/// foo = \{ x } -> body +/// +/// becomes +/// +/// foo = \r -> when r is { x } -> body +/// +/// conversion of one-pattern when expressions will do the most optimal thing +#[allow(clippy::type_complexity)] +fn patterns_to_when<'a>( + env: &mut Env<'a, '_>, + layout_cache: &mut LayoutCache<'a>, + patterns: std::vec::Vec<(Variable, Located)>, + body_var: Variable, + body: Located, +) -> Result< + ( + Vec<'a, Variable>, + Vec<'a, Symbol>, + Located, + ), + Located, +> { + let mut arg_vars = Vec::with_capacity_in(patterns.len(), env.arena); + let mut symbols = Vec::with_capacity_in(patterns.len(), env.arena); + let mut body = Ok(body); + + // patterns that are not yet in a when (e.g. in let or function arguments) must be irrefutable + // to pass type checking. So the order in which we add them to the body does not matter: there + // are only stores anyway, no branches. + for (pattern_var, pattern) in patterns.into_iter() { + let context = crate::exhaustive::Context::BadArg; + let mono_pattern = from_can_pattern(env, layout_cache, &pattern.value); + + match crate::exhaustive::check( + pattern.region, + &[( + Located::at(pattern.region, mono_pattern.clone()), + crate::exhaustive::Guard::NoGuard, + )], + context, + ) { + Ok(_) => { + // Replace the body with a new one, but only if it was Ok. + if let Ok(unwrapped_body) = body { + let (new_symbol, new_body) = + pattern_to_when(env, pattern_var, pattern, body_var, unwrapped_body); + + symbols.push(new_symbol); + arg_vars.push(pattern_var); + + body = Ok(new_body) + } + } + Err(errors) => { + for error in errors { + env.problems.push(MonoProblem::PatternProblem(error)) + } + + let value = RuntimeError::UnsupportedPattern(pattern.region); + + // Even if the body was Ok, replace it with this Err. + // If it was already an Err, leave it at that Err, so the first + // RuntimeError we encountered remains the first. + body = body.and_then(|_| { + Err(Located { + region: pattern.region, + value, + }) + }); + } + } + } + + match body { + Ok(body) => Ok((arg_vars, symbols, body)), + Err(loc_error) => Err(loc_error), + } +} + +/// turn irrefutable patterns into when. For example +/// +/// foo = \{ x } -> body +/// +/// Assuming the above program typechecks, the pattern match cannot fail +/// (it is irrefutable). It becomes +/// +/// foo = \r -> +/// when r is +/// { x } -> body +/// +/// conversion of one-pattern when expressions will do the most optimal thing +fn pattern_to_when<'a>( + env: &mut Env<'a, '_>, + pattern_var: Variable, + pattern: Located, + body_var: Variable, + body: Located, +) -> (Symbol, Located) { + use roc_can::expr::Expr::*; + use roc_can::expr::WhenBranch; + use roc_can::pattern::Pattern::*; + + match &pattern.value { + Identifier(symbol) => (*symbol, body), + Underscore => { + // for underscore we generate a dummy Symbol + (env.unique_symbol(), body) + } + Shadowed(region, loc_ident) => { + let error = roc_problem::can::RuntimeError::Shadowing { + original_region: *region, + shadow: loc_ident.clone(), + }; + (env.unique_symbol(), Located::at_zero(RuntimeError(error))) + } + + UnsupportedPattern(region) => { + // create the runtime error here, instead of delegating to When. + // UnsupportedPattern should then never occcur in When + let error = roc_problem::can::RuntimeError::UnsupportedPattern(*region); + (env.unique_symbol(), Located::at_zero(RuntimeError(error))) + } + + MalformedPattern(problem, region) => { + // create the runtime error here, instead of delegating to When. + let error = roc_problem::can::RuntimeError::MalformedPattern(*problem, *region); + (env.unique_symbol(), Located::at_zero(RuntimeError(error))) + } + + AppliedTag { .. } | RecordDestructure { .. } => { + let symbol = env.unique_symbol(); + + let wrapped_body = When { + cond_var: pattern_var, + expr_var: body_var, + region: Region::zero(), + loc_cond: Box::new(Located::at_zero(Var(symbol))), + branches: vec![WhenBranch { + patterns: vec![pattern], + value: body, + guard: None, + }], + }; + + (symbol, Located::at_zero(wrapped_body)) + } + + IntLiteral(_) | NumLiteral(_, _) | FloatLiteral(_) | StrLiteral(_) => { + // These patters are refutable, and thus should never occur outside a `when` expression + // They should have been replaced with `UnsupportedPattern` during canonicalization + unreachable!("refutable pattern {:?} where irrefutable pattern is expected. This should never happen!", pattern.value) + } + } +} + +pub fn specialize_all<'a>( + env: &mut Env<'a, '_>, + mut procs: Procs<'a>, + layout_cache: &mut LayoutCache<'a>, +) -> Procs<'a> { + let mut pending_specializations = procs.pending_specializations.unwrap_or_default(); + + // When calling from_can, pending_specializations should be unavailable. + // This must be a single pass, and we must not add any more entries to it! + procs.pending_specializations = None; + + for (name, mut by_layout) in pending_specializations.drain() { + // Use the function's symbol's home module as the home module + // when doing canonicalization. This will be important to determine + // whether or not it's safe to defer specialization. + env.home = name.module_id(); + + for (layout, pending) in by_layout.drain() { + // If we've already seen this (Symbol, Layout) combination before, + // don't try to specialize it again. If we do, we'll loop forever! + // + // NOTE: this #[allow(clippy::map_entry)] here is for correctness! + // Changing it to use .entry() would necessarily make it incorrect. + #[allow(clippy::map_entry)] + if !procs.specialized.contains_key(&(name, layout.clone())) { + // TODO should pending_procs hold a Rc? + let partial_proc = procs + .partial_procs + .get(&name) + .unwrap_or_else(|| panic!("Could not find partial_proc for {:?}", name)) + .clone(); + + // Mark this proc as in-progress, so if we're dealing with + // mutually recursive functions, we don't loop forever. + // (We had a bug around this before this system existed!) + procs.specialized.insert((name, layout.clone()), InProgress); + + match specialize(env, &mut procs, name, layout_cache, pending, partial_proc) { + Ok(proc) => { + procs.specialized.insert((name, layout), Done(proc)); + } + Err(error) => { + let error_msg = env.arena.alloc(format!( + "TODO generate a RuntimeError message for {:?}", + error + )); + + procs.runtime_errors.insert(name, error_msg); + } + } + } + } + } + + procs +} + +fn specialize<'a>( + env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + proc_name: Symbol, + layout_cache: &mut LayoutCache<'a>, + pending: PendingSpecialization<'a>, + partial_proc: PartialProc<'a>, +) -> Result, LayoutProblem> { + let PendingSpecialization { + ret_var, + fn_var, + pattern_vars, + } = pending; + + let PartialProc { + annotation, + pattern_symbols, + body, + } = partial_proc; + + // unify the called function with the specialized signature, then specialize the function body + let snapshot = env.subs.snapshot(); + let unified = roc_unify::unify::unify(env.subs, annotation, fn_var); + + debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_))); + + let ret_symbol = env.unique_symbol(); + let hole = env.arena.alloc(Stmt::Ret(ret_symbol)); + let specialized_body = with_hole(env, body, procs, layout_cache, ret_symbol, hole); + + // reset subs, so we don't get type errors when specializing for a different signature + env.subs.rollback_to(snapshot); + + let mut proc_args = Vec::with_capacity_in(pattern_vars.len(), &env.arena); + + debug_assert_eq!( + &pattern_vars.len(), + &pattern_symbols.len(), + "Tried to zip two vecs with different lengths!" + ); + + for (arg_var, arg_name) in pattern_vars.iter().zip(pattern_symbols.iter()) { + let layout = layout_cache.from_var(&env.arena, *arg_var, env.subs)?; + + proc_args.push((layout, *arg_name)); + } + + let ret_layout = layout_cache + .from_var(&env.arena, ret_var, env.subs) + .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); + + // TODO WRONG + let closes_over_layout = Layout::Struct(&[]); + + let proc = Proc { + name: proc_name, + args: proc_args.into_bump_slice(), + body: specialized_body, + closes_over: closes_over_layout, + ret_layout, + }; + + Ok(proc) +} + +pub fn with_hole<'a>( + env: &mut Env<'a, '_>, + can_expr: roc_can::expr::Expr, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, + assigned: Symbol, + hole: &'a Stmt<'a>, +) -> Stmt<'a> { + use roc_can::expr::Expr::*; + + let arena = env.arena; + + match can_expr { + Int(_, num) => Stmt::Let( + assigned, + Expr::Literal(Literal::Int(num)), + Layout::Builtin(Builtin::Int64), + hole, + ), + + Float(_, num) => Stmt::Let( + assigned, + Expr::Literal(Literal::Float(num)), + Layout::Builtin(Builtin::Float64), + hole, + ), + + Str(string) | BlockStr(string) => Stmt::Let( + assigned, + Expr::Literal(Literal::Str(arena.alloc(string))), + Layout::Builtin(Builtin::Str), + hole, + ), + + Num(var, num) => match num_argument_to_int_or_float(env.subs, var) { + IntOrFloat::IntType => Stmt::Let( + assigned, + Expr::Literal(Literal::Int(num)), + Layout::Builtin(Builtin::Int64), + hole, + ), + IntOrFloat::FloatType => Stmt::Let( + assigned, + Expr::Literal(Literal::Float(num as f64)), + Layout::Builtin(Builtin::Float64), + hole, + ), + }, + LetNonRec(def, cont, _, _) => { + // WRONG! this is introduces new control flow, and should call `from_can` again + if let roc_can::pattern::Pattern::Identifier(symbol) = def.loc_pattern.value { + let mut stmt = with_hole(env, cont.value, procs, layout_cache, assigned, hole); + + // this is an alias of a variable + if let roc_can::expr::Expr::Var(original) = def.loc_expr.value { + substitute_in_exprs(env.arena, &mut stmt, symbol, original); + } + + with_hole( + env, + def.loc_expr.value, + procs, + layout_cache, + symbol, + env.arena.alloc(stmt), + ) + } else { + // this may be a destructure pattern + let mono_pattern = from_can_pattern(env, layout_cache, &def.loc_pattern.value); + + if let Pattern::Identifier(symbol) = mono_pattern { + let hole = env + .arena + .alloc(from_can(env, cont.value, procs, layout_cache)); + with_hole(env, def.loc_expr.value, procs, layout_cache, symbol, hole) + } else { + let context = crate::exhaustive::Context::BadDestruct; + match crate::exhaustive::check( + def.loc_pattern.region, + &[( + Located::at(def.loc_pattern.region, mono_pattern.clone()), + crate::exhaustive::Guard::NoGuard, + )], + context, + ) { + Ok(_) => {} + Err(errors) => { + for error in errors { + env.problems.push(MonoProblem::PatternProblem(error)) + } + } // TODO make all variables bound in the pattern evaluate to a runtime error + // return Stmt::RuntimeError("TODO non-exhaustive pattern"); + } + + // convert the continuation + let mut stmt = from_can(env, cont.value, procs, layout_cache); + + let outer_symbol = env.unique_symbol(); + stmt = + store_pattern(env, procs, layout_cache, &mono_pattern, outer_symbol, stmt) + .unwrap(); + + // convert the def body, store in outer_symbol + with_hole( + env, + def.loc_expr.value, + procs, + layout_cache, + outer_symbol, + env.arena.alloc(stmt), + ) + } + } + } + Var(symbol) => { + if procs.module_thunks.contains(&symbol) { + let partial_proc = procs.partial_procs.get(&symbol).unwrap(); + let fn_var = partial_proc.annotation; + let ret_var = fn_var; // These are the same for a thunk. + + // This is a top-level declaration, which will code gen to a 0-arity thunk. + let result = call_by_name( + env, + procs, + fn_var, + ret_var, + symbol, + std::vec::Vec::new(), + layout_cache, + assigned, + env.arena.alloc(Stmt::Ret(assigned)), + ); + + return result; + } + + // A bit ugly, but it does the job + match hole { + Stmt::Jump(id, _) => Stmt::Jump(*id, env.arena.alloc([symbol])), + _ => { + // if you see this, there is variable aliasing going on + Stmt::Ret(symbol) + } + } + } + // Var(symbol) => panic!("reached Var {}", symbol), + // assigned, + // Stmt::Ret(symbol), + Tag { + variant_var, + name: tag_name, + arguments: args, + .. + } => { + use crate::layout::UnionVariant::*; + let arena = env.arena; + + let variant = crate::layout::union_sorted_tags(env.arena, variant_var, env.subs); + + match variant { + Never => unreachable!("The `[]` type has no constructors"), + Unit => Stmt::Let(assigned, Expr::Struct(&[]), Layout::Struct(&[]), hole), + BoolUnion { ttrue, .. } => Stmt::Let( + assigned, + Expr::Literal(Literal::Bool(tag_name == ttrue)), + Layout::Builtin(Builtin::Int1), + hole, + ), + ByteUnion(tag_names) => { + let tag_id = tag_names + .iter() + .position(|key| key == &tag_name) + .expect("tag must be in its own type"); + + Stmt::Let( + assigned, + Expr::Literal(Literal::Byte(tag_id as u8)), + Layout::Builtin(Builtin::Int8), + hole, + ) + } + + Unwrapped(field_layouts) => { + let mut field_symbols = Vec::with_capacity_in(field_layouts.len(), env.arena); + + for (_, arg) in args.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg.value { + field_symbols.push(symbol); + } else { + field_symbols.push(env.unique_symbol()); + } + } + + // Layout will unpack this unwrapped tack if it only has one (non-zero-sized) field + let layout = layout_cache + .from_var(env.arena, variant_var, env.subs) + .unwrap_or_else(|err| { + panic!("TODO turn fn_var into a RuntimeError {:?}", err) + }); + + // even though this was originally a Tag, we treat it as a Struct from now on + let mut stmt = Stmt::Let( + assigned, + Expr::Struct(field_symbols.clone().into_bump_slice()), + layout, + hole, + ); + + for ((_, arg), symbol) in args.into_iter().rev().zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = arg.value { + continue; + } + stmt = with_hole( + env, + arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(stmt), + ); + } + + stmt + } + Wrapped(sorted_tag_layouts) => { + let union_size = sorted_tag_layouts.len() as u8; + let (tag_id, (_, _)) = sorted_tag_layouts + .iter() + .enumerate() + .find(|(_, (key, _))| key == &tag_name) + .expect("tag must be in its own type"); + + let mut field_symbols: Vec = Vec::with_capacity_in(args.len(), arena); + let tag_id_symbol = env.unique_symbol(); + field_symbols.push(tag_id_symbol); + + for (_, arg) in args.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg.value { + field_symbols.push(symbol); + } else { + field_symbols.push(env.unique_symbol()); + } + } + + let mut layouts: Vec<&'a [Layout<'a>]> = + Vec::with_capacity_in(sorted_tag_layouts.len(), env.arena); + + for (_, arg_layouts) in sorted_tag_layouts.into_iter() { + layouts.push(arg_layouts); + } + + let field_symbols = field_symbols.into_bump_slice(); + let layout = Layout::Union(layouts.into_bump_slice()); + let tag = Expr::Tag { + tag_layout: layout.clone(), + tag_name, + tag_id: tag_id as u8, + union_size, + arguments: field_symbols, + }; + + let mut stmt = Stmt::Let(assigned, tag, layout, hole); + + for ((_, arg), symbol) in args.into_iter().rev().zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = arg.value { + continue; + } + + stmt = with_hole( + env, + arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(stmt), + ); + } + + // define the tag id + stmt = Stmt::Let( + tag_id_symbol, + Expr::Literal(Literal::Int(tag_id as i64)), + Layout::Builtin(Builtin::Int64), + arena.alloc(stmt), + ); + + stmt + } + } + } + + Record { + record_var, + mut fields, + .. + } => { + let sorted_fields = crate::layout::sort_record_fields(env.arena, record_var, env.subs); + + let mut field_symbols = Vec::with_capacity_in(fields.len(), env.arena); + let mut field_layouts = Vec::with_capacity_in(fields.len(), env.arena); + let mut can_fields = Vec::with_capacity_in(fields.len(), env.arena); + + for (label, layout) in sorted_fields.into_iter() { + field_layouts.push(layout); + + let field = fields.remove(&label).unwrap(); + if let roc_can::expr::Expr::Var(symbol) = field.loc_expr.value { + field_symbols.push(symbol); + can_fields.push(None); + } else { + field_symbols.push(env.unique_symbol()); + can_fields.push(Some(field)); + } + } + + // creating a record from the var will unpack it if it's just a single field. + let layout = layout_cache + .from_var(env.arena, record_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let field_symbols = field_symbols.into_bump_slice(); + let mut stmt = Stmt::Let(assigned, Expr::Struct(field_symbols), layout, hole); + + for (opt_field, symbol) in can_fields.into_iter().rev().zip(field_symbols.iter().rev()) + { + if let Some(field) = opt_field { + stmt = with_hole( + env, + field.loc_expr.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(stmt), + ); + } + } + + stmt + } + + EmptyRecord => Stmt::Let(assigned, Expr::Struct(&[]), Layout::Struct(&[]), hole), + + If { + cond_var, + branch_var, + branches, + final_else, + } => { + let ret_layout = layout_cache + .from_var(env.arena, branch_var, env.subs) + .expect("invalid ret_layout"); + let cond_layout = layout_cache + .from_var(env.arena, cond_var, env.subs) + .expect("invalid cond_layout"); + + let assigned_in_jump = env.unique_symbol(); + let id = JoinPointId(env.unique_symbol()); + let jump = env + .arena + .alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump]))); + + let mut stmt = with_hole( + env, + final_else.value, + procs, + layout_cache, + assigned_in_jump, + jump, + ); + + for (loc_cond, loc_then) in branches.into_iter().rev() { + let branching_symbol = env.unique_symbol(); + let then = with_hole( + env, + loc_then.value, + procs, + layout_cache, + assigned_in_jump, + jump, + ); + + stmt = Stmt::Cond { + cond_symbol: branching_symbol, + branching_symbol, + cond_layout: cond_layout.clone(), + branching_layout: cond_layout.clone(), + pass: env.arena.alloc(then), + fail: env.arena.alloc(stmt), + ret_layout: ret_layout.clone(), + }; + + // add condition + stmt = with_hole( + env, + loc_cond.value, + procs, + layout_cache, + branching_symbol, + env.arena.alloc(stmt), + ); + } + + let layout = layout_cache + .from_var(env.arena, branch_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let param = Param { + symbol: assigned, + layout, + borrow: false, + }; + + Stmt::Join { + id, + parameters: env.arena.alloc([param]), + remainder: env.arena.alloc(stmt), + continuation: hole, + } + } + + When { + cond_var, + expr_var, + region, + loc_cond, + branches, + } => { + let cond_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_cond.value { + symbol + } else { + env.unique_symbol() + }; + + let id = JoinPointId(env.unique_symbol()); + + let mut stmt = from_can_when( + env, + cond_var, + expr_var, + region, + cond_symbol, + branches, + layout_cache, + procs, + Some(id), + ); + + // define the `when` condition + if let roc_can::expr::Expr::Var(_) = loc_cond.value { + // do nothing + } else { + stmt = with_hole( + env, + loc_cond.value, + procs, + layout_cache, + cond_symbol, + env.arena.alloc(stmt), + ); + }; + + let layout = layout_cache + .from_var(env.arena, expr_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let param = Param { + symbol: assigned, + layout, + borrow: false, + }; + + Stmt::Join { + id, + parameters: env.arena.alloc([param]), + remainder: env.arena.alloc(stmt), + continuation: env.arena.alloc(hole), + } + } + + List { loc_elems, .. } if loc_elems.is_empty() => { + // because an empty list has an unknown element type, it is handled differently + let expr = Expr::EmptyArray; + Stmt::Let(assigned, expr, Layout::Builtin(Builtin::EmptyList), hole) + } + + List { + list_var, + elem_var, + loc_elems, + } => { + let mut arg_symbols = Vec::with_capacity_in(loc_elems.len(), env.arena); + for arg_expr in loc_elems.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg_expr.value { + arg_symbols.push(symbol); + } else { + arg_symbols.push(env.unique_symbol()); + } + } + let arg_symbols = arg_symbols.into_bump_slice(); + + let elem_layout = layout_cache + .from_var(env.arena, elem_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let expr = Expr::Array { + elem_layout: elem_layout.clone(), + elems: arg_symbols, + }; + + let mode = crate::layout::mode_from_var(list_var, env.subs); + + let mut stmt = Stmt::Let( + assigned, + expr, + Layout::Builtin(Builtin::List(mode, env.arena.alloc(elem_layout))), + hole, + ); + + for (arg_expr, symbol) in loc_elems.into_iter().rev().zip(arg_symbols.iter().rev()) { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = arg_expr.value { + continue; + } + + stmt = with_hole( + env, + arg_expr.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(stmt), + ); + } + + stmt + } + LetRec(_, _, _, _) => todo!("lets"), + + Access { + record_var, + field_var, + field, + loc_expr, + .. + } => { + let sorted_fields = crate::layout::sort_record_fields(env.arena, record_var, env.subs); + + let mut index = None; + let mut field_layouts = Vec::with_capacity_in(sorted_fields.len(), env.arena); + + for (current, (label, field_layout)) in sorted_fields.into_iter().enumerate() { + field_layouts.push(field_layout); + + if label == field { + index = Some(current); + } + } + + let record_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_expr.value { + symbol + } else { + env.unique_symbol() + }; + + let expr = Expr::AccessAtIndex { + index: index.expect("field not in its own type") as u64, + field_layouts: field_layouts.into_bump_slice(), + structure: record_symbol, + is_unwrapped: true, + }; + + let layout = layout_cache + .from_var(env.arena, field_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let mut stmt = Stmt::Let(assigned, expr, layout, hole); + + if let roc_can::expr::Expr::Var(_) = loc_expr.value { + // do nothing + } else { + stmt = with_hole( + env, + loc_expr.value, + procs, + layout_cache, + record_symbol, + env.arena.alloc(stmt), + ); + }; + + stmt + } + + Accessor { .. } | Update { .. } => todo!("record access/accessor/update"), + + Closure(ann, name, _, loc_args, boxed_body) => { + let (loc_body, ret_var) = *boxed_body; + + match procs.insert_anonymous(env, name, ann, loc_args, loc_body, ret_var, layout_cache) + { + Ok(layout) => { + // TODO should the let have layout Pointer? + Stmt::Let( + assigned, + Expr::FunctionPointer(name, layout.clone()), + layout, + hole, + ) + } + + Err(_error) => Stmt::RuntimeError( + "TODO convert anonymous function error to a RuntimeError string", + ), + } + } + + Call(boxed, loc_args, _) => { + let (fn_var, loc_expr, ret_var) = *boxed; + + /* + Var(symbol) => { + if procs.module_thunks.contains(&symbol) { + let partial_proc = procs.partial_procs.get(&symbol).unwrap(); + let fn_var = partial_proc.annotation; + let ret_var = fn_var; // These are the same for a thunk. + + // This is a top-level declaration, which will code gen to a 0-arity thunk. + call_by_name( + env, + procs, + fn_var, + ret_var, + symbol, + std::vec::Vec::new(), + layout_cache, + ) + } else { + // NOTE Load will always increment the refcount + Expr::Load(symbol) + } + } + */ + + // match from_can(env, loc_expr.value, procs, layout_cache) { + match loc_expr.value { + roc_can::expr::Expr::Var(proc_name) if procs.module_thunks.contains(&proc_name) => { + todo!() + } + roc_can::expr::Expr::Var(proc_name) => call_by_name( + env, + procs, + fn_var, + ret_var, + proc_name, + loc_args, + layout_cache, + assigned, + hole, + ), + _ => { + // Call by pointer - the closure was anonymous, e.g. + // + // ((\a -> a) 5) + // + // It might even be the anonymous result of a conditional: + // + // ((if x > 0 then \a -> a else \_ -> 0) 5) + // + // It could be named too: + // + // ((if x > 0 then foo else bar) 5) + let mut arg_symbols = Vec::with_capacity_in(loc_args.len(), env.arena); + + for _ in 0..loc_args.len() { + arg_symbols.push(env.unique_symbol()); + } + + let layout = layout_cache + .from_var(env.arena, fn_var, env.subs) + .unwrap_or_else(|err| { + panic!("TODO turn fn_var into a RuntimeError {:?}", err) + }); + + let arg_layouts = match layout { + Layout::FunctionPointer(args, _) => args, + _ => unreachable!("function has layout that is not function pointer"), + }; + + let ret_layout = layout_cache + .from_var(env.arena, ret_var, env.subs) + .unwrap_or_else(|err| { + panic!("TODO turn fn_var into a RuntimeError {:?}", err) + }); + + let function_symbol = env.unique_symbol(); + let arg_symbols = arg_symbols.into_bump_slice(); + let mut result = Stmt::Let( + assigned, + Expr::FunctionCall { + call_type: CallType::ByPointer(function_symbol), + layout, + args: arg_symbols, + arg_layouts, + }, + ret_layout, + arena.alloc(hole), + ); + + // let ptr = with_hole(env, loc_expr.value, procs, layout_cache, function_symbol); + result = with_hole( + env, + loc_expr.value, + procs, + layout_cache, + function_symbol, + env.arena.alloc(result), + ); + + for ((_, loc_arg), symbol) in + loc_args.into_iter().rev().zip(arg_symbols.iter().rev()) + { + result = with_hole( + env, + loc_arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result + } + } + } + + RunLowLevel { op, args, ret_var } => { + let op = optimize_low_level(env.subs, op, &args); + + let mut arg_symbols = Vec::with_capacity_in(args.len(), env.arena); + + for (_, arg_expr) in args.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg_expr { + arg_symbols.push(*symbol); + } else { + arg_symbols.push(env.unique_symbol()); + } + } + let arg_symbols = arg_symbols.into_bump_slice(); + + // layout of the return type + let layout = layout_cache + .from_var(env.arena, ret_var, env.subs) + .unwrap_or_else(|err| todo!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let mut result = Stmt::Let(assigned, Expr::RunLowLevel(op, arg_symbols), layout, hole); + + for ((_arg_var, arg_expr), symbol) in + args.into_iter().rev().zip(arg_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = arg_expr { + continue; + } + + result = with_hole( + env, + arg_expr, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result + } + RuntimeError(e) => Stmt::RuntimeError(env.arena.alloc(format!("{:?}", e))), + } +} + +pub fn from_can<'a>( + env: &mut Env<'a, '_>, + can_expr: roc_can::expr::Expr, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, +) -> Stmt<'a> { + use roc_can::expr::Expr::*; + + match can_expr { + LetRec(defs, cont, _, _) => { + // because Roc is strict, only functions can be recursive! + for def in defs.into_iter() { + if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { + // Now that we know for sure it's a closure, get an owned + // version of these variant args so we can use them properly. + match def.loc_expr.value { + Closure(ann, _, _, loc_args, boxed_body) => { + // Extract Procs, but discard the resulting Expr::Load. + // That Load looks up the pointer, which we won't use here! + + let (loc_body, ret_var) = *boxed_body; + + procs.insert_named( + env, + layout_cache, + *symbol, + ann, + loc_args, + loc_body, + ret_var, + ); + + continue; + } + _ => unreachable!("recursive value is not a function"), + } + } + unreachable!("recursive value does not have Identifier pattern") + } + + from_can(env, cont.value, procs, layout_cache) + } + LetNonRec(def, cont, _, _) => { + if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { + if let Closure(_, _, _, _, _) = &def.loc_expr.value { + // Now that we know for sure it's a closure, get an owned + // version of these variant args so we can use them properly. + match def.loc_expr.value { + Closure(ann, _, _, loc_args, boxed_body) => { + // Extract Procs, but discard the resulting Expr::Load. + // That Load looks up the pointer, which we won't use here! + + let (loc_body, ret_var) = *boxed_body; + + procs.insert_named( + env, + layout_cache, + *symbol, + ann, + loc_args, + loc_body, + ret_var, + ); + + return from_can(env, cont.value, procs, layout_cache); + } + _ => unreachable!(), + } + } + let rest = from_can(env, cont.value, procs, layout_cache); + return with_hole( + env, + def.loc_expr.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(rest), + ); + } + + // this may be a destructure pattern + let mono_pattern = from_can_pattern(env, layout_cache, &def.loc_pattern.value); + + if let Pattern::Identifier(symbol) = mono_pattern { + let hole = env + .arena + .alloc(from_can(env, cont.value, procs, layout_cache)); + with_hole(env, def.loc_expr.value, procs, layout_cache, symbol, hole) + } else { + let context = crate::exhaustive::Context::BadDestruct; + match crate::exhaustive::check( + def.loc_pattern.region, + &[( + Located::at(def.loc_pattern.region, mono_pattern.clone()), + crate::exhaustive::Guard::NoGuard, + )], + context, + ) { + Ok(_) => {} + Err(errors) => { + for error in errors { + env.problems.push(MonoProblem::PatternProblem(error)) + } + } // TODO make all variables bound in the pattern evaluate to a runtime error + // return Stmt::RuntimeError("TODO non-exhaustive pattern"); + } + + // convert the continuation + let mut stmt = from_can(env, cont.value, procs, layout_cache); + + let outer_symbol = env.unique_symbol(); + stmt = store_pattern(env, procs, layout_cache, &mono_pattern, outer_symbol, stmt) + .unwrap(); + + // convert the def body, store in outer_symbol + with_hole( + env, + def.loc_expr.value, + procs, + layout_cache, + outer_symbol, + env.arena.alloc(stmt), + ) + } + } + + _ => { + let symbol = env.unique_symbol(); + let hole = env.arena.alloc(Stmt::Ret(symbol)); + with_hole(env, can_expr, procs, layout_cache, symbol, hole) + } + } +} + +fn to_opt_branches<'a>( + env: &mut Env<'a, '_>, + region: Region, + branches: std::vec::Vec, + layout_cache: &mut LayoutCache<'a>, +) -> std::vec::Vec<( + Pattern<'a>, + Option>, + roc_can::expr::Expr, +)> { + debug_assert!(!branches.is_empty()); + + let mut loc_branches = std::vec::Vec::new(); + let mut opt_branches = std::vec::Vec::new(); + + for when_branch in branches { + let exhaustive_guard = if when_branch.guard.is_some() { + Guard::HasGuard + } else { + Guard::NoGuard + }; + + for loc_pattern in when_branch.patterns { + let mono_pattern = from_can_pattern(env, layout_cache, &loc_pattern.value); + + loc_branches.push(( + Located::at(loc_pattern.region, mono_pattern.clone()), + exhaustive_guard.clone(), + )); + + // TODO remove clone? + opt_branches.push(( + mono_pattern, + when_branch.guard.clone(), + when_branch.value.value.clone(), + )); + } + } + + // NOTE exhaustiveness is checked after the construction of all the branches + // In contrast to elm (currently), we still do codegen even if a pattern is non-exhaustive. + // So we not only report exhaustiveness errors, but also correct them + let context = crate::exhaustive::Context::BadCase; + match crate::exhaustive::check(region, &loc_branches, context) { + Ok(_) => {} + Err(errors) => { + use crate::exhaustive::Error::*; + let mut is_not_exhaustive = false; + let mut overlapping_branches = std::vec::Vec::new(); + + for error in errors { + match &error { + Incomplete(_, _, _) => { + is_not_exhaustive = true; + } + Redundant { index, .. } => { + overlapping_branches.push(index.to_zero_based()); + } + } + env.problems.push(MonoProblem::PatternProblem(error)) + } + + overlapping_branches.sort(); + + for i in overlapping_branches.into_iter().rev() { + opt_branches.remove(i); + } + + if is_not_exhaustive { + opt_branches.push(( + Pattern::Underscore, + None, + roc_can::expr::Expr::RuntimeError( + roc_problem::can::RuntimeError::NonExhaustivePattern, + ), + )); + } + } + } + + opt_branches +} + +#[allow(clippy::too_many_arguments)] +fn from_can_when<'a>( + env: &mut Env<'a, '_>, + cond_var: Variable, + expr_var: Variable, + region: Region, + cond_symbol: Symbol, + branches: std::vec::Vec, + layout_cache: &mut LayoutCache<'a>, + procs: &mut Procs<'a>, + join_point: Option, +) -> Stmt<'a> { + if branches.is_empty() { + // A when-expression with no branches is a runtime error. + // We can't know what to return! + return Stmt::RuntimeError("Hit a 0-branch when expression"); + } + let opt_branches = to_opt_branches(env, region, branches, layout_cache); + + let cond_layout = layout_cache + .from_var(env.arena, cond_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); + + let ret_layout = layout_cache + .from_var(env.arena, expr_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); + + let arena = env.arena; + let it = opt_branches + .into_iter() + .map(|(pattern, opt_guard, can_expr)| { + let branch_stmt = match join_point { + None => from_can(env, can_expr, procs, layout_cache), + Some(id) => { + let symbol = env.unique_symbol(); + let arguments = bumpalo::vec![in env.arena; symbol].into_bump_slice(); + let jump = env.arena.alloc(Stmt::Jump(id, arguments)); + + with_hole(env, can_expr, procs, layout_cache, symbol, jump) + } + }; + + use crate::decision_tree::Guard; + if let Some(loc_expr) = opt_guard { + let id = JoinPointId(env.unique_symbol()); + let symbol = env.unique_symbol(); + let jump = env.arena.alloc(Stmt::Jump(id, env.arena.alloc([symbol]))); + + let guard_stmt = with_hole(env, loc_expr.value, procs, layout_cache, symbol, jump); + + match store_pattern(env, procs, layout_cache, &pattern, cond_symbol, guard_stmt) { + Ok(new_guard_stmt) => ( + pattern, + Guard::Guard { + id, + symbol, + stmt: new_guard_stmt, + }, + branch_stmt, + ), + Err(msg) => ( + Pattern::Underscore, + Guard::NoGuard, + Stmt::RuntimeError(env.arena.alloc(msg)), + ), + } + } else { + match store_pattern(env, procs, layout_cache, &pattern, cond_symbol, branch_stmt) { + Ok(new_branch_stmt) => (pattern, Guard::NoGuard, new_branch_stmt), + Err(msg) => ( + Pattern::Underscore, + Guard::NoGuard, + Stmt::RuntimeError(env.arena.alloc(msg)), + ), + } + } + }); + let mono_branches = Vec::from_iter_in(it, arena); + + crate::decision_tree::optimize_when( + env, + procs, + layout_cache, + cond_symbol, + cond_layout.clone(), + ret_layout, + mono_branches, + ) +} + +fn substitute(substitutions: &MutMap, s: Symbol) -> Option { + match substitutions.get(&s) { + Some(new) => { + debug_assert!(!substitutions.contains_key(new)); + Some(*new) + } + None => None, + } +} + +fn substitute_in_exprs<'a>(arena: &'a Bump, stmt: &mut Stmt<'a>, from: Symbol, to: Symbol) { + let mut subs = MutMap::default(); + subs.insert(from, to); + + // TODO clean this up + let ref_stmt = arena.alloc(stmt.clone()); + if let Some(new) = substitute_in_stmt_help(arena, ref_stmt, &subs) { + *stmt = new.clone(); + } +} + +fn substitute_in_stmt_help<'a>( + arena: &'a Bump, + stmt: &'a Stmt<'a>, + subs: &MutMap, +) -> Option<&'a Stmt<'a>> { + use Stmt::*; + + match stmt { + Let(symbol, expr, layout, cont) => { + let opt_cont = substitute_in_stmt_help(arena, cont, subs); + let opt_expr = substitute_in_expr(arena, expr, subs); + + if opt_expr.is_some() || opt_cont.is_some() { + let cont = opt_cont.unwrap_or(cont); + let expr = opt_expr.unwrap_or_else(|| expr.clone()); + + Some(arena.alloc(Let(*symbol, expr, layout.clone(), cont))) + } else { + None + } + } + Join { + id, + parameters, + remainder, + continuation, + } => { + let opt_remainder = substitute_in_stmt_help(arena, remainder, subs); + let opt_continuation = substitute_in_stmt_help(arena, continuation, subs); + + if opt_remainder.is_some() || opt_continuation.is_some() { + let remainder = opt_remainder.unwrap_or(remainder); + let continuation = opt_continuation.unwrap_or_else(|| *continuation); + + Some(arena.alloc(Join { + id: *id, + parameters, + remainder, + continuation, + })) + } else { + None + } + } + Cond { + cond_symbol, + cond_layout, + branching_symbol, + branching_layout, + pass, + fail, + ret_layout, + } => { + let opt_pass = substitute_in_stmt_help(arena, pass, subs); + let opt_fail = substitute_in_stmt_help(arena, fail, subs); + + if opt_pass.is_some() || opt_fail.is_some() { + let pass = opt_pass.unwrap_or(pass); + let fail = opt_fail.unwrap_or_else(|| *fail); + + Some(arena.alloc(Cond { + cond_symbol: *cond_symbol, + cond_layout: cond_layout.clone(), + branching_symbol: *branching_symbol, + branching_layout: branching_layout.clone(), + pass, + fail, + ret_layout: ret_layout.clone(), + })) + } else { + None + } + } + Switch { + cond_symbol, + cond_layout, + branches, + default_branch, + ret_layout, + } => { + let opt_default = substitute_in_stmt_help(arena, default_branch, subs); + + let mut did_change = false; + + let opt_branches = Vec::from_iter_in( + branches.iter().map(|(label, branch)| { + match substitute_in_stmt_help(arena, branch, subs) { + None => None, + Some(branch) => { + did_change = true; + Some((*label, branch.clone())) + } + } + }), + arena, + ); + + if opt_default.is_some() || did_change { + let default_branch = opt_default.unwrap_or(default_branch); + + let branches = if did_change { + let new = Vec::from_iter_in( + opt_branches.into_iter().zip(branches.iter()).map( + |(opt_branch, branch)| match opt_branch { + None => branch.clone(), + Some(new_branch) => new_branch, + }, + ), + arena, + ); + + new.into_bump_slice() + } else { + branches + }; + + Some(arena.alloc(Switch { + cond_symbol: *cond_symbol, + cond_layout: cond_layout.clone(), + default_branch, + branches, + ret_layout: ret_layout.clone(), + })) + } else { + None + } + } + Ret(s) => match substitute(subs, *s) { + Some(s) => Some(arena.alloc(Ret(s))), + None => None, + }, + Inc(symbol, cont) => match substitute_in_stmt_help(arena, cont, subs) { + Some(cont) => Some(arena.alloc(Inc(*symbol, cont))), + None => None, + }, + Dec(symbol, cont) => match substitute_in_stmt_help(arena, cont, subs) { + Some(cont) => Some(arena.alloc(Dec(*symbol, cont))), + None => None, + }, + + Jump(id, args) => { + let mut did_change = false; + let new_args = Vec::from_iter_in( + args.iter().map(|s| match substitute(subs, *s) { + None => *s, + Some(s) => { + did_change = true; + s + } + }), + arena, + ); + + if did_change { + let args = new_args.into_bump_slice(); + + Some(arena.alloc(Jump(*id, args))) + } else { + None + } + } + + RuntimeError(_) => None, + } +} + +fn substitute_in_expr<'a>( + arena: &'a Bump, + expr: &'a Expr<'a>, + subs: &MutMap, +) -> Option> { + use Expr::*; + + match expr { + Literal(_) | FunctionPointer(_, _) | EmptyArray | RuntimeErrorFunction(_) => None, + + FunctionCall { + call_type, + args, + arg_layouts, + layout, + } => { + let opt_call_type = match call_type { + CallType::ByName(s) => substitute(subs, *s).map(CallType::ByName), + CallType::ByPointer(s) => substitute(subs, *s).map(CallType::ByPointer), + }; + + let mut did_change = false; + let new_args = Vec::from_iter_in( + args.iter().map(|s| match substitute(subs, *s) { + None => *s, + Some(s) => { + did_change = true; + s + } + }), + arena, + ); + + if did_change || opt_call_type.is_some() { + let call_type = opt_call_type.unwrap_or(*call_type); + + let args = new_args.into_bump_slice(); + + Some(FunctionCall { + call_type, + args, + arg_layouts: *arg_layouts, + layout: layout.clone(), + }) + } else { + None + } + } + RunLowLevel(op, args) => { + let mut did_change = false; + let new_args = Vec::from_iter_in( + args.iter().map(|s| match substitute(subs, *s) { + None => *s, + Some(s) => { + did_change = true; + s + } + }), + arena, + ); + + if did_change { + let args = new_args.into_bump_slice(); + + Some(RunLowLevel(*op, args)) + } else { + None + } + } + + Tag { + tag_layout, + tag_name, + tag_id, + union_size, + arguments: args, + } => { + let mut did_change = false; + let new_args = Vec::from_iter_in( + args.iter().map(|s| match substitute(subs, *s) { + None => *s, + Some(s) => { + did_change = true; + s + } + }), + arena, + ); + + if did_change { + let arguments = new_args.into_bump_slice(); + + Some(Tag { + tag_layout: tag_layout.clone(), + tag_name: tag_name.clone(), + tag_id: *tag_id, + union_size: *union_size, + arguments, + }) + } else { + None + } + } + Struct(args) => { + let mut did_change = false; + let new_args = Vec::from_iter_in( + args.iter().map(|s| match substitute(subs, *s) { + None => *s, + Some(s) => { + did_change = true; + s + } + }), + arena, + ); + + if did_change { + let args = new_args.into_bump_slice(); + + Some(Struct(args)) + } else { + None + } + } + + Array { + elems: args, + elem_layout, + } => { + let mut did_change = false; + let new_args = Vec::from_iter_in( + args.iter().map(|s| match substitute(subs, *s) { + None => *s, + Some(s) => { + did_change = true; + s + } + }), + arena, + ); + + if did_change { + let args = new_args.into_bump_slice(); + + Some(Array { + elem_layout: elem_layout.clone(), + elems: args, + }) + } else { + None + } + } + + AccessAtIndex { + index, + structure, + field_layouts, + is_unwrapped, + } => match substitute(subs, *structure) { + Some(structure) => Some(AccessAtIndex { + index: *index, + field_layouts: *field_layouts, + is_unwrapped: *is_unwrapped, + structure, + }), + None => None, + }, + } +} + +#[allow(clippy::too_many_arguments)] +fn store_pattern<'a>( + env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, + can_pat: &Pattern<'a>, + outer_symbol: Symbol, + mut stmt: Stmt<'a>, +) -> Result, &'a str> { + use Pattern::*; + + match can_pat { + Identifier(symbol) => { + substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol); + } + Underscore => { + // do nothing + } + IntLiteral(_) + | FloatLiteral(_) + | EnumLiteral { .. } + | BitLiteral { .. } + | StrLiteral(_) => {} + AppliedTag { + union, arguments, .. + } => { + let is_unwrapped = union.alternatives.len() == 1; + + let mut arg_layouts = Vec::with_capacity_in(arguments.len(), env.arena); + + if !is_unwrapped { + // add an element for the tag discriminant + arg_layouts.push(Layout::Builtin(Builtin::Int64)); + } + + for (_, layout) in arguments { + arg_layouts.push(layout.clone()); + } + + for (index, (argument, arg_layout)) in arguments.iter().enumerate().rev() { + let load = Expr::AccessAtIndex { + is_unwrapped, + index: (!is_unwrapped as usize + index) as u64, + field_layouts: arg_layouts.clone().into_bump_slice(), + structure: outer_symbol, + }; + match argument { + Identifier(symbol) => { + // store immediately in the given symbol + stmt = Stmt::Let(*symbol, load, arg_layout.clone(), env.arena.alloc(stmt)); + } + Underscore => { + // ignore + } + IntLiteral(_) + | FloatLiteral(_) + | EnumLiteral { .. } + | BitLiteral { .. } + | StrLiteral(_) => {} + _ => { + // store the field in a symbol, and continue matching on it + let symbol = env.unique_symbol(); + + // first recurse, continuing to unpack symbol + stmt = store_pattern(env, procs, layout_cache, argument, symbol, stmt)?; + + // then store the symbol + stmt = Stmt::Let(symbol, load, arg_layout.clone(), env.arena.alloc(stmt)); + } + } + } + } + RecordDestructure(destructs, Layout::Struct(sorted_fields)) => { + for (index, destruct) in destructs.iter().enumerate().rev() { + stmt = store_record_destruct( + env, + procs, + layout_cache, + destruct, + index as u64, + outer_symbol, + sorted_fields, + stmt, + )?; + } + } + + RecordDestructure(_, _) => { + unreachable!("a record destructure must always occur on a struct layout"); + } + + Shadowed(_region, _ident) => { + return Err(&"TODO"); + } + + UnsupportedPattern(_region) => { + return Err(&"TODO"); + } + } + + Ok(stmt) +} + +#[allow(clippy::too_many_arguments)] +fn store_record_destruct<'a>( + env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, + destruct: &RecordDestruct<'a>, + index: u64, + outer_symbol: Symbol, + sorted_fields: &'a [Layout<'a>], + mut stmt: Stmt<'a>, +) -> Result, &'a str> { + use Pattern::*; + + let load = Expr::AccessAtIndex { + index, + field_layouts: sorted_fields, + structure: outer_symbol, + is_unwrapped: true, + }; + + match &destruct.typ { + DestructType::Required => { + stmt = Stmt::Let( + destruct.symbol, + load, + destruct.layout.clone(), + env.arena.alloc(stmt), + ); + } + DestructType::Optional(_expr) => { + todo!("TODO monomorphize optional field destructure's default expr"); + } + DestructType::Guard(guard_pattern) => match &guard_pattern { + Identifier(symbol) => { + stmt = Stmt::Let( + *symbol, + load, + destruct.layout.clone(), + env.arena.alloc(stmt), + ); + } + Underscore => { + // important that this is special-cased to do nothing: mono record patterns will extract all the + // fields, but those not bound in the source code are guarded with the underscore + // pattern. So given some record `{ x : a, y : b }`, a match + // + // { x } -> ... + // + // is actually + // + // { x, y: _ } -> ... + // + // internally. But `y` is never used, so we must make sure it't not stored/loaded. + } + IntLiteral(_) + | FloatLiteral(_) + | EnumLiteral { .. } + | BitLiteral { .. } + | StrLiteral(_) => {} + + _ => { + let symbol = env.unique_symbol(); + + stmt = store_pattern(env, procs, layout_cache, guard_pattern, symbol, stmt)?; + + stmt = Stmt::Let(symbol, load, destruct.layout.clone(), env.arena.alloc(stmt)); + } + }, + } + + Ok(stmt) +} + +#[allow(clippy::too_many_arguments)] +fn call_by_name<'a>( + env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + fn_var: Variable, + ret_var: Variable, + proc_name: Symbol, + loc_args: std::vec::Vec<(Variable, Located)>, + layout_cache: &mut LayoutCache<'a>, + assigned: Symbol, + hole: &'a Stmt<'a>, +) -> Stmt<'a> { + // Register a pending_specialization for this function + match layout_cache.from_var(env.arena, fn_var, env.subs) { + Ok(layout) => { + // Build the CallByName node + let arena = env.arena; + let mut pattern_vars = Vec::with_capacity_in(loc_args.len(), arena); + + let mut field_symbols = Vec::with_capacity_in(loc_args.len(), env.arena); + + for (_, arg_expr) in loc_args.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg_expr.value { + field_symbols.push(symbol); + } else { + field_symbols.push(env.unique_symbol()); + } + } + let field_symbols = field_symbols.into_bump_slice(); + + for (var, _) in &loc_args { + match layout_cache.from_var(&env.arena, *var, &env.subs) { + Ok(_) => { + pattern_vars.push(*var); + } + Err(_) => { + // One of this function's arguments code gens to a runtime error, + // so attempting to call it will immediately crash. + return Stmt::RuntimeError("TODO runtime error for invalid layout"); + } + } + } + + // TODO does this work? + let empty = &[] as &[_]; + let (arg_layouts, layout) = if let Layout::FunctionPointer(args, rlayout) = layout { + (args, rlayout) + } else { + (empty, &layout) + }; + + // If we've already specialized this one, no further work is needed. + if procs.specialized.contains_key(&(proc_name, layout.clone())) { + let call = Expr::FunctionCall { + call_type: CallType::ByName(proc_name), + layout: layout.clone(), + arg_layouts, + args: field_symbols, + }; + + let mut result = Stmt::Let(assigned, call, layout.clone(), hole); + + for ((_, loc_arg), symbol) in + loc_args.into_iter().rev().zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = loc_arg.value { + continue; + } + result = with_hole( + env, + loc_arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result + } else { + let pending = PendingSpecialization { + pattern_vars, + ret_var, + fn_var, + }; + + // When requested (that is, when procs.pending_specializations is `Some`), + // store a pending specialization rather than specializing immediately. + // + // We do this so that we can do specialization in two passes: first, + // build the mono_expr with all the specialized calls in place (but + // no specializations performed yet), and then second, *after* + // de-duplicating requested specializations (since multiple modules + // which could be getting monomorphized in parallel might request + // the same specialization independently), we work through the + // queue of pending specializations to complete each specialization + // exactly once. + match &mut procs.pending_specializations { + Some(pending_specializations) => { + // register the pending specialization, so this gets code genned later + add_pending(pending_specializations, proc_name, layout.clone(), pending); + + let call = Expr::FunctionCall { + call_type: CallType::ByName(proc_name), + layout: layout.clone(), + arg_layouts, + args: field_symbols, + }; + + let mut result = Stmt::Let(assigned, call, layout.clone(), hole); + + for ((_, loc_arg), symbol) in + loc_args.into_iter().rev().zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = loc_arg.value { + continue; + } + result = with_hole( + env, + loc_arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result + } + None => { + let opt_partial_proc = procs.partial_procs.get(&proc_name); + + match opt_partial_proc { + Some(partial_proc) => { + // TODO should pending_procs hold a Rc to avoid this .clone()? + let partial_proc = partial_proc.clone(); + + // Mark this proc as in-progress, so if we're dealing with + // mutually recursive functions, we don't loop forever. + // (We had a bug around this before this system existed!) + procs + .specialized + .insert((proc_name, layout.clone()), InProgress); + + match specialize( + env, + procs, + proc_name, + layout_cache, + pending, + partial_proc, + ) { + Ok(proc) => { + procs + .specialized + .insert((proc_name, layout.clone()), Done(proc)); + + let call = Expr::FunctionCall { + call_type: CallType::ByName(proc_name), + layout: layout.clone(), + arg_layouts, + args: field_symbols, + }; + + let mut result = + Stmt::Let(assigned, call, layout.clone(), hole); + + for ((_, loc_arg), symbol) in loc_args + .into_iter() + .rev() + .zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = loc_arg.value { + continue; + } + result = with_hole( + env, + loc_arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result + } + Err(error) => { + let error_msg = env.arena.alloc(format!( + "TODO generate a RuntimeError message for {:?}", + error + )); + + procs.runtime_errors.insert(proc_name, error_msg); + + Stmt::RuntimeError(error_msg) + } + } + } + + None => { + // This must have been a runtime error. + match procs.runtime_errors.get(&proc_name) { + Some(error) => Stmt::RuntimeError(error), + None => unreachable!("Proc name {:?} is invalid", proc_name), + } + } + } + } + } + } + } + Err(_) => { + // This function code gens to a runtime error, + // so attempting to call it will immediately crash. + Stmt::RuntimeError("") + } + } +} + +/// A pattern, including possible problems (e.g. shadowing) so that +/// codegen can generate a runtime error if this pattern is reached. +#[derive(Clone, Debug, PartialEq)] +pub enum Pattern<'a> { + Identifier(Symbol), + Underscore, + + IntLiteral(i64), + FloatLiteral(u64), + BitLiteral { + value: bool, + tag_name: TagName, + union: crate::exhaustive::Union, + }, + EnumLiteral { + tag_id: u8, + tag_name: TagName, + union: crate::exhaustive::Union, + }, + StrLiteral(Box), + + RecordDestructure(Vec<'a, RecordDestruct<'a>>, Layout<'a>), + AppliedTag { + tag_name: TagName, + tag_id: u8, + arguments: Vec<'a, (Pattern<'a>, Layout<'a>)>, + layout: Layout<'a>, + union: crate::exhaustive::Union, + }, + + // Runtime Exceptions + Shadowed(Region, Located), + // Example: (5 = 1 + 2) is an unsupported pattern in an assignment; Int patterns aren't allowed in assignments! + UnsupportedPattern(Region), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct RecordDestruct<'a> { + pub label: Lowercase, + pub layout: Layout<'a>, + pub symbol: Symbol, + pub typ: DestructType<'a>, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum DestructType<'a> { + Required, + Optional(roc_can::expr::Expr), + Guard(Pattern<'a>), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct WhenBranch<'a> { + pub patterns: Vec<'a, Pattern<'a>>, + pub value: Expr<'a>, + pub guard: Option>, +} + +pub fn from_can_pattern<'a>( + env: &mut Env<'a, '_>, + layout_cache: &mut LayoutCache<'a>, + can_pattern: &roc_can::pattern::Pattern, +) -> Pattern<'a> { + use roc_can::pattern::Pattern::*; + match can_pattern { + Underscore => Pattern::Underscore, + Identifier(symbol) => Pattern::Identifier(*symbol), + IntLiteral(v) => Pattern::IntLiteral(*v), + FloatLiteral(v) => Pattern::FloatLiteral(f64::to_bits(*v)), + StrLiteral(v) => Pattern::StrLiteral(v.clone()), + Shadowed(region, ident) => Pattern::Shadowed(*region, ident.clone()), + UnsupportedPattern(region) => Pattern::UnsupportedPattern(*region), + MalformedPattern(_problem, region) => { + // TODO preserve malformed problem information here? + Pattern::UnsupportedPattern(*region) + } + NumLiteral(var, num) => match num_argument_to_int_or_float(env.subs, *var) { + IntOrFloat::IntType => Pattern::IntLiteral(*num), + IntOrFloat::FloatType => Pattern::FloatLiteral(*num as u64), + }, + + AppliedTag { + whole_var, + tag_name, + arguments, + .. + } => { + use crate::exhaustive::Union; + use crate::layout::UnionVariant::*; + + let variant = crate::layout::union_sorted_tags(env.arena, *whole_var, env.subs); + + match variant { + Never => unreachable!("there is no pattern of type `[]`"), + Unit => Pattern::EnumLiteral { + tag_id: 0, + tag_name: tag_name.clone(), + union: Union { + render_as: RenderAs::Tag, + alternatives: vec![Ctor { + tag_id: TagId(0), + name: tag_name.clone(), + arity: 0, + }], + }, + }, + BoolUnion { ttrue, ffalse } => Pattern::BitLiteral { + value: tag_name == &ttrue, + tag_name: tag_name.clone(), + union: Union { + render_as: RenderAs::Tag, + alternatives: vec![ + Ctor { + tag_id: TagId(0), + name: ffalse, + arity: 0, + }, + Ctor { + tag_id: TagId(1), + name: ttrue, + arity: 0, + }, + ], + }, + }, + ByteUnion(tag_names) => { + let tag_id = tag_names + .iter() + .position(|key| key == tag_name) + .expect("tag must be in its own type"); + + let mut ctors = std::vec::Vec::with_capacity(tag_names.len()); + for (i, tag_name) in tag_names.iter().enumerate() { + ctors.push(Ctor { + tag_id: TagId(i as u8), + name: tag_name.clone(), + arity: 0, + }) + } + + let union = crate::exhaustive::Union { + render_as: RenderAs::Tag, + alternatives: ctors, + }; + + Pattern::EnumLiteral { + tag_id: tag_id as u8, + tag_name: tag_name.clone(), + union, + } + } + Unwrapped(field_layouts) => { + let union = crate::exhaustive::Union { + render_as: RenderAs::Tag, + alternatives: vec![Ctor { + tag_id: TagId(0), + name: tag_name.clone(), + arity: field_layouts.len(), + }], + }; + + let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); + for ((_, loc_pat), layout) in arguments.iter().zip(field_layouts.iter()) { + mono_args.push(( + from_can_pattern(env, layout_cache, &loc_pat.value), + layout.clone(), + )); + } + + let layout = Layout::Struct(field_layouts.into_bump_slice()); + + Pattern::AppliedTag { + tag_name: tag_name.clone(), + tag_id: 0, + arguments: mono_args, + union, + layout, + } + } + Wrapped(tags) => { + let mut ctors = std::vec::Vec::with_capacity(tags.len()); + for (i, (tag_name, args)) in tags.iter().enumerate() { + ctors.push(Ctor { + tag_id: TagId(i as u8), + name: tag_name.clone(), + // don't include tag discriminant in arity + arity: args.len() - 1, + }) + } + + let union = crate::exhaustive::Union { + render_as: RenderAs::Tag, + alternatives: ctors, + }; + + let (tag_id, (_, argument_layouts)) = tags + .iter() + .enumerate() + .find(|(_, (key, _))| key == tag_name) + .expect("tag must be in its own type"); + + let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); + // disregard the tag discriminant layout + let it = argument_layouts[1..].iter(); + for ((_, loc_pat), layout) in arguments.iter().zip(it) { + mono_args.push(( + from_can_pattern(env, layout_cache, &loc_pat.value), + layout.clone(), + )); + } + + let mut layouts: Vec<&'a [Layout<'a>]> = + Vec::with_capacity_in(tags.len(), env.arena); + + for (_, arg_layouts) in tags.into_iter() { + layouts.push(arg_layouts); + } + + let layout = Layout::Union(layouts.into_bump_slice()); + + Pattern::AppliedTag { + tag_name: tag_name.clone(), + tag_id: tag_id as u8, + arguments: mono_args, + union, + layout, + } + } + } + } + + RecordDestructure { + whole_var, + destructs, + .. + } => { + let mut mono_destructs = Vec::with_capacity_in(destructs.len(), env.arena); + let mut destructs = destructs.clone(); + destructs.sort_by(|a, b| a.value.label.cmp(&b.value.label)); + + let mut it = destructs.iter(); + let mut opt_destruct = it.next(); + + let sorted_fields = crate::layout::sort_record_fields(env.arena, *whole_var, env.subs); + + let mut field_layouts = Vec::with_capacity_in(sorted_fields.len(), env.arena); + + for (label, field_layout) in sorted_fields.into_iter() { + if let Some(destruct) = opt_destruct { + if destruct.value.label == label { + opt_destruct = it.next(); + + mono_destructs.push(from_can_record_destruct( + env, + layout_cache, + &destruct.value, + field_layout.clone(), + )); + } else { + // insert underscore pattern + mono_destructs.push(RecordDestruct { + label: label.clone(), + symbol: env.unique_symbol(), + layout: field_layout.clone(), + typ: DestructType::Guard(Pattern::Underscore), + }); + } + } else { + // insert underscore pattern + mono_destructs.push(RecordDestruct { + label: label.clone(), + symbol: env.unique_symbol(), + layout: field_layout.clone(), + typ: DestructType::Guard(Pattern::Underscore), + }); + } + field_layouts.push(field_layout); + } + + Pattern::RecordDestructure( + mono_destructs, + Layout::Struct(field_layouts.into_bump_slice()), + ) + } + } +} + +fn from_can_record_destruct<'a>( + env: &mut Env<'a, '_>, + layout_cache: &mut LayoutCache<'a>, + can_rd: &roc_can::pattern::RecordDestruct, + field_layout: Layout<'a>, +) -> RecordDestruct<'a> { + RecordDestruct { + label: can_rd.label.clone(), + symbol: can_rd.symbol, + layout: field_layout, + typ: match &can_rd.typ { + roc_can::pattern::DestructType::Required => DestructType::Required, + roc_can::pattern::DestructType::Optional(_, loc_expr) => { + DestructType::Optional(loc_expr.value.clone()) + } + roc_can::pattern::DestructType::Guard(_, loc_pattern) => { + DestructType::Guard(from_can_pattern(env, layout_cache, &loc_pattern.value)) + } + }, + } +} + +/// Potentially translate LowLevel operations into more efficient ones based on +/// uniqueness type info. +/// +/// For example, turning LowLevel::ListSet to LowLevel::ListSetInPlace if the +/// list is Unique. +fn optimize_low_level( + subs: &Subs, + op: LowLevel, + args: &[(Variable, roc_can::expr::Expr)], +) -> LowLevel { + match op { + LowLevel::ListSet => { + // The first arg is the one with the List in it. + // List.set : List elem, Int, elem -> List elem + let list_arg_var = args[0].0; + let content = subs.get_without_compacting(list_arg_var).content; + + match content { + Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, attr_args)) => { + debug_assert_eq!(attr_args.len(), 2); + + // If the first argument (the List) is unique, + // then we can safely upgrade to List.set_in_place + let attr_arg_content = subs.get_without_compacting(attr_args[0]).content; + + if attr_arg_content.is_unique(subs) { + LowLevel::ListSetInPlace + } else { + LowLevel::ListSet + } + } + _ => op, + } + } + _ => op, + } +} + +pub enum IntOrFloat { + IntType, + FloatType, +} + +/// Given the `a` in `Num a`, determines whether it's an int or a float +pub fn num_argument_to_int_or_float(subs: &Subs, var: Variable) -> IntOrFloat { + match subs.get_without_compacting(var).content { + Content::Alias(Symbol::NUM_INTEGER, args, _) => { + debug_assert!(args.is_empty()); + IntOrFloat::IntType + } + Content::FlexVar(_) => { + // If this was still a (Num *), assume compiling it to an Int + IntOrFloat::IntType + } + Content::Alias(Symbol::NUM_FLOATINGPOINT, args, _) => { + debug_assert!(args.is_empty()); + IntOrFloat::FloatType + } + Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, attr_args)) => { + debug_assert!(attr_args.len() == 2); + + // Recurse on the second argument + num_argument_to_int_or_float(subs, attr_args[1]) + } + other => { + panic!( + "Unrecognized Num type argument for var {:?} with Content: {:?}", + var, other + ); + } + } +} diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 6c997f2219..8ca6734b8f 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -27,6 +27,12 @@ pub enum Layout<'a> { Pointer(&'a Layout<'a>), } +#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)] +pub enum MemoryMode { + Unique, + Refcounted, +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum Builtin<'a> { Int128, @@ -42,7 +48,7 @@ pub enum Builtin<'a> { Str, Map(&'a Layout<'a>, &'a Layout<'a>), Set(&'a Layout<'a>), - List(&'a Layout<'a>), + List(MemoryMode, &'a Layout<'a>), EmptyStr, EmptyList, EmptyMap, @@ -136,6 +142,31 @@ impl<'a> Layout<'a> { Pointer(_) => pointer_size, } } + + pub fn is_refcounted(&self) -> bool { + match self { + Layout::Builtin(Builtin::List(_, _)) => true, + _ => false, + } + } + + /// Even if a value (say, a record) is not itself reference counted, + /// it may contains values/fields that are. Therefore when this record + /// goes out of scope, the refcount on those values/fields must be decremented. + pub fn contains_refcounted(&self) -> bool { + use Layout::*; + + match self { + Builtin(builtin) => builtin.is_refcounted(), + Struct(fields) => fields.iter().any(|f| f.is_refcounted()), + Union(fields) => fields + .iter() + .map(|ls| ls.iter()) + .flatten() + .any(|f| f.is_refcounted()), + FunctionPointer(_, _) | Pointer(_) => false, + } + } } /// Avoid recomputing Layout from Variable multiple times. @@ -210,7 +241,7 @@ impl<'a> Builtin<'a> { Str | EmptyStr => Builtin::STR_WORDS * pointer_size, Map(_, _) | EmptyMap => Builtin::MAP_WORDS * pointer_size, Set(_) | EmptySet => Builtin::SET_WORDS * pointer_size, - List(_) | EmptyList => Builtin::LIST_WORDS * pointer_size, + List(_, _) | EmptyList => Builtin::LIST_WORDS * pointer_size, } } @@ -220,7 +251,23 @@ impl<'a> Builtin<'a> { match self { Int128 | Int64 | Int32 | Int16 | Int8 | Int1 | Float128 | Float64 | Float32 | Float16 | EmptyStr | EmptyMap | EmptyList | EmptySet => true, - Str | Map(_, _) | Set(_) | List(_) => false, + Str | Map(_, _) | Set(_) | List(_, _) => false, + } + } + + // Question: does is_refcounted exactly correspond with the "safe to memcpy" property? + pub fn is_refcounted(&self) -> bool { + use Builtin::*; + + match self { + Int128 | Int64 | Int32 | Int16 | Int8 | Int1 | Float128 | Float64 | Float32 + | Float16 | EmptyStr | EmptyMap | EmptyList | EmptySet => false, + List(mode, element_layout) => match mode { + MemoryMode::Refcounted => true, + MemoryMode::Unique => element_layout.contains_refcounted(), + }, + + Str | Map(_, _) | Set(_) => true, } } } @@ -258,13 +305,26 @@ fn layout_from_flat_type<'a>( debug_assert_eq!(args.len(), 2); // The first argument is the uniqueness info; - // that doesn't affect layout, so we don't need it here. + // second is the base type let wrapped_var = args[1]; - // For now, layout is unaffected by uniqueness. - // (Incorporating refcounting may change this.) - // Unwrap and continue - Layout::from_var(arena, wrapped_var, subs) + // correct the memory mode of unique lists + match Layout::from_var(arena, wrapped_var, subs)? { + Layout::Builtin(Builtin::List(_, elem_layout)) => { + let uniqueness_var = args[0]; + let uniqueness_content = + subs.get_without_compacting(uniqueness_var).content; + + let mode = if uniqueness_content.is_unique(subs) { + MemoryMode::Unique + } else { + MemoryMode::Refcounted + }; + + Ok(Layout::Builtin(Builtin::List(mode, elem_layout))) + } + other => Ok(other), + } } _ => { panic!("TODO layout_from_flat_type for {:?}", Apply(symbol, args)); @@ -336,8 +396,8 @@ fn layout_from_flat_type<'a>( Ok(layout_from_tag_union(arena, tags, subs)) } - RecursiveTagUnion(_, _, _) => { - panic!("TODO make Layout for non-empty Tag Union"); + RecursiveTagUnion(_rec_var, _tags, _ext_var) => { + panic!("TODO make Layout for empty RecursiveTagUnion"); } EmptyTagUnion => { panic!("TODO make Layout for empty Tag Union"); @@ -656,15 +716,15 @@ fn unwrap_num_tag<'a>(subs: &Subs, var: Variable) -> Result, LayoutPr pub fn list_layout_from_elem<'a>( arena: &'a Bump, subs: &Subs, - var: Variable, + elem_var: Variable, ) -> Result, LayoutProblem> { - match subs.get_without_compacting(var).content { + match subs.get_without_compacting(elem_var).content { Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) => { debug_assert_eq!(args.len(), 2); - let arg_var = args.get(1).unwrap(); + let var = *args.get(1).unwrap(); - list_layout_from_elem(arena, subs, *arg_var) + list_layout_from_elem(arena, subs, var) } Content::FlexVar(_) | Content::RigidVar(_) => { // If this was still a (List *) then it must have been an empty list @@ -674,7 +734,28 @@ pub fn list_layout_from_elem<'a>( let elem_layout = Layout::new(arena, content, subs)?; // This is a normal list. - Ok(Layout::Builtin(Builtin::List(arena.alloc(elem_layout)))) + Ok(Layout::Builtin(Builtin::List( + MemoryMode::Refcounted, + arena.alloc(elem_layout), + ))) } } } + +pub fn mode_from_var(var: Variable, subs: &Subs) -> MemoryMode { + match subs.get_without_compacting(var).content { + Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) => { + debug_assert_eq!(args.len(), 2); + + let uvar = *args.get(0).unwrap(); + let content = subs.get_without_compacting(uvar).content; + + if content.is_unique(subs) { + MemoryMode::Unique + } else { + MemoryMode::Refcounted + } + } + _ => MemoryMode::Refcounted, + } +} diff --git a/compiler/mono/src/lib.rs b/compiler/mono/src/lib.rs index 8b78927d46..22abce575d 100644 --- a/compiler/mono/src/lib.rs +++ b/compiler/mono/src/lib.rs @@ -11,12 +11,15 @@ // re-enable this when working on performance optimizations than have it block PRs. #![allow(clippy::large_enum_variant)] -pub mod expr; +pub mod inc_dec; +pub mod ir; pub mod layout; // Temporary, while we can build up test cases and optimize the exhaustiveness checking. // For now, following this warning's advice will lead to nasty type inference errors. +//#[allow(clippy::ptr_arg)] +//pub mod decision_tree; #[allow(clippy::ptr_arg)] pub mod decision_tree; #[allow(clippy::ptr_arg)] -pub mod pattern; +pub mod exhaustive; diff --git a/compiler/mono/src/reset_reuse.rs b/compiler/mono/src/reset_reuse.rs new file mode 100644 index 0000000000..ad2c2ef45d --- /dev/null +++ b/compiler/mono/src/reset_reuse.rs @@ -0,0 +1,648 @@ +use crate::expr::Env; +use crate::expr::Expr; + +use bumpalo::collections::Vec; +use roc_collections::all::MutSet; +use roc_module::symbol::Symbol; + +pub fn function_r<'a>(env: &mut Env<'a, '_>, body: &'a Expr<'a>) -> Expr<'a> { + use Expr::*; + + match body { + Switch { + cond_symbol, + branches, + cond, + cond_layout, + default_branch, + ret_layout, + } => { + let stack_size = cond_layout.stack_size(env.pointer_size); + let mut new_branches = Vec::with_capacity_in(branches.len(), env.arena); + + for (tag, stores, branch) in branches.iter() { + let new_branch = function_d(env, *cond_symbol, stack_size as _, branch); + + new_branches.push((*tag, *stores, new_branch)); + } + + let new_default_branch = ( + default_branch.0, + &*env.arena.alloc(function_d( + env, + *cond_symbol, + stack_size as _, + default_branch.1, + )), + ); + + Switch { + cond_symbol: *cond_symbol, + branches: new_branches.into_bump_slice(), + default_branch: new_default_branch, + ret_layout: ret_layout.clone(), + cond: *cond, + cond_layout: cond_layout.clone(), + } + } + Cond { + cond_symbol, + cond_layout, + branching_symbol, + branching_layout, + pass, + fail, + ret_layout, + } => { + let stack_size = cond_layout.stack_size(env.pointer_size); + + let new_pass = ( + pass.0, + &*env + .arena + .alloc(function_d(env, *cond_symbol, stack_size as _, pass.1)), + ); + + let new_fail = ( + fail.0, + &*env + .arena + .alloc(function_d(env, *cond_symbol, stack_size as _, fail.1)), + ); + + Cond { + cond_symbol: *cond_symbol, + cond_layout: cond_layout.clone(), + branching_symbol: *branching_symbol, + branching_layout: branching_layout.clone(), + ret_layout: ret_layout.clone(), + pass: new_pass, + fail: new_fail, + } + } + + Store(stores, body) => { + let new_body = function_r(env, body); + + Store(stores, env.arena.alloc(new_body)) + } + + DecAfter(symbol, body) => { + let new_body = function_r(env, body); + + DecAfter(*symbol, env.arena.alloc(new_body)) + } + + CallByName { .. } + | CallByPointer(_, _, _) + | RunLowLevel(_, _) + | Tag { .. } + | Struct(_) + | Array { .. } + | AccessAtIndex { .. } => { + // TODO + // how often are `when` expressions in one of the above? + body.clone() + } + + Int(_) + | Float(_) + | Str(_) + | Bool(_) + | Byte(_) + | Load(_) + | EmptyArray + | Inc(_, _) + | FunctionPointer(_, _) + | RuntimeError(_) + | RuntimeErrorFunction(_) => body.clone(), + + Reset(_, _) | Reuse(_, _) => unreachable!("reset/reuse should not have been inserted yet!"), + } +} + +fn function_d<'a>( + env: &mut Env<'a, '_>, + z: Symbol, + stack_size: usize, + body: &'a Expr<'a>, +) -> Expr<'a> { + let symbols = symbols_in_expr(body); + if symbols.contains(&z) { + return body.clone(); + } + + if let Ok(reused) = function_s(env, z, stack_size, body) { + Expr::Reset(z, env.arena.alloc(reused)) + } else { + body.clone() + } + /* + match body { + Expr::Tag { .. } => Some(env.arena.alloc(Expr::Reuse(w, body))), + _ => None, + } + */ +} + +fn function_s<'a>( + env: &mut Env<'a, '_>, + w: Symbol, + stack_size: usize, + body: &'a Expr<'a>, +) -> Result<&'a Expr<'a>, &'a Expr<'a>> { + use Expr::*; + + match body { + Tag { tag_layout, .. } => { + if tag_layout.stack_size(env.pointer_size) as usize <= stack_size { + Ok(env.arena.alloc(Expr::Reuse(w, body))) + } else { + Err(body) + } + } + + Array { .. } | Struct(_) => { + // TODO + Err(body) + } + + Switch { + cond_symbol, + branches, + cond, + cond_layout, + default_branch, + ret_layout, + } => { + // we can re-use `w` in each branch + let mut has_been_reused = false; + let mut new_branches = Vec::with_capacity_in(branches.len(), env.arena); + for (tag, stores, branch) in branches.iter() { + match function_s(env, *cond_symbol, stack_size as _, branch) { + Ok(new_branch) => { + has_been_reused = true; + new_branches.push((*tag, *stores, new_branch.clone())); + } + Err(new_branch) => { + new_branches.push((*tag, *stores, new_branch.clone())); + } + } + } + + let new_default_branch = ( + default_branch.0, + match function_s(env, *cond_symbol, stack_size, default_branch.1) { + Ok(new) => { + has_been_reused = true; + new + } + Err(new) => new, + }, + ); + let result = env.arena.alloc(Switch { + cond_symbol: *cond_symbol, + branches: new_branches.into_bump_slice(), + default_branch: new_default_branch, + ret_layout: ret_layout.clone(), + cond: *cond, + cond_layout: cond_layout.clone(), + }); + + if has_been_reused { + Ok(result) + } else { + Err(result) + } + } + + Cond { + cond_symbol, + cond_layout, + branching_symbol, + branching_layout, + pass, + fail, + ret_layout, + } => { + let mut has_been_reused = false; + let new_pass = ( + pass.0, + match function_s(env, *cond_symbol, stack_size, pass.1) { + Ok(new) => { + has_been_reused = true; + new + } + Err(new) => new, + }, + ); + + let new_fail = ( + fail.0, + match function_s(env, *cond_symbol, stack_size, fail.1) { + Ok(new) => { + has_been_reused = true; + new + } + Err(new) => new, + }, + ); + + let result = env.arena.alloc(Cond { + cond_symbol: *cond_symbol, + cond_layout: cond_layout.clone(), + branching_symbol: *branching_symbol, + branching_layout: branching_layout.clone(), + ret_layout: ret_layout.clone(), + pass: new_pass, + fail: new_fail, + }); + + if has_been_reused { + Ok(result) + } else { + Err(result) + } + } + + Store(stores, expr) => { + let new_expr = function_s(env, w, stack_size, expr)?; + + Ok(env.arena.alloc(Store(*stores, new_expr))) + } + + DecAfter(symbol, expr) => { + let new_expr = function_s(env, w, stack_size, expr)?; + + Ok(env.arena.alloc(DecAfter(*symbol, new_expr))) + } + + CallByName { .. } | CallByPointer(_, _, _) | RunLowLevel(_, _) | AccessAtIndex { .. } => { + // TODO + // how often are `Tag` expressions in one of the above? + Err(body) + } + + Int(_) + | Float(_) + | Str(_) + | Bool(_) + | Byte(_) + | Load(_) + | EmptyArray + | Inc(_, _) + | FunctionPointer(_, _) + | RuntimeError(_) + | RuntimeErrorFunction(_) => Err(body), + + Reset(_, _) | Reuse(_, _) => { + unreachable!("reset/reuse should not have been introduced yet") + } + } +} + +fn free_variables<'a>(initial: &Expr<'a>) -> MutSet { + use Expr::*; + let mut seen = MutSet::default(); + let mut bound = MutSet::default(); + let mut stack = vec![initial]; + + // in other words, variables that are referenced, but not stored + + while let Some(expr) = stack.pop() { + match expr { + FunctionPointer(symbol, _) | Load(symbol) => { + seen.insert(*symbol); + } + Reset(symbol, expr) | Reuse(symbol, expr) => { + seen.insert(*symbol); + stack.push(expr) + } + + Cond { + cond_symbol, + branching_symbol, + pass, + fail, + .. + } => { + seen.insert(*cond_symbol); + seen.insert(*branching_symbol); + + for (symbol, _, expr) in pass.0.iter() { + seen.insert(*symbol); + stack.push(expr) + } + + for (symbol, _, expr) in fail.0.iter() { + seen.insert(*symbol); + stack.push(expr) + } + } + + Switch { + cond, + cond_symbol, + branches, + default_branch, + .. + } => { + stack.push(cond); + seen.insert(*cond_symbol); + + for (_, stores, expr) in branches.iter() { + stack.push(expr); + + for (symbol, _, expr) in stores.iter() { + bound.insert(*symbol); + stack.push(expr) + } + } + + stack.push(default_branch.1); + for (symbol, _, expr) in default_branch.0.iter() { + seen.insert(*symbol); + stack.push(expr) + } + } + + Store(stores, body) => { + for (symbol, _, expr) in stores.iter() { + bound.insert(*symbol); + stack.push(&expr) + } + + stack.push(body) + } + + DecAfter(symbol, body) | Inc(symbol, body) => { + seen.insert(*symbol); + stack.push(body); + } + + CallByName { name, args, .. } => { + seen.insert(*name); + for (expr, _) in args.iter() { + stack.push(expr); + } + } + + CallByPointer(function, args, _) => { + stack.push(function); + stack.extend(args.iter()); + } + + RunLowLevel(_, args) => { + for (expr, _) in args.iter() { + stack.push(expr); + } + } + + Tag { arguments, .. } => { + for (symbol, _) in arguments.iter() { + seen.insert(*symbol); + } + } + + Struct(arguments) => { + for (expr, _) in arguments.iter() { + stack.push(expr); + } + } + + Array { elems, .. } => { + for expr in elems.iter() { + stack.push(expr); + } + } + + AccessAtIndex { expr, .. } => { + stack.push(expr); + } + + Int(_) + | Float(_) + | Str(_) + | Bool(_) + | Byte(_) + | EmptyArray + | RuntimeError(_) + | RuntimeErrorFunction(_) => {} + } + } + + for symbol in bound.iter() { + seen.remove(symbol); + } + + seen +} + +fn symbols_in_expr<'a>(initial: &Expr<'a>) -> MutSet { + use Expr::*; + let mut result = MutSet::default(); + let mut stack = vec![initial]; + + while let Some(expr) = stack.pop() { + match expr { + FunctionPointer(symbol, _) | Load(symbol) => { + result.insert(*symbol); + } + + Reset(symbol, expr) | Reuse(symbol, expr) => { + result.insert(*symbol); + stack.push(expr) + } + + Cond { + cond_symbol, + branching_symbol, + pass, + fail, + .. + } => { + result.insert(*cond_symbol); + result.insert(*branching_symbol); + + for (symbol, _, expr) in pass.0.iter() { + result.insert(*symbol); + stack.push(expr) + } + + for (symbol, _, expr) in fail.0.iter() { + result.insert(*symbol); + stack.push(expr) + } + } + + Switch { + cond, + cond_symbol, + branches, + default_branch, + .. + } => { + stack.push(cond); + result.insert(*cond_symbol); + + for (_, stores, expr) in branches.iter() { + stack.push(expr); + + for (symbol, _, expr) in stores.iter() { + result.insert(*symbol); + stack.push(expr) + } + } + + stack.push(default_branch.1); + for (symbol, _, expr) in default_branch.0.iter() { + result.insert(*symbol); + stack.push(expr) + } + } + + Store(stores, body) => { + for (symbol, _, expr) in stores.iter() { + result.insert(*symbol); + stack.push(&expr) + } + + stack.push(body) + } + + DecAfter(symbol, body) | Inc(symbol, body) => { + result.insert(*symbol); + stack.push(body); + } + + CallByName { name, args, .. } => { + result.insert(*name); + for (expr, _) in args.iter() { + stack.push(expr); + } + } + + CallByPointer(function, args, _) => { + stack.push(function); + stack.extend(args.iter()); + } + + RunLowLevel(_, args) => { + for (expr, _) in args.iter() { + stack.push(expr); + } + } + + Tag { arguments, .. } => { + for (symbol, _) in arguments.iter() { + result.insert(*symbol); + } + } + + Struct(arguments) => { + for (expr, _) in arguments.iter() { + stack.push(expr); + } + } + + Array { elems, .. } => { + for expr in elems.iter() { + stack.push(expr); + } + } + + AccessAtIndex { expr, .. } => { + stack.push(expr); + } + + Int(_) + | Float(_) + | Str(_) + | Bool(_) + | Byte(_) + | EmptyArray + | RuntimeError(_) + | RuntimeErrorFunction(_) => {} + } + } + + result +} + +pub fn function_c<'a>(env: &mut Env<'a, '_>, body: Expr<'a>) -> Expr<'a> { + let fv = free_variables(&body); + + function_c_help(env, body, fv) +} + +pub fn function_c_help<'a>(env: &mut Env<'a, '_>, body: Expr<'a>, fv: MutSet) -> Expr<'a> { + use Expr::*; + + match body { + Tag { arguments, .. } => { + let symbols = arguments + .iter() + .map(|(x, _)| x) + .copied() + .collect::>(); + + function_c_app(env, &symbols, &fv, body) + } + _ => body, + } +} + +fn function_c_app<'a>( + env: &mut Env<'a, '_>, + arguments: &[Symbol], + orig_fv: &MutSet, + mut application: Expr<'a>, +) -> Expr<'a> { + // in the future, this will need to be a check + let is_owned = true; + + for (i, y) in arguments.iter().rev().enumerate() { + if is_owned { + let mut fv = orig_fv.clone(); + fv.extend(arguments[i..].iter().copied()); + + application = insert_increment(env, *y, fv, application) + } else { + unimplemented!("owned references are not implemented yet") + } + } + + application +} + +fn insert_increment<'a>( + env: &mut Env<'a, '_>, + symbol: Symbol, + live_variables: MutSet, + body: Expr<'a>, +) -> Expr<'a> { + // in the future, this will need to be a check + let is_owned = true; + + if is_owned && !live_variables.contains(&symbol) { + body + } else { + Expr::Inc(symbol, env.arena.alloc(body)) + } +} + +fn insert_decrement<'a>(env: &mut Env<'a, '_>, symbols: &[Symbol], mut body: Expr<'a>) -> Expr<'a> { + // in the future, this will need to be a check + let is_owned = true; + let fv = free_variables(&body); + + for symbol in symbols.iter() { + let is_dead = !fv.contains(&symbol); + + if is_owned && is_dead { + body = Expr::DecAfter(*symbol, env.arena.alloc(body)); + } + } + + body +} diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 438f5ec535..4f93ba9716 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -1,6 +1,9 @@ #[macro_use] extern crate pretty_assertions; +#[macro_use] +extern crate indoc; + extern crate bumpalo; extern crate roc_mono; @@ -9,28 +12,22 @@ mod helpers; // Test monomorphization #[cfg(test)] mod test_mono { - use crate::helpers::{can_expr, infer_expr, test_home, CanExprOut}; - use bumpalo::Bump; - use roc_module::symbol::{Interns, Symbol}; - use roc_mono::expr::Expr::{self, *}; - use roc_mono::expr::Procs; - use roc_mono::layout; - use roc_mono::layout::{Builtin, Layout, LayoutCache}; - use roc_types::subs::Subs; - // HELPERS - - const I64_LAYOUT: Layout<'static> = Layout::Builtin(Builtin::Int64); - const F64_LAYOUT: Layout<'static> = Layout::Builtin(Builtin::Float64); - - fn compiles_to(src: &str, expected: Expr<'_>) { - compiles_to_with_interns(src, |_| expected) + // NOTE because the Show instance of module names is different in --release mode, + // these tests would all fail. In the future, when we do interesting optimizations, + // we'll likely want some tests for --release too. + #[cfg(not(debug_assertions))] + fn compiles_to_ir(_src: &str, _expected: &str) { + // just do nothing } - fn compiles_to_with_interns<'a, F>(src: &str, get_expected: F) - where - F: FnOnce(Interns) -> Expr<'a>, - { + #[cfg(debug_assertions)] + fn compiles_to_ir(src: &str, expected: &str) { + use crate::helpers::{can_expr, infer_expr, CanExprOut}; + use bumpalo::Bump; + use roc_mono::layout::LayoutCache; + use roc_types::subs::Subs; + let arena = Bump::new(); let CanExprOut { loc_expr, @@ -47,570 +44,994 @@ mod test_mono { let (_content, mut subs) = infer_expr(subs, &mut unify_problems, &constraint, var); // Compile and add all the Procs before adding main - let mut procs = Procs::default(); + let mut procs = roc_mono::ir::Procs::default(); let mut ident_ids = interns.all_ident_ids.remove(&home).unwrap(); + // Put this module's ident_ids back in the interns + interns.all_ident_ids.insert(home, ident_ids.clone()); + // Populate Procs and Subs, and get the low-level Expr from the canonical Expr let mut mono_problems = Vec::new(); - let mut mono_env = roc_mono::expr::Env { + let mut mono_env = roc_mono::ir::Env { arena: &arena, subs: &mut subs, problems: &mut mono_problems, home, ident_ids: &mut ident_ids, }; - let mono_expr = Expr::new(&mut mono_env, loc_expr.value, &mut procs); - let procs = - roc_mono::expr::specialize_all(&mut mono_env, procs, &mut LayoutCache::default()); + + let mut layout_cache = LayoutCache::default(); + let ir_expr = + roc_mono::ir::from_can(&mut mono_env, loc_expr.value, &mut procs, &mut layout_cache); + + // let mono_expr = Expr::new(&mut mono_env, loc_expr.value, &mut procs); + let procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut LayoutCache::default()); + + // apply inc/dec + let stmt = mono_env.arena.alloc(ir_expr); + let ir_expr = roc_mono::inc_dec::visit_declaration(mono_env.arena, stmt); assert_eq!( procs.runtime_errors, roc_collections::all::MutMap::default() ); - // Put this module's ident_ids back in the interns - interns.all_ident_ids.insert(home, ident_ids); + let mut procs_string = procs + .get_specialized_procs(mono_env.arena) + .values() + .map(|proc| proc.to_pretty(200)) + .collect::>(); - assert_eq!(get_expected(interns), mono_expr); + procs_string.push(ir_expr.to_pretty(200)); + + let result = procs_string.join("\n"); + + let the_same = result == expected; + if !the_same { + println!("{}", result); + } + + assert_eq!(result, expected); } #[test] - fn int_literal() { - compiles_to("5", Int(5)); - } - - #[test] - fn float_literal() { - compiles_to("0.5", Float(0.5)); - } - - #[test] - fn float_addition() { - compiles_to( - "3.0 + 4", - CallByName { - name: Symbol::NUM_ADD, - layout: Layout::FunctionPointer( - &[ - Layout::Builtin(Builtin::Float64), - Layout::Builtin(Builtin::Float64), - ], - &Layout::Builtin(Builtin::Float64), - ), - args: &[ - (Float(3.0), Layout::Builtin(Builtin::Float64)), - (Float(4.0), Layout::Builtin(Builtin::Float64)), - ], - }, - ); - } - - #[test] - fn int_addition() { - compiles_to( - "0xDEADBEEF + 4", - CallByName { - name: Symbol::NUM_ADD, - layout: Layout::FunctionPointer( - &[ - Layout::Builtin(Builtin::Int64), - Layout::Builtin(Builtin::Int64), - ], - &Layout::Builtin(Builtin::Int64), - ), - args: &[ - (Int(3735928559), Layout::Builtin(Builtin::Int64)), - (Int(4), Layout::Builtin(Builtin::Int64)), - ], - }, - ); - } - - #[test] - fn num_addition() { - // Default to Int for `Num *` - compiles_to( - "3 + 5", - CallByName { - name: Symbol::NUM_ADD, - layout: Layout::FunctionPointer( - &[ - Layout::Builtin(Builtin::Int64), - Layout::Builtin(Builtin::Int64), - ], - &Layout::Builtin(Builtin::Int64), - ), - args: &[ - (Int(3), Layout::Builtin(Builtin::Int64)), - (Int(5), Layout::Builtin(Builtin::Int64)), - ], - }, - ); - } - - #[test] - fn specialize_closure() { - compiles_to( + fn ir_int_literal() { + compiles_to_ir( r#" - f = \x -> x + 5 - - { y: f 3.14, x: f 0x4 } + 5 "#, - { - use self::Builtin::*; - let home = test_home(); - let gen_symbol_0 = Interns::from_index(home, 0); - - Struct(&[ - ( - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer( - &[Layout::Builtin(Builtin::Int64)], - &Layout::Builtin(Builtin::Int64), - ), - args: &[(Int(4), Layout::Builtin(Int64))], - }, - Layout::Builtin(Int64), - ), - ( - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer( - &[Layout::Builtin(Builtin::Float64)], - &Layout::Builtin(Builtin::Float64), - ), - args: &[(Float(3.14), Layout::Builtin(Float64))], - }, - Layout::Builtin(Float64), - ), - ]) - }, + indoc!( + r#" + let Test.0 = 5i64; + ret Test.0; + "# + ), ) } #[test] - fn if_expression() { - compiles_to( + fn ir_assignment() { + compiles_to_ir( r#" - if True then "bar" else "foo" + x = 5 + + x "#, - { - use self::Builtin::*; - use Layout::Builtin; - - let home = test_home(); - let gen_symbol_0 = Interns::from_index(home, 0); - - Store( - &[( - gen_symbol_0, - Layout::Builtin(layout::Builtin::Int1), - Expr::Bool(true), - )], - &Cond { - cond_symbol: gen_symbol_0, - branch_symbol: gen_symbol_0, - cond_layout: Builtin(Int1), - pass: (&[] as &[_], &Expr::Str("bar")), - fail: (&[] as &[_], &Expr::Str("foo")), - ret_layout: Builtin(Str), - }, - ) - }, + indoc!( + r#" + let Test.0 = 5i64; + ret Test.0; + "# + ), ) } #[test] - fn multiway_if_expression() { - compiles_to( + fn ir_if() { + compiles_to_ir( + r#" + if True then 1 else 2 + "#, + indoc!( + r#" + let Test.3 = true; + if Test.3 then + let Test.1 = 1i64; + jump Test.2 Test.1; + else + let Test.1 = 2i64; + jump Test.2 Test.1; + joinpoint Test.2 Test.0: + ret Test.0; + "# + ), + ) + } + + #[test] + fn ir_when_enum() { + compiles_to_ir( + r#" + when Blue is + Red -> 1 + White -> 2 + Blue -> 3 + "#, + indoc!( + r#" + let Test.1 = 0u8; + switch Test.1: + case 1: + let Test.3 = 1i64; + jump Test.2 Test.3; + + case 2: + let Test.4 = 2i64; + jump Test.2 Test.4; + + default: + let Test.5 = 3i64; + jump Test.2 Test.5; + + joinpoint Test.2 Test.0: + ret Test.0; + "# + ), + ) + } + + #[test] + fn ir_when_maybe() { + compiles_to_ir( + r#" + when Just 3 is + Just n -> n + Nothing -> 0 + "#, + indoc!( + r#" + let Test.11 = 0i64; + let Test.12 = 3i64; + let Test.2 = Just Test.11 Test.12; + let Test.7 = true; + let Test.9 = Index 0 Test.2; + let Test.8 = 0i64; + let Test.10 = lowlevel Eq Test.8 Test.9; + let Test.6 = lowlevel And Test.10 Test.7; + if Test.6 then + let Test.0 = Index 1 Test.2; + jump Test.3 Test.0; + else + let Test.5 = 0i64; + jump Test.3 Test.5; + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), + ) + } + + #[test] + fn ir_when_these() { + compiles_to_ir( + r#" + when These 1 2 is + This x -> x + That y -> y + These x _ -> x + "#, + indoc!( + r#" + let Test.9 = 1i64; + let Test.10 = 1i64; + let Test.11 = 2i64; + let Test.4 = These Test.9 Test.10 Test.11; + switch Test.4: + case 2: + let Test.0 = Index 1 Test.4; + jump Test.5 Test.0; + + case 0: + let Test.1 = Index 1 Test.4; + jump Test.5 Test.1; + + default: + let Test.2 = Index 1 Test.4; + jump Test.5 Test.2; + + joinpoint Test.5 Test.3: + ret Test.3; + "# + ), + ) + } + + #[test] + fn ir_when_record() { + compiles_to_ir( + r#" + when { x: 1, y: 3.14 } is + { x } -> x + "#, + indoc!( + r#" + let Test.6 = 1i64; + let Test.7 = 3.14f64; + let Test.2 = Struct {Test.6, Test.7}; + let Test.0 = Index 0 Test.2; + jump Test.3 Test.0; + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), + ) + } + + #[test] + fn ir_plus() { + compiles_to_ir( + r#" + 1 + 2 + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.3 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.3; + + let Test.1 = 1i64; + let Test.2 = 2i64; + let Test.0 = CallByName Num.14 Test.1 Test.2; + ret Test.0; + "# + ), + ) + } + + #[test] + fn ir_round() { + compiles_to_ir( + r#" + Num.round 3.6 + "#, + indoc!( + r#" + procedure Num.36 (#Attr.2): + let Test.2 = lowlevel NumRound #Attr.2; + ret Test.2; + + let Test.1 = 3.6f64; + let Test.0 = CallByName Num.36 Test.1; + ret Test.0; + "# + ), + ) + } + + #[test] + fn ir_when_idiv() { + compiles_to_ir( + r#" + when 1000 // 10 is + Ok val -> val + Err _ -> -1 + "#, + indoc!( + r#" + procedure Num.32 (#Attr.2, #Attr.3): + let Test.21 = 0i64; + let Test.18 = lowlevel NotEq #Attr.3 Test.21; + if Test.18 then + let Test.19 = 1i64; + let Test.20 = lowlevel NumDivUnchecked #Attr.2 #Attr.3; + let Test.14 = Ok Test.19 Test.20; + jump Test.15 Test.14; + else + let Test.16 = 0i64; + let Test.17 = Struct {}; + let Test.14 = Err Test.16 Test.17; + jump Test.15 Test.14; + joinpoint Test.15 Test.13: + ret Test.13; + + let Test.11 = 1000i64; + let Test.12 = 10i64; + let Test.2 = CallByName Num.32 Test.11 Test.12; + let Test.7 = true; + let Test.9 = Index 0 Test.2; + let Test.8 = 1i64; + let Test.10 = lowlevel Eq Test.8 Test.9; + let Test.6 = lowlevel And Test.10 Test.7; + if Test.6 then + let Test.0 = Index 1 Test.2; + jump Test.3 Test.0; + else + let Test.5 = -1i64; + jump Test.3 Test.5; + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), + ) + } + + #[test] + fn ir_two_defs() { + compiles_to_ir( + r#" + x = 3 + y = 4 + + x + y + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.3 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.3; + + let Test.1 = 4i64; + let Test.0 = 3i64; + let Test.2 = CallByName Num.14 Test.0 Test.1; + ret Test.2; + "# + ), + ) + } + + #[test] + fn ir_when_just() { + compiles_to_ir( + r#" + x : [ Nothing, Just Int ] + x = Just 41 + + when x is + Just v -> v + 0x1 + Nothing -> 0x1 + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.6 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.6; + + let Test.13 = 0i64; + let Test.14 = 41i64; + let Test.0 = Just Test.13 Test.14; + let Test.9 = true; + let Test.11 = Index 0 Test.0; + let Test.10 = 0i64; + let Test.12 = lowlevel Eq Test.10 Test.11; + let Test.8 = lowlevel And Test.12 Test.9; + if Test.8 then + let Test.1 = Index 1 Test.0; + let Test.5 = 1i64; + let Test.4 = CallByName Num.14 Test.1 Test.5; + jump Test.3 Test.4; + else + let Test.7 = 1i64; + jump Test.3 Test.7; + joinpoint Test.3 Test.2: + ret Test.2; + "# + ), + ) + } + + #[test] + fn one_element_tag() { + compiles_to_ir( + r#" + x : [ Pair Int ] + x = Pair 2 + + x + "#, + indoc!( + r#" + let Test.2 = 2i64; + let Test.0 = Struct {Test.2}; + ret Test.0; + "# + ), + ) + } + + #[test] + fn join_points() { + compiles_to_ir( + r#" + x = + if True then 1 else 2 + + x + "#, + indoc!( + r#" + let Test.4 = true; + if Test.4 then + let Test.2 = 1i64; + jump Test.3 Test.2; + else + let Test.2 = 2i64; + jump Test.3 Test.2; + joinpoint Test.3 Test.0: + ret Test.0; + "# + ), + ) + } + + #[test] + fn guard_pattern_true() { + compiles_to_ir( + r#" + when 2 is + 2 if False -> 42 + _ -> 0 + "#, + indoc!( + r#" + let Test.1 = 2i64; + let Test.8 = true; + let Test.9 = 2i64; + let Test.12 = lowlevel Eq Test.9 Test.1; + let Test.10 = lowlevel And Test.12 Test.8; + let Test.5 = false; + jump Test.4 Test.5; + joinpoint Test.4 Test.11: + let Test.7 = lowlevel And Test.11 Test.10; + if Test.7 then + let Test.3 = 42i64; + jump Test.2 Test.3; + else + let Test.6 = 0i64; + jump Test.2 Test.6; + joinpoint Test.2 Test.0: + ret Test.0; + "# + ), + ) + } + + #[test] + fn when_on_record() { + compiles_to_ir( + r#" + when { x: 0x2 } is + { x } -> x + 3 + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.6 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.6; + + let Test.7 = 2i64; + let Test.2 = Struct {Test.7}; + let Test.0 = Index 0 Test.2; + let Test.5 = 3i64; + let Test.4 = CallByName Num.14 Test.0 Test.5; + jump Test.3 Test.4; + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), + ) + } + + #[test] + fn when_nested_maybe() { + compiles_to_ir( + r#" + Maybe a : [ Nothing, Just a ] + + x : Maybe (Maybe Int) + x = Just (Just 41) + + when x is + Just (Just v) -> v + 0x1 + _ -> 0x1 + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.7 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.7; + + let Test.19 = 0i64; + let Test.21 = 0i64; + let Test.22 = 41i64; + let Test.20 = Just Test.21 Test.22; + let Test.1 = Just Test.19 Test.20; + let Test.11 = true; + let Test.13 = Index 0 Test.1; + let Test.12 = 0i64; + let Test.18 = lowlevel Eq Test.12 Test.13; + let Test.16 = lowlevel And Test.18 Test.11; + let Test.15 = Index 0 Test.1; + let Test.14 = 0i64; + let Test.17 = lowlevel Eq Test.14 Test.15; + let Test.10 = lowlevel And Test.17 Test.16; + if Test.10 then + let Test.8 = Index 1 Test.1; + let Test.2 = Index 1 Test.8; + let Test.6 = 1i64; + let Test.5 = CallByName Num.14 Test.2 Test.6; + jump Test.4 Test.5; + else + let Test.9 = 1i64; + jump Test.4 Test.9; + joinpoint Test.4 Test.3: + ret Test.3; + "# + ), + ) + } + + #[test] + fn when_on_two_values() { + compiles_to_ir( + r#" + when Pair 2 3 is + Pair 4 3 -> 9 + Pair a b -> a + b + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.7 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.7; + + let Test.17 = 2i64; + let Test.18 = 3i64; + let Test.3 = Struct {Test.17, Test.18}; + let Test.9 = true; + let Test.10 = 4i64; + let Test.11 = Index 0 Test.3; + let Test.16 = lowlevel Eq Test.10 Test.11; + let Test.14 = lowlevel And Test.16 Test.9; + let Test.12 = 3i64; + let Test.13 = Index 1 Test.3; + let Test.15 = lowlevel Eq Test.12 Test.13; + let Test.8 = lowlevel And Test.15 Test.14; + if Test.8 then + let Test.5 = 9i64; + jump Test.4 Test.5; + else + let Test.0 = Index 0 Test.3; + let Test.1 = Index 1 Test.3; + let Test.6 = CallByName Num.14 Test.0 Test.1; + jump Test.4 Test.6; + joinpoint Test.4 Test.2: + ret Test.2; + "# + ), + ) + } + + #[test] + fn list_append_closure() { + compiles_to_ir( + r#" + myFunction = \l -> List.append l 42 + + myFunction [ 1, 2 ] + "#, + indoc!( + r#" + procedure List.5 (#Attr.2, #Attr.3): + let Test.7 = lowlevel ListAppend #Attr.2 #Attr.3; + ret Test.7; + + procedure Test.0 (Test.2): + let Test.6 = 42i64; + let Test.5 = CallByName List.5 Test.2 Test.6; + ret Test.5; + + let Test.8 = 1i64; + let Test.9 = 2i64; + let Test.4 = Array [Test.8, Test.9]; + let Test.3 = CallByName Test.0 Test.4; + dec Test.4; + ret Test.3; + "# + ), + ) + } + + #[test] + fn list_append() { + compiles_to_ir( + r#" + List.append [1] 2 + "#, + indoc!( + r#" + procedure List.5 (#Attr.2, #Attr.3): + let Test.3 = lowlevel ListAppend #Attr.2 #Attr.3; + ret Test.3; + + let Test.4 = 1i64; + let Test.1 = Array [Test.4]; + let Test.2 = 2i64; + let Test.0 = CallByName List.5 Test.1 Test.2; + dec Test.1; + ret Test.0; + "# + ), + ) + } + + #[test] + fn list_len() { + compiles_to_ir( + r#" + x = [1,2,3] + y = [ 1.0 ] + + List.len x + List.len y + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.5 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.5; + + procedure List.7 (#Attr.2): + let Test.6 = lowlevel ListLen #Attr.2; + ret Test.6; + + let Test.10 = 1f64; + let Test.1 = Array [Test.10]; + let Test.7 = 1i64; + let Test.8 = 2i64; + let Test.9 = 3i64; + let Test.0 = Array [Test.7, Test.8, Test.9]; + let Test.3 = CallByName List.7 Test.0; + dec Test.0; + let Test.4 = CallByName List.7 Test.1; + dec Test.1; + let Test.2 = CallByName Num.14 Test.3 Test.4; + ret Test.2; + "# + ), + ) + } + + #[test] + fn when_joinpoint() { + compiles_to_ir( + r#" + x : [ Red, White, Blue ] + x = Blue + + y = + when x is + Red -> 1 + White -> 2 + Blue -> 3 + + y + "#, + indoc!( + r#" + let Test.0 = 0u8; + switch Test.0: + case 1: + let Test.4 = 1i64; + jump Test.3 Test.4; + + case 2: + let Test.5 = 2i64; + jump Test.3 Test.5; + + default: + let Test.6 = 3i64; + jump Test.3 Test.6; + + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), + ) + } + + #[test] + fn simple_if() { + compiles_to_ir( r#" if True then - "bar" + 1 + else + 2 + "#, + indoc!( + r#" + let Test.3 = true; + if Test.3 then + let Test.1 = 1i64; + jump Test.2 Test.1; + else + let Test.1 = 2i64; + jump Test.2 Test.1; + joinpoint Test.2 Test.0: + ret Test.0; + "# + ), + ) + } + + #[test] + fn if_multi_branch() { + compiles_to_ir( + r#" + if True then + 1 else if False then - "foo" - else - "baz" + 2 + else + 3 "#, - { - use self::Builtin::*; - use Layout::Builtin; - - let home = test_home(); - let gen_symbol_0 = Interns::from_index(home, 1); - let gen_symbol_1 = Interns::from_index(home, 0); - - Store( - &[( - gen_symbol_0, - Layout::Builtin(layout::Builtin::Int1), - Expr::Bool(true), - )], - &Cond { - cond_symbol: gen_symbol_0, - branch_symbol: gen_symbol_0, - cond_layout: Builtin(Int1), - pass: (&[] as &[_], &Expr::Str("bar")), - fail: ( - &[] as &[_], - &Store( - &[( - gen_symbol_1, - Layout::Builtin(layout::Builtin::Int1), - Expr::Bool(false), - )], - &Cond { - cond_symbol: gen_symbol_1, - branch_symbol: gen_symbol_1, - cond_layout: Builtin(Int1), - pass: (&[] as &[_], &Expr::Str("foo")), - fail: (&[] as &[_], &Expr::Str("baz")), - ret_layout: Builtin(Str), - }, - ), - ), - ret_layout: Builtin(Str), - }, - ) - }, + indoc!( + r#" + let Test.6 = true; + if Test.6 then + let Test.1 = 1i64; + jump Test.2 Test.1; + else + let Test.5 = false; + if Test.5 then + let Test.3 = 2i64; + jump Test.4 Test.3; + else + let Test.3 = 3i64; + jump Test.4 Test.3; + joinpoint Test.4 Test.1: + jump Test.2 Test.1; + joinpoint Test.2 Test.0: + ret Test.0; + "# + ), ) } #[test] - fn annotated_if_expression() { - // an if with an annotation gets constrained differently. Make sure the result is still correct. - compiles_to( + fn when_on_result() { + compiles_to_ir( r#" - x : Str - x = if True then "bar" else "foo" + x : Result Int Int + x = Ok 2 - x + y = + when x is + Ok 3 -> 1 + Ok _ -> 2 + Err _ -> 3 + y "#, - { - use self::Builtin::*; - use Layout::Builtin; - - let home = test_home(); - let gen_symbol_0 = Interns::from_index(home, 1); - let symbol_x = Interns::from_index(home, 0); - - Store( - &[( - symbol_x, - Builtin(Str), - Store( - &[( - gen_symbol_0, - Layout::Builtin(layout::Builtin::Int1), - Expr::Bool(true), - )], - &Cond { - cond_symbol: gen_symbol_0, - branch_symbol: gen_symbol_0, - cond_layout: Builtin(Int1), - pass: (&[] as &[_], &Expr::Str("bar")), - fail: (&[] as &[_], &Expr::Str("foo")), - ret_layout: Builtin(Str), - }, - ), - )], - &Load(symbol_x), - ) - }, + indoc!( + r#" + let Test.17 = 1i64; + let Test.18 = 2i64; + let Test.0 = Ok Test.17 Test.18; + let Test.13 = true; + let Test.15 = Index 0 Test.0; + let Test.14 = 1i64; + let Test.16 = lowlevel Eq Test.14 Test.15; + let Test.12 = lowlevel And Test.16 Test.13; + if Test.12 then + let Test.8 = true; + let Test.9 = 3i64; + let Test.10 = Index 0 Test.0; + let Test.11 = lowlevel Eq Test.9 Test.10; + let Test.7 = lowlevel And Test.11 Test.8; + if Test.7 then + let Test.4 = 1i64; + jump Test.3 Test.4; + else + let Test.5 = 2i64; + jump Test.3 Test.5; + else + let Test.6 = 3i64; + jump Test.3 Test.6; + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), ) } - // #[test] - // fn record_pattern() { - // compiles_to( - // r#" - // \{ x } -> x + 0x5 - // "#, - // { Float(3.45) }, - // ) - // } - // - // #[test] - // fn tag_pattern() { - // compiles_to( - // r#" - // \Foo x -> x + 0x5 - // "#, - // { Float(3.45) }, - // ) - // } - #[test] - fn polymorphic_identity() { - compiles_to( + fn let_with_record_pattern() { + compiles_to_ir( r#" - id = \x -> x + { x } = { x: 0x2, y: 3.14 } - id { x: id 0x4, y: 0.1 } + x "#, - { - let home = test_home(); - - let gen_symbol_0 = Interns::from_index(home, 0); - let struct_layout = Layout::Struct(&[I64_LAYOUT, F64_LAYOUT]); - - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer( - &[struct_layout.clone()], - &struct_layout.clone(), - ), - args: &[( - Struct(&[ - ( - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer(&[I64_LAYOUT], &I64_LAYOUT), - args: &[(Int(4), I64_LAYOUT)], - }, - I64_LAYOUT, - ), - (Float(0.1), F64_LAYOUT), - ]), - struct_layout, - )], - } - }, + indoc!( + r#" + let Test.4 = 2i64; + let Test.5 = 3.14f64; + let Test.3 = Struct {Test.4, Test.5}; + let Test.0 = Index 0 Test.3; + ret Test.0; + "# + ), ) } - // #[test] - // fn list_get_unique() { - // compiles_to( - // r#" - // unique = [ 2, 4 ] - - // List.get unique 1 - // "#, - // { - // use self::Builtin::*; - // let home = test_home(); - - // let gen_symbol_0 = Interns::from_index(home, 0); - // let list_layout = Layout::Builtin(Builtin::List(&I64_LAYOUT)); - - // CallByName { - // name: gen_symbol_0, - // layout: Layout::FunctionPointer(&[list_layout.clone()], &list_layout.clone()), - // args: &[( - // Struct(&[( - // CallByName { - // name: gen_symbol_0, - // layout: Layout::FunctionPointer( - // &[Layout::Builtin(Builtin::Int64)], - // &Layout::Builtin(Builtin::Int64), - // ), - // args: &[(Int(4), Layout::Builtin(Int64))], - // }, - // Layout::Builtin(Int64), - // )]), - // Layout::Struct(&[Layout::Builtin(Int64)]), - // )], - // } - // }, - // ) - // } - - // needs LetRec to be converted to mono - // #[test] - // fn polymorphic_recursive() { - // compiles_to( - // r#" - // f = \x -> - // when x < 10 is - // True -> f (x + 1) - // False -> x - // - // { x: f 0x4, y: f 3.14 } - // "#, - // { - // use self::Builtin::*; - // use Layout::Builtin; - // let home = test_home(); - // - // let gen_symbol_3 = Interns::from_index(home, 3); - // let gen_symbol_4 = Interns::from_index(home, 4); - // - // Float(3.4) - // - // }, - // ) - // } - - // needs layout for non-empty tag union - // #[test] - // fn is_nil() { - // let arena = Bump::new(); - // - // compiles_to_with_interns( - // r#" - // LinkedList a : [ Cons a (LinkedList a), Nil ] - // - // isNil : LinkedList a -> Bool - // isNil = \list -> - // when list is - // Nil -> True - // Cons _ _ -> False - // - // listInt : LinkedList Int - // listInt = Nil - // - // isNil listInt - // "#, - // |interns| { - // let home = test_home(); - // let var_is_nil = interns.symbol(home, "isNil".into()); - // }, - // ); - // } - #[test] - fn bool_literal() { - let arena = Bump::new(); - - compiles_to_with_interns( + fn let_with_record_pattern_list() { + compiles_to_ir( r#" - x : Bool - x = True - - x - "#, - |interns| { - let home = test_home(); - let var_x = interns.symbol(home, "x".into()); - - let stores = [(var_x, Layout::Builtin(Builtin::Int1), Bool(true))]; - - let load = Load(var_x); - - Store(arena.alloc(stores), arena.alloc(load)) - }, - ); - } - - #[test] - fn two_element_enum() { - let arena = Bump::new(); - - compiles_to_with_interns( - r#" - x : [ Yes, No ] - x = No + { x } = { x: [ 1, 3, 4 ], y: 3.14 } x "#, - |interns| { - let home = test_home(); - let var_x = interns.symbol(home, "x".into()); - - let stores = [(var_x, Layout::Builtin(Builtin::Int1), Bool(false))]; - - let load = Load(var_x); - - Store(arena.alloc(stores), arena.alloc(load)) - }, - ); + indoc!( + r#" + let Test.6 = 1i64; + let Test.7 = 3i64; + let Test.8 = 4i64; + let Test.4 = Array [Test.6, Test.7, Test.8]; + let Test.5 = 3.14f64; + let Test.3 = Struct {Test.4, Test.5}; + let Test.0 = Index 0 Test.3; + inc Test.0; + dec Test.3; + ret Test.0; + "# + ), + ) } #[test] - fn three_element_enum() { - let arena = Bump::new(); + fn if_guard_bind_variable_false() { + compiles_to_ir( + indoc!( + r#" + when 10 is + x if x == 5 -> 0 + _ -> 42 + "# + ), + indoc!( + r#" + procedure Bool.5 (#Attr.2, #Attr.3): + let Test.8 = lowlevel Eq #Attr.2 #Attr.3; + ret Test.8; - compiles_to_with_interns( - r#" - # this test is brought to you by fruits.com! - x : [ Apple, Orange, Banana ] - x = Orange - - x - "#, - |interns| { - let home = test_home(); - let var_x = interns.symbol(home, "x".into()); - - // orange gets index (and therefore tag_id) 1 - let stores = [(var_x, Layout::Builtin(Builtin::Int8), Byte(2))]; - - let load = Load(var_x); - - Store(arena.alloc(stores), arena.alloc(load)) - }, - ); + let Test.2 = 10i64; + let Test.11 = true; + let Test.7 = 5i64; + let Test.6 = CallByName Bool.5 Test.2 Test.7; + jump Test.5 Test.6; + joinpoint Test.5 Test.12: + let Test.10 = lowlevel And Test.12 Test.11; + if Test.10 then + let Test.4 = 0i64; + jump Test.3 Test.4; + else + let Test.9 = 42i64; + jump Test.3 Test.9; + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), + ) } #[test] - fn set_unique_int_list() { - compiles_to("List.get (List.set [ 12, 9, 7, 3 ] 1 42) 1", { - CallByName { - name: Symbol::LIST_GET, - layout: Layout::FunctionPointer( - &[Layout::Builtin(Builtin::List(&I64_LAYOUT)), I64_LAYOUT], - &Layout::Union(&[&[I64_LAYOUT], &[I64_LAYOUT, I64_LAYOUT]]), - ), - args: &vec![ - ( - CallByName { - name: Symbol::LIST_SET, - layout: Layout::FunctionPointer( - &[ - Layout::Builtin(Builtin::List(&I64_LAYOUT)), - I64_LAYOUT, - I64_LAYOUT, - ], - &Layout::Builtin(Builtin::List(&I64_LAYOUT)), - ), - args: &vec![ - ( - Array { - elem_layout: I64_LAYOUT, - elems: &vec![Int(12), Int(9), Int(7), Int(3)], - }, - Layout::Builtin(Builtin::List(&I64_LAYOUT)), - ), - (Int(1), I64_LAYOUT), - (Int(42), I64_LAYOUT), - ], - }, - Layout::Builtin(Builtin::List(&I64_LAYOUT)), - ), - (Int(1), I64_LAYOUT), - ], - } - }); + fn alias_variable() { + compiles_to_ir( + indoc!( + r#" + x = 5 + y = x + + 3 + "# + ), + indoc!( + r#" + let Test.0 = 5i64; + ret Test.0; + "# + ), + ); + + compiles_to_ir( + indoc!( + r#" + x = 5 + y = x + + y + "# + ), + indoc!( + r#" + let Test.0 = 5i64; + ret Test.0; + "# + ), + ) } - // #[test] - // fn when_on_result() { - // compiles_to( - // r#" - // when 1 is - // 1 -> 12 - // _ -> 34 - // "#, - // { - // use self::Builtin::*; - // use Layout::Builtin; - // let home = test_home(); - // - // let gen_symbol_3 = Interns::from_index(home, 3); - // let gen_symbol_4 = Interns::from_index(home, 4); - // - // CallByName( - // gen_symbol_3, - // &[( - // Struct(&[( - // CallByName(gen_symbol_4, &[(Int(4), Builtin(Int64))]), - // Builtin(Int64), - // )]), - // Layout::Struct(&[("x".into(), Builtin(Int64))]), - // )], - // ) - // }, - // ) - // } + #[test] + fn branch_store_variable() { + compiles_to_ir( + indoc!( + r#" + when 0 is + 1 -> 12 + a -> a + "# + ), + indoc!( + r#" + let Test.2 = 0i64; + let Test.7 = true; + let Test.8 = 1i64; + let Test.9 = lowlevel Eq Test.8 Test.2; + let Test.6 = lowlevel And Test.9 Test.7; + if Test.6 then + let Test.4 = 12i64; + jump Test.3 Test.4; + else + jump Test.3 Test.2; + joinpoint Test.3 Test.1: + ret Test.1; + "# + ), + ) + } + + #[test] + fn list_pass_to_function() { + compiles_to_ir( + indoc!( + r#" + x : List Int + x = [1,2,3] + + id : List Int -> List Int + id = \y -> List.set y 0 0 + + id x + "# + ), + indoc!( + r#" + procedure List.4 (#Attr.2, #Attr.3, #Attr.4): + let Test.12 = lowlevel ListLen #Attr.2; + let Test.11 = lowlevel NumLt #Attr.3 Test.12; + if Test.11 then + let Test.9 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4; + jump Test.10 Test.9; + else + jump Test.10 #Attr.2; + joinpoint Test.10 Test.8: + ret Test.8; + + procedure Test.1 (Test.3): + let Test.6 = 0i64; + let Test.7 = 0i64; + let Test.5 = CallByName List.4 Test.3 Test.6 Test.7; + ret Test.5; + + let Test.13 = 1i64; + let Test.14 = 2i64; + let Test.15 = 3i64; + let Test.0 = Array [Test.13, Test.14, Test.15]; + let Test.4 = CallByName Test.1 Test.0; + dec Test.0; + ret Test.4; + "# + ), + ) + } } diff --git a/compiler/problem/src/can.rs b/compiler/problem/src/can.rs index a85e65b2bb..523ccd965c 100644 --- a/compiler/problem/src/can.rs +++ b/compiler/problem/src/can.rs @@ -123,6 +123,8 @@ pub enum RuntimeError { InvalidInt(IntErrorKind, Base, Region, Box), CircularDef(Vec, Vec<(Region /* pattern */, Region /* expr */)>), + NonExhaustivePattern, + /// When the author specifies a type annotation but no implementation NoImplementation, } diff --git a/compiler/reporting/src/error/canonicalize.rs b/compiler/reporting/src/error/canonicalize.rs index 5079b00850..918b6ebb45 100644 --- a/compiler/reporting/src/error/canonicalize.rs +++ b/compiler/reporting/src/error/canonicalize.rs @@ -525,6 +525,9 @@ fn pretty_runtime_error<'b>( alloc.reflow("Only variables can be updated with record update syntax."), ]), RuntimeError::NoImplementation => todo!("no implementation, unreachable"), + RuntimeError::NonExhaustivePattern => { + unreachable!("not currently reported (but can blow up at runtime)") + } } } diff --git a/compiler/reporting/src/error/mono.rs b/compiler/reporting/src/error/mono.rs index 65dc625652..2f69644598 100644 --- a/compiler/reporting/src/error/mono.rs +++ b/compiler/reporting/src/error/mono.rs @@ -5,11 +5,11 @@ use ven_pretty::DocAllocator; pub fn mono_problem<'b>( alloc: &'b RocDocAllocator<'b>, filename: PathBuf, - problem: roc_mono::expr::MonoProblem, + problem: roc_mono::ir::MonoProblem, ) -> Report<'b> { - use roc_mono::expr::MonoProblem::*; - use roc_mono::pattern::Context::*; - use roc_mono::pattern::Error::*; + use roc_mono::exhaustive::Context::*; + use roc_mono::exhaustive::Error::*; + use roc_mono::ir::MonoProblem::*; match problem { PatternProblem(Incomplete(region, context, missing)) => match context { @@ -111,7 +111,7 @@ pub fn mono_problem<'b>( pub fn unhandled_patterns_to_doc_block<'b>( alloc: &'b RocDocAllocator<'b>, - patterns: Vec, + patterns: Vec, ) -> RocDocBuilder<'b> { alloc .vcat(patterns.into_iter().map(|v| pattern_to_doc(alloc, v))) @@ -121,19 +121,19 @@ pub fn unhandled_patterns_to_doc_block<'b>( fn pattern_to_doc<'b>( alloc: &'b RocDocAllocator<'b>, - pattern: roc_mono::pattern::Pattern, + pattern: roc_mono::exhaustive::Pattern, ) -> RocDocBuilder<'b> { pattern_to_doc_help(alloc, pattern, false) } fn pattern_to_doc_help<'b>( alloc: &'b RocDocAllocator<'b>, - pattern: roc_mono::pattern::Pattern, + pattern: roc_mono::exhaustive::Pattern, in_type_param: bool, ) -> RocDocBuilder<'b> { - use roc_mono::pattern::Literal::*; - use roc_mono::pattern::Pattern::*; - use roc_mono::pattern::RenderAs; + use roc_mono::exhaustive::Literal::*; + use roc_mono::exhaustive::Pattern::*; + use roc_mono::exhaustive::RenderAs; match pattern { Anything => alloc.text("_"), diff --git a/compiler/reporting/tests/test_reporting.rs b/compiler/reporting/tests/test_reporting.rs index f9887083ca..986494edc6 100644 --- a/compiler/reporting/tests/test_reporting.rs +++ b/compiler/reporting/tests/test_reporting.rs @@ -12,7 +12,7 @@ mod test_reporting { use crate::helpers::test_home; use bumpalo::Bump; use roc_module::symbol::{Interns, ModuleId}; - use roc_mono::expr::{Expr, Procs}; + use roc_mono::ir::{Procs, Stmt}; use roc_reporting::report::{ can_problem, mono_problem, parse_problem, type_problem, Report, BLUE_CODE, BOLD_CODE, CYAN_CODE, DEFAULT_PALETTE, GREEN_CODE, MAGENTA_CODE, RED_CODE, RESET_CODE, UNDERLINE_CODE, @@ -47,7 +47,7 @@ mod test_reporting { ( Vec, Vec, - Vec, + Vec, ModuleId, Interns, ), @@ -87,14 +87,14 @@ mod test_reporting { let mut ident_ids = interns.all_ident_ids.remove(&home).unwrap(); // Populate Procs and Subs, and get the low-level Expr from the canonical Expr - let mut mono_env = roc_mono::expr::Env { + let mut mono_env = roc_mono::ir::Env { arena: &arena, subs: &mut subs, problems: &mut mono_problems, home, ident_ids: &mut ident_ids, }; - let _mono_expr = Expr::new(&mut mono_env, loc_expr.value, &mut procs); + let _mono_expr = Stmt::new(&mut mono_env, loc_expr.value, &mut procs); } Ok((unify_problems, can_problems, mono_problems, home, interns)) diff --git a/compiler/uniq/src/sharing.rs b/compiler/uniq/src/sharing.rs index 2300785f1a..7caa549f91 100644 --- a/compiler/uniq/src/sharing.rs +++ b/compiler/uniq/src/sharing.rs @@ -606,6 +606,7 @@ impl VarUsage { ); closure_signatures.insert(Symbol::LIST_IS_EMPTY, vec![Usage::Simple(Mark::Seen)]); + closure_signatures.insert(Symbol::LIST_LEN, vec![Usage::Simple(Mark::Seen)]); closure_signatures.insert( Symbol::LIST_SET, diff --git a/examples/quicksort/host.rs b/examples/quicksort/host.rs index 9dc9305700..54d19892e4 100644 --- a/examples/quicksort/host.rs +++ b/examples/quicksort/host.rs @@ -9,4 +9,10 @@ pub fn main() { let list = unsafe { list_from_roc() }; println!("Roc quicksort says: {:?}", list); + + // the pointer is to the first _element_ of the list, + // but the refcount precedes it. Thus calling free() on + // this pointer would segfault/cause badness. Therefore, we + // leak it for now + Box::leak(list); }