diff --git a/cli/src/repl.rs b/cli/src/repl.rs index f56a8f1256..ef9d733ccd 100644 --- a/cli/src/repl.rs +++ b/cli/src/repl.rs @@ -258,8 +258,13 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S }; let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); - let main_body = - roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); + + let param_map = roc_mono::borrow::ParamMap::default(); + let main_body = roc_mono::inc_dec::visit_declaration( + mono_env.arena, + mono_env.arena.alloc(param_map), + mono_env.arena.alloc(main_body), + ); let mut headers = { let num_headers = match &procs.pending_specializations { Some(map) => map.len(), diff --git a/compiler/build/src/program.rs b/compiler/build/src/program.rs index 61c980e2ce..223b862e32 100644 --- a/compiler/build/src/program.rs +++ b/compiler/build/src/program.rs @@ -158,7 +158,7 @@ pub fn gen( pattern_symbols: bumpalo::collections::Vec::new_in( mono_env.arena, ), - is_tail_recursive: false, + is_self_recursive: false, body, }; diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 6eee7eecbf..fcf814c1c6 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -37,6 +37,9 @@ const PRINT_FN_VERIFICATION_OUTPUT: bool = true; #[cfg(not(debug_assertions))] const PRINT_FN_VERIFICATION_OUTPUT: bool = false; +pub const REFCOUNT_0: usize = std::usize::MAX; +pub const REFCOUNT_1: usize = REFCOUNT_0 - 1; + #[derive(Debug, Clone, Copy)] pub enum OptLevel { Normal, @@ -904,7 +907,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( match layout { Layout::Builtin(Builtin::List(MemoryMode::Refcounted, _)) => { - increment_refcount_list(env, value.into_struct_value()); + increment_refcount_list(env, parent, value.into_struct_value()); build_exp_stmt(env, layout_ids, scope, parent, cont) } _ => build_exp_stmt(env, layout_ids, scope, parent, cont), @@ -929,11 +932,7 @@ fn refcount_is_one_comparison<'ctx>( context: &'ctx Context, refcount: IntValue<'ctx>, ) -> IntValue<'ctx> { - let refcount_one: IntValue<'ctx> = context.i64_type().const_int((std::usize::MAX) as _, false); - // Note: Check for refcount < refcount_1 as the "true" condition, - // to avoid misprediction. (In practice this should usually pass, - // and CPUs generally default to predicting that a forward jump - // shouldn't be taken; that is, they predict "else" won't be taken.) + let refcount_one: IntValue<'ctx> = context.i64_type().const_int(REFCOUNT_1 as _, false); builder.build_int_compare( IntPredicate::EQ, refcount, @@ -998,6 +997,7 @@ fn decrement_refcount_layout<'a, 'ctx, 'env>( } } } + RecursiveUnion(_) => todo!("TODO implement decrement layout of recursive tag union"), Union(tags) => { debug_assert!(!tags.is_empty()); let wrapper_struct = value.into_struct_value(); @@ -1086,11 +1086,29 @@ fn decrement_refcount_builtin<'a, 'ctx, 'env>( fn increment_refcount_list<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, original_wrapper: StructValue<'ctx>, ) { let builder = env.builder; let ctx = env.context; + let len = list_len(builder, original_wrapper); + + let is_non_empty = builder.build_int_compare( + IntPredicate::UGT, + len, + ctx.i64_type().const_zero(), + "len > 0", + ); + + // build blocks + let increment_block = ctx.append_basic_block(parent, "increment_block"); + let cont_block = ctx.append_basic_block(parent, "after_increment_block"); + + builder.build_conditional_branch(is_non_empty, increment_block, cont_block); + + builder.position_at_end(increment_block); + let refcount_ptr = list_get_refcount_ptr(env, original_wrapper); let refcount = env @@ -1107,6 +1125,9 @@ fn increment_refcount_list<'a, 'ctx, 'env>( // Mutate the new array in-place to change the element. builder.build_store(refcount_ptr, decremented); + builder.build_unconditional_branch(cont_block); + + builder.position_at_end(cont_block); } fn decrement_refcount_list<'a, 'ctx, 'env>( @@ -1117,6 +1138,30 @@ fn decrement_refcount_list<'a, 'ctx, 'env>( let builder = env.builder; let ctx = env.context; + // the block we'll always jump to when we're done + let cont_block = ctx.append_basic_block(parent, "after_decrement_block"); + let decrement_block = ctx.append_basic_block(parent, "decrement_block"); + + // currently, an empty list has a null-pointer in its length is 0 + // so we must first check the length + + let len = list_len(builder, original_wrapper); + let is_non_empty = builder.build_int_compare( + IntPredicate::UGT, + len, + ctx.i64_type().const_zero(), + "len > 0", + ); + + // if the length is 0, we're done and jump to the continuation block + // otherwise, actually read and check the refcount + builder.build_conditional_branch(is_non_empty, decrement_block, cont_block); + builder.position_at_end(decrement_block); + + // build blocks + let then_block = ctx.append_basic_block(parent, "then"); + let else_block = ctx.append_basic_block(parent, "else"); + let refcount_ptr = list_get_refcount_ptr(env, original_wrapper); let refcount = env @@ -1126,16 +1171,24 @@ fn decrement_refcount_list<'a, 'ctx, 'env>( let comparison = refcount_is_one_comparison(builder, env.context, refcount); - // build blocks - let then_block = ctx.append_basic_block(parent, "then"); - let else_block = ctx.append_basic_block(parent, "else"); - let cont_block = ctx.append_basic_block(parent, "dec_ref_branchcont"); - + // TODO what would be most optimial for the branch predictor + // + // are most refcounts 1 most of the time? or not? builder.build_conditional_branch(comparison, then_block, else_block); // build then block { builder.position_at_end(then_block); + if !env.leak { + let free = builder.build_free(refcount_ptr); + builder.insert_instruction(&free, None); + } + builder.build_unconditional_branch(cont_block); + } + + // build else block + { + builder.position_at_end(else_block); // our refcount 0 is actually usize::MAX, so decrementing the refcount means incrementing this value. let decremented = env.builder.build_int_add( ctx.i64_type().const_int(1 as u64, false), @@ -1149,16 +1202,6 @@ fn decrement_refcount_list<'a, 'ctx, 'env>( builder.build_unconditional_branch(cont_block); } - // build else block - { - builder.position_at_end(else_block); - if !env.leak { - let free = builder.build_free(refcount_ptr); - builder.insert_instruction(&free, None); - } - builder.build_unconditional_branch(cont_block); - } - // emit merge block builder.position_at_end(cont_block); } @@ -1804,14 +1847,9 @@ fn run_low_level<'a, 'ctx, 'env>( list_get_unsafe(env, list_layout, elem_index, wrapper_struct) } - ListSet => { + ListSetInPlace => { let (list_symbol, list_layout) = load_symbol_and_layout(env, scope, &args[0]); - let in_place = match &list_layout { - Layout::Builtin(Builtin::List(MemoryMode::Unique, _)) => InPlace::InPlace, - _ => InPlace::Clone, - }; - list_set( parent, &[ @@ -1820,19 +1858,57 @@ fn run_low_level<'a, 'ctx, 'env>( (load_symbol_and_layout(env, scope, &args[2])), ], env, - in_place, + InPlace::InPlace, ) } - ListSetInPlace => list_set( - parent, - &[ - (load_symbol_and_layout(env, scope, &args[0])), + ListSet => { + let (list_symbol, list_layout) = load_symbol_and_layout(env, scope, &args[0]); + + let arguments = &[ + (list_symbol, list_layout), (load_symbol_and_layout(env, scope, &args[1])), (load_symbol_and_layout(env, scope, &args[2])), - ], - env, - InPlace::InPlace, - ), + ]; + + match list_layout { + Layout::Builtin(Builtin::List(MemoryMode::Unique, _)) => { + // the layout tells us this List.set can be done in-place + list_set(parent, arguments, env, InPlace::InPlace) + } + Layout::Builtin(Builtin::List(MemoryMode::Refcounted, _)) => { + // no static guarantees, but all is not lost: we can check the refcount + // if it is one, we hold the final reference, and can mutate it in-place! + let builder = env.builder; + let ctx = env.context; + + let ret_type = + basic_type_from_layout(env.arena, ctx, list_layout, env.ptr_bytes); + + let refcount_ptr = list_get_refcount_ptr(env, list_symbol.into_struct_value()); + + let refcount = env + .builder + .build_load(refcount_ptr, "get_refcount") + .into_int_value(); + + let comparison = refcount_is_one_comparison(builder, env.context, refcount); + + // build then block + // refcount is 1, so work in-place + let build_pass = || list_set(parent, arguments, env, InPlace::InPlace); + + // build else block + // refcount != 1, so clone first + let build_fail = || list_set(parent, arguments, env, InPlace::Clone); + + crate::llvm::build_list::build_basic_phi2( + env, parent, comparison, build_pass, build_fail, ret_type, + ) + } + Layout::Builtin(Builtin::EmptyList) => list_symbol, + other => unreachable!("List.set: weird layout {:?}", other), + } + } } } diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index aa1a44dadf..b52f47b56a 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -207,10 +207,7 @@ pub fn list_prepend<'a, 'ctx, 'env>( let ptr_bytes = env.ptr_bytes; // Allocate space for the new array that we'll copy into. - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let clone_ptr = builder - .build_array_malloc(elem_type, new_list_len, "list_ptr") - .unwrap(); + let clone_ptr = allocate_list(env, elem_layout, new_list_len); let int_type = ptr_int(ctx, ptr_bytes); let ptr_as_int = builder.build_ptr_to_int(clone_ptr, int_type, "list_cast_ptr"); @@ -355,9 +352,7 @@ pub fn list_join<'a, 'ctx, 'env>( .build_load(list_len_sum_alloca, list_len_sum_name) .into_int_value(); - let final_list_ptr = builder - .build_array_malloc(elem_type, final_list_sum, "final_list_sum") - .unwrap(); + let final_list_ptr = allocate_list(env, elem_layout, final_list_sum); let dest_elem_ptr_alloca = builder.build_alloca(elem_ptr_type, "dest_elem"); @@ -1375,9 +1370,12 @@ pub fn allocate_list<'a, 'ctx, 'env>( "make ptr", ); - // put our "refcount 0" in the first slot - let ref_count_zero = ctx.i64_type().const_int(std::usize::MAX as u64, false); - builder.build_store(refcount_ptr, ref_count_zero); + // the refcount of a new list is initially 1 + // we assume that the list is indeed used (dead variables are eliminated) + let ref_count_one = ctx + .i64_type() + .const_int(crate::llvm::build::REFCOUNT_1 as _, false); + builder.build_store(refcount_ptr, ref_count_one); list_element_ptr } diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index 65a6b5793a..68c283a1bc 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -107,6 +107,7 @@ pub fn basic_type_from_layout<'ctx>( .struct_type(field_types.into_bump_slice(), false) .as_basic_type_enum() } + RecursiveUnion(_) => todo!("TODO implement layout of recursive tag union"), Union(_) => { // TODO make this dynamic let ptr_size = std::mem::size_of::(); diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index 151c4a8543..b89f67fd52 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -210,7 +210,28 @@ mod gen_list { } #[test] - fn list_concat() { + fn foobarbaz() { + assert_evals_to!( + indoc!( + r#" + firstList : List Int + firstList = + [] + + secondList : List Int + secondList = + [] + + List.concat firstList secondList + "# + ), + &[], + &'static [i64] + ); + } + + #[test] + fn list_concat_vanilla() { assert_evals_to!("List.concat [] []", &[], &'static [i64]); assert_evals_to!( @@ -516,7 +537,7 @@ mod gen_list { assert_evals_to!( indoc!( r#" - shared = [ 2.1, 4.3 ] + main = \shared -> # This should not mutate the original x = @@ -530,6 +551,8 @@ mod gen_list { Err _ -> 0 { x, y } + + main [ 2.1, 4.3 ] "# ), (7.7, 4.3), @@ -542,6 +565,7 @@ mod gen_list { assert_evals_to!( indoc!( r#" + main = \{} -> shared = [ 2, 4 ] # This List.set is out of bounds, and should have no effect @@ -556,6 +580,8 @@ mod gen_list { Err _ -> 0 { x, y } + + main {} "# ), (4, 4), diff --git a/compiler/gen/tests/gen_num.rs b/compiler/gen/tests/gen_num.rs index e460e42340..e329cc7a92 100644 --- a/compiler/gen/tests/gen_num.rs +++ b/compiler/gen/tests/gen_num.rs @@ -482,9 +482,12 @@ mod gen_num { assert_evals_to!( indoc!( r#" - when 10 is - x if x == 5 -> 0 - _ -> 42 + main = \{} -> + when 10 is + x if x == 5 -> 0 + _ -> 42 + + main {} "# ), 42, @@ -497,9 +500,12 @@ mod gen_num { assert_evals_to!( indoc!( r#" - when 10 is - x if x == 10 -> 42 - _ -> 0 + main = \{} -> + when 10 is + x if x == 10 -> 42 + _ -> 0 + + main {} "# ), 42, diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 97f69c56ca..6c310b5561 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -283,7 +283,10 @@ mod gen_primitives { assert_evals_to!( indoc!( r#" + main = \{} -> (\a -> a) 5 + + main {} "# ), 5, @@ -296,11 +299,14 @@ mod gen_primitives { assert_evals_to!( indoc!( r#" + main = \{} -> alwaysFloatIdentity : Int -> (Float -> Float) alwaysFloatIdentity = \num -> (\a -> a) (alwaysFloatIdentity 2) 3.14 + + main {} "# ), 3.14, diff --git a/compiler/gen/tests/gen_tags.rs b/compiler/gen/tests/gen_tags.rs index 28d93a7668..67bf2e8830 100644 --- a/compiler/gen/tests/gen_tags.rs +++ b/compiler/gen/tests/gen_tags.rs @@ -455,9 +455,12 @@ mod gen_tags { assert_evals_to!( indoc!( r#" - when 2 is - 2 if False -> 0 - _ -> 42 + main = \{} -> + when 2 is + 2 if False -> 0 + _ -> 42 + + main {} "# ), 42, @@ -470,9 +473,12 @@ mod gen_tags { assert_evals_to!( indoc!( r#" - when 2 is - 2 if True -> 42 - _ -> 0 + main = \{} -> + when 2 is + 2 if True -> 42 + _ -> 0 + + main {} "# ), 42, @@ -485,9 +491,12 @@ mod gen_tags { assert_evals_to!( indoc!( r#" - when 2 is - _ if False -> 0 - _ -> 42 + main = \{} -> + when 2 is + _ if False -> 0 + _ -> 42 + + main {} "# ), 42, @@ -665,16 +674,19 @@ mod gen_tags { assert_evals_to!( indoc!( r#" - x : [ Red, White, Blue ] - x = Blue + main = \{} -> + x : [ Red, White, Blue ] + x = Blue - y = - when x is - Red -> 1 - White -> 2 - Blue -> 3.1 + y = + when x is + Red -> 1 + White -> 2 + Blue -> 3.1 - y + y + + main {} "# ), 3.1, @@ -687,13 +699,16 @@ mod gen_tags { assert_evals_to!( indoc!( r#" - y = - when 1 + 2 is - 3 -> 3 - 1 -> 1 - _ -> 0 + main = \{} -> + y = + when 1 + 2 is + 3 -> 3 + 1 -> 1 + _ -> 0 - y + y + + main {} "# ), 3, diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index c80ee80f07..42834c10f9 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -106,8 +106,6 @@ pub fn helper_without_uniqueness<'a>( }; let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); - let main_body = - roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); let mut headers = { let num_headers = match &procs.pending_specializations { @@ -125,6 +123,13 @@ pub fn helper_without_uniqueness<'a>( roc_collections::all::MutMap::default() ); + let (mut procs, param_map) = procs.get_specialized_procs_help(mono_env.arena); + let main_body = roc_mono::inc_dec::visit_declaration( + mono_env.arena, + param_map, + mono_env.arena.alloc(main_body), + ); + // Put this module's ident_ids back in the interns, so we can use them in env. // This must happen *after* building the headers, because otherwise there's // a conflicting mutable borrow on ident_ids. @@ -133,8 +138,7 @@ pub fn helper_without_uniqueness<'a>( // Add all the Proc headers to the module. // We have to do this in a separate pass first, // because their bodies may reference each other. - - for ((symbol, layout), proc) in procs.get_specialized_procs(env.arena).drain() { + for ((symbol, layout), proc) in procs.drain() { let fn_val = build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); headers.push((proc, fn_val)); @@ -296,8 +300,6 @@ pub fn helper_with_uniqueness<'a>( }; let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); - let main_body = - roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); let mut headers = { let num_headers = match &procs.pending_specializations { Some(map) => map.len(), @@ -314,6 +316,13 @@ pub fn helper_with_uniqueness<'a>( roc_collections::all::MutMap::default() ); + let (mut procs, param_map) = procs.get_specialized_procs_help(mono_env.arena); + let main_body = roc_mono::inc_dec::visit_declaration( + mono_env.arena, + param_map, + mono_env.arena.alloc(main_body), + ); + // Put this module's ident_ids back in the interns, so we can use them in env. // This must happen *after* building the headers, because otherwise there's // a conflicting mutable borrow on ident_ids. @@ -322,7 +331,7 @@ pub fn helper_with_uniqueness<'a>( // Add all the Proc headers to the module. // We have to do this in a separate pass first, // because their bodies may reference each other. - for ((symbol, layout), proc) in procs.get_specialized_procs(env.arena).drain() { + for ((symbol, layout), proc) in procs.drain() { let fn_val = build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); headers.push((proc, fn_val)); diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs new file mode 100644 index 0000000000..553e75f321 --- /dev/null +++ b/compiler/mono/src/borrow.rs @@ -0,0 +1,497 @@ +use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt}; +use crate::layout::Layout; +use bumpalo::collections::Vec; +use bumpalo::Bump; +use roc_collections::all::{MutMap, MutSet}; +use roc_module::low_level::LowLevel; +use roc_module::symbol::Symbol; + +pub fn infer_borrow<'a>( + arena: &'a Bump, + procs: &MutMap<(Symbol, Layout<'a>), Proc<'a>>, +) -> ParamMap<'a> { + let mut param_map = ParamMap { + items: MutMap::default(), + }; + + for proc in procs.values() { + param_map.visit_proc(arena, proc); + } + + let mut env = BorrowInfState { + current_proc: Symbol::ATTR_ATTR, + param_set: MutSet::default(), + owned: MutMap::default(), + modified: false, + param_map, + arena, + }; + + // This is a fixed-point analysis + // + // all functions initiall own all their paramters + // through a series of checks and heuristics, some arguments are set to borrowed + // when that doesn't lead to conflicts the change is kept, otherwise it may be reverted + // + // when the signatures no longer change, the analysis stops and returns the signatures + loop { + // sort the symbols (roughly) in definition order. + // TODO in the future I think we need to do this properly, and group + // mutually recursive functions (or just make all their arguments owned) + + for proc in procs.values() { + env.collect_proc(proc); + } + + if !env.modified { + // if there were no modifications, we're done + break; + } else { + // otherwise see if there are changes after another iteration + env.modified = false; + } + } + + env.param_map +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +enum Key { + Declaration(Symbol), + JoinPoint(JoinPointId), +} + +#[derive(Debug, Clone, Default)] +pub struct ParamMap<'a> { + items: MutMap]>, +} + +impl<'a> ParamMap<'a> { + pub fn get_symbol(&self, symbol: Symbol) -> Option<&'a [Param<'a>]> { + let key = Key::Declaration(symbol); + + self.items.get(&key).copied() + } + pub fn get_join_point(&self, id: JoinPointId) -> &'a [Param<'a>] { + let key = Key::JoinPoint(id); + + match self.items.get(&key) { + Some(slice) => slice, + None => unreachable!("join point not in param map: {:?}", id), + } + } +} + +impl<'a> ParamMap<'a> { + fn init_borrow_params(arena: &'a Bump, ps: &'a [Param<'a>]) -> &'a [Param<'a>] { + Vec::from_iter_in( + ps.iter().map(|p| Param { + borrow: p.layout.is_refcounted(), + layout: p.layout.clone(), + symbol: p.symbol, + }), + arena, + ) + .into_bump_slice() + } + + fn init_borrow_args(arena: &'a Bump, ps: &'a [(Layout<'a>, Symbol)]) -> &'a [Param<'a>] { + Vec::from_iter_in( + ps.iter().map(|(layout, symbol)| Param { + borrow: layout.is_refcounted(), + layout: layout.clone(), + symbol: *symbol, + }), + arena, + ) + .into_bump_slice() + } + + fn visit_proc(&mut self, arena: &'a Bump, proc: &Proc<'a>) { + self.items.insert( + Key::Declaration(proc.name), + Self::init_borrow_args(arena, proc.args), + ); + + self.visit_stmt(arena, proc.name, &proc.body); + } + + fn visit_stmt(&mut self, arena: &'a Bump, _fnid: Symbol, stmt: &Stmt<'a>) { + use Stmt::*; + + let mut stack = bumpalo::vec![ in arena; stmt ]; + + while let Some(stmt) = stack.pop() { + match stmt { + Join { + id: j, + parameters: xs, + remainder: v, + continuation: b, + } => { + self.items + .insert(Key::JoinPoint(*j), Self::init_borrow_params(arena, xs)); + + stack.push(v); + stack.push(b); + } + Let(_, _, _, cont) => { + stack.push(cont); + } + Cond { pass, fail, .. } => { + stack.push(pass); + stack.push(fail); + } + Switch { + branches, + default_branch, + .. + } => { + stack.extend(branches.iter().map(|b| &b.1)); + stack.push(default_branch); + } + Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), + + Ret(_) | Jump(_, _) | RuntimeError(_) => { + // these are terminal, do nothing + } + } + } + } +} + +// Apply the inferred borrow annotations stored in ParamMap to a block of mutually recursive procs + +struct BorrowInfState<'a> { + current_proc: Symbol, + param_set: MutSet, + owned: MutMap>, + modified: bool, + param_map: ParamMap<'a>, + arena: &'a Bump, +} + +impl<'a> BorrowInfState<'a> { + pub fn own_var(&mut self, x: Symbol) { + let current = self.owned.get_mut(&self.current_proc).unwrap(); + + if current.contains(&x) { + // do nothing + } else { + current.insert(x); + self.modified = true; + } + } + + fn is_owned(&self, x: Symbol) -> bool { + match self.owned.get(&self.current_proc) { + None => unreachable!( + "the current procedure symbol {:?} is not in the owned map", + self.current_proc + ), + Some(set) => set.contains(&x), + } + } + + fn update_param_map(&mut self, k: Key) { + let arena = self.arena; + if let Some(ps) = self.param_map.items.get(&k) { + let ps = Vec::from_iter_in( + ps.iter().map(|p| { + if !p.borrow { + p.clone() + } else if self.is_owned(p.symbol) { + self.modified = true; + let mut p = p.clone(); + p.borrow = false; + + p + } else { + p.clone() + } + }), + arena, + ); + + self.param_map.items.insert(k, ps.into_bump_slice()); + } + } + + /// This looks at an application `f x1 x2 x3` + /// If the parameter (based on the definition of `f`) is owned, + /// then the argument must also be owned + fn own_args_using_params(&mut self, xs: &[Symbol], ps: &[Param<'a>]) { + debug_assert_eq!(xs.len(), ps.len()); + + for (x, p) in xs.iter().zip(ps.iter()) { + if !p.borrow { + self.own_var(*x); + } + } + } + + /// This looks at an application `f x1 x2 x3` + /// If the parameter (based on the definition of `f`) is owned, + /// then the argument must also be owned + fn own_args_using_bools(&mut self, xs: &[Symbol], ps: &[bool]) { + debug_assert_eq!(xs.len(), ps.len()); + + for (x, borrow) in xs.iter().zip(ps.iter()) { + if !borrow { + self.own_var(*x); + } + } + } + + /// For each xs[i], if xs[i] is owned, then mark ps[i] as owned. + /// We use this action to preserve tail calls. That is, if we have + /// a tail call `f xs`, if the i-th parameter is borrowed, but `xs[i]` is owned + /// we would have to insert a `dec xs[i]` after `f xs` and consequently + /// "break" the tail call. + fn own_params_using_args(&mut self, xs: &[Symbol], ps: &[Param<'a>]) { + debug_assert_eq!(xs.len(), ps.len()); + + for (x, p) in xs.iter().zip(ps.iter()) { + if self.is_owned(*x) { + self.own_var(p.symbol); + } + } + } + + /// Mark `xs[i]` as owned if it is one of the parameters `ps`. + /// We use this action to mark function parameters that are being "packed" inside constructors. + /// This is a heuristic, and is not related with the effectiveness of the reset/reuse optimization. + /// It is useful for code such as + /// + /// > def f (x y : obj) := + /// > let z := ctor_1 x y; + /// > ret z + fn own_args_if_param(&mut self, xs: &[Symbol]) { + for x in xs.iter() { + // TODO may also be asking for the index here? see Lean + if self.param_set.contains(x) { + self.own_var(*x); + } + } + } + + /// This looks at the assignement + /// + /// let z = e in ... + /// + /// and determines whether z and which of the symbols used in e + /// must be taken as owned paramters + fn collect_expr(&mut self, z: Symbol, e: &Expr<'a>) { + use Expr::*; + + match e { + Tag { arguments: xs, .. } | Struct(xs) | Array { elems: xs, .. } => { + self.own_var(z); + + // if the used symbol is an argument to the current function, + // the function must take it as an owned parameter + self.own_args_if_param(xs); + } + EmptyArray => { + self.own_var(z); + } + AccessAtIndex { structure: x, .. } => { + // if the structure (record/tag/array) is owned, the extracted value is + if self.is_owned(*x) { + self.own_var(z); + } + + // if the extracted value is owned, the structure must be too + if self.is_owned(z) { + self.own_var(*x); + } + } + FunctionCall { + call_type, + args, + arg_layouts, + .. + } => { + // get the borrow signature of the applied function + let ps = match self.param_map.get_symbol(call_type.get_inner()) { + Some(slice) => slice, + None => Vec::from_iter_in( + arg_layouts.iter().cloned().map(|layout| Param { + symbol: Symbol::UNDERSCORE, + borrow: false, + layout, + }), + self.arena, + ) + .into_bump_slice(), + }; + + // the return value will be owned + self.own_var(z); + + // if the function exects an owned argument (ps), the argument must be owned (args) + self.own_args_using_params(args, ps); + } + + RunLowLevel(op, args) => { + // very unsure what demand RunLowLevel should place upon its arguments + self.own_var(z); + + let ps = lowlevel_borrow_signature(self.arena, *op); + + self.own_args_using_bools(args, ps); + } + + Literal(_) | FunctionPointer(_, _) | RuntimeErrorFunction(_) => {} + } + } + + fn preserve_tail_call(&mut self, x: Symbol, v: &Expr<'a>, b: &Stmt<'a>) { + if let ( + Expr::FunctionCall { + call_type, + args: ys, + .. + }, + Stmt::Ret(z), + ) = (v, b) + { + let g = call_type.get_inner(); + if self.current_proc == g && x == *z { + // anonymous functions (for which the ps may not be known) + // can never be tail-recursive, so this is fine + if let Some(ps) = self.param_map.get_symbol(g) { + self.own_params_using_args(ys, ps) + } + } + } + } + + fn update_param_set(&mut self, ps: &[Param<'a>]) { + for p in ps.iter() { + self.param_set.insert(p.symbol); + } + } + + fn update_param_set_symbols(&mut self, ps: &[Symbol]) { + for p in ps.iter() { + self.param_set.insert(*p); + } + } + + fn collect_stmt(&mut self, stmt: &Stmt<'a>) { + use Stmt::*; + + match stmt { + Join { + id: j, + parameters: ys, + remainder: v, + continuation: b, + } => { + let old = self.param_set.clone(); + self.update_param_set(ys); + self.collect_stmt(v); + self.param_set = old; + self.update_param_map(Key::JoinPoint(*j)); + + self.collect_stmt(b); + } + + Let(x, Expr::FunctionPointer(fsymbol, layout), _, b) => { + // ensure that the function pointed to is in the param map + if let Some(params) = self.param_map.get_symbol(*fsymbol) { + self.param_map.items.insert(Key::Declaration(*x), params); + } + + self.collect_stmt(b); + self.preserve_tail_call(*x, &Expr::FunctionPointer(*fsymbol, layout.clone()), b); + } + Let(x, v, _, b) => { + self.collect_stmt(b); + self.collect_expr(*x, v); + self.preserve_tail_call(*x, v, b); + } + Jump(j, ys) => { + let ps = self.param_map.get_join_point(*j); + + // for making sure the join point can reuse + self.own_args_using_params(ys, ps); + + // for making sure the tail call is preserved + self.own_params_using_args(ys, ps); + } + Cond { pass, fail, .. } => { + self.collect_stmt(pass); + self.collect_stmt(fail); + } + Switch { + branches, + default_branch, + .. + } => { + for (_, b) in branches.iter() { + self.collect_stmt(b); + } + self.collect_stmt(default_branch); + } + Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), + + Ret(_) | RuntimeError(_) => { + // these are terminal, do nothing + } + } + } + + fn collect_proc(&mut self, proc: &Proc<'a>) { + let old = self.param_set.clone(); + + let ys = Vec::from_iter_in(proc.args.iter().map(|t| t.1), self.arena).into_bump_slice(); + self.update_param_set_symbols(ys); + self.current_proc = proc.name; + + // ensure that current_proc is in the owned map + self.owned.entry(proc.name).or_default(); + + self.collect_stmt(&proc.body); + self.update_param_map(Key::Declaration(proc.name)); + + self.param_set = old; + } +} + +pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { + use LowLevel::*; + + // TODO is true or false more efficient for non-refcounted layouts? + let irrelevant = false; + let owned = false; + let borrowed = true; + + // Here we define the borrow signature of low-level operations + // + // - arguments with non-refcounted layouts (ints, floats) are `irrelevant` + // - arguments that we may want to update destructively must be Owned + // - other refcounted arguments are Borrowed + match op { + ListLen => arena.alloc_slice_copy(&[borrowed]), + ListSet => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]), + ListSetInPlace => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]), + ListGetUnsafe => arena.alloc_slice_copy(&[borrowed, irrelevant]), + + ListSingle => arena.alloc_slice_copy(&[irrelevant]), + ListRepeat => arena.alloc_slice_copy(&[irrelevant, irrelevant]), + ListReverse => arena.alloc_slice_copy(&[owned]), + ListConcat => arena.alloc_slice_copy(&[irrelevant, irrelevant]), + ListAppend => arena.alloc_slice_copy(&[owned, owned]), + ListPrepend => arena.alloc_slice_copy(&[owned, owned]), + ListJoin => arena.alloc_slice_copy(&[irrelevant]), + + Eq | NotEq | And | Or | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte + | NumDivUnchecked | NumRemUnchecked => arena.alloc_slice_copy(&[irrelevant, irrelevant]), + + NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumToFloat | Not => { + arena.alloc_slice_copy(&[irrelevant]) + } + } +} diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index 68bb08e2a5..c0b43138f4 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -1,3 +1,4 @@ +use crate::borrow::ParamMap; use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt}; use crate::layout::Layout; use bumpalo::collections::Vec; @@ -113,7 +114,11 @@ pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet) { result.extend(arguments.iter().copied()); } - RunLowLevel(_, _) | EmptyArray | RuntimeErrorFunction(_) | Literal(_) => {} + RunLowLevel(_, args) => { + result.extend(args.iter()); + } + + EmptyArray | RuntimeErrorFunction(_) | Literal(_) => {} } } @@ -139,7 +144,7 @@ pub struct Context<'a> { vars: VarMap, jp_live_vars: JPLiveVarMap, // map: join point => live variables local_context: LocalContext<'a>, // we use it to store the join point declarations - function_params: MutMap]>, + param_map: &'a ParamMap<'a>, } fn update_live_vars<'a>(expr: &Expr<'a>, v: &LiveVarSet) -> LiveVarSet { @@ -150,6 +155,7 @@ fn update_live_vars<'a>(expr: &Expr<'a>, v: &LiveVarSet) -> LiveVarSet { v } +/// `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` fn is_first_occurence(xs: &[Symbol], i: usize) -> bool { match xs.get(i) { None => unreachable!(), @@ -157,6 +163,9 @@ fn is_first_occurence(xs: &[Symbol], i: usize) -> bool { } } +/// Return `n`, the number of times `x` is consumed. +/// - `ys` is a sequence of instruction parameters where we search for `x`. +/// - `consumeParamPred i = true` if parameter `i` is consumed. fn get_num_consumptions(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> usize where F: Fn(usize) -> bool, @@ -171,6 +180,8 @@ where n } +/// Return true if `x` also occurs in `ys` in a position that is not consumed. +/// That is, it is also passed as a borrow reference. fn is_borrow_param_help(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> bool where F: Fn(usize) -> bool, @@ -182,11 +193,11 @@ where fn is_borrow_param(x: Symbol, ys: &[Symbol], ps: &[Param]) -> bool { // default to owned arguments - let pred = |i: usize| match ps.get(i) { + let is_owned = |i: usize| match ps.get(i) { Some(param) => !param.borrow, - None => true, + None => unreachable!("or?"), }; - is_borrow_param_help(x, ys, pred) + is_borrow_param_help(x, ys, is_owned) } // We do not need to consume the projection of a variable that is not consumed @@ -201,13 +212,13 @@ fn consume_expr(m: &VarMap, e: &Expr<'_>) -> bool { } impl<'a> Context<'a> { - pub fn new(arena: &'a Bump) -> Self { + pub fn new(arena: &'a Bump, param_map: &'a ParamMap<'a>) -> Self { Self { arena, vars: MutMap::default(), jp_live_vars: MutMap::default(), local_context: LocalContext::default(), - function_params: MutMap::default(), + param_map, } } @@ -253,56 +264,13 @@ impl<'a> Context<'a> { self.arena.alloc(Stmt::Dec(symbol, stmt)) } - fn add_inc_before_consume_all_help( - &self, - xs: &[Symbol], - consume_param_pred: F, - mut b: &'a Stmt<'a>, - live_vars_after: &LiveVarSet, - ) -> &'a Stmt<'a> - where - F: Fn(usize) -> bool + Clone, - { - for (i, x) in xs.iter().enumerate() { - let info = self.get_var_info(*x); - if !info.reference || !is_first_occurence(xs, i) { - // do nothing - } else { - // number of times the argument is used (in the body?) - let num_consumptions = get_num_consumptions(*x, xs, consume_param_pred.clone()); - - // `x` is not a variable that must be consumed by the current procedure - // `x` is live after executing instruction - // `x` is used in a position that is passed as a borrow reference - let lives_on = !info.consume - || live_vars_after.contains(x) - || is_borrow_param_help(*x, xs, consume_param_pred.clone()); - - let num_incs = if lives_on { - num_consumptions - } else { - num_consumptions - 1 - }; - - // Lean can increment by more than 1 at once. Is that needed? - debug_assert!(num_incs <= 1); - - if num_incs == 1 { - b = self.add_inc(*x, b); - } - } - } - - b - } - fn add_inc_before_consume_all( &self, xs: &[Symbol], b: &'a Stmt<'a>, live_vars_after: &LiveVarSet, ) -> &'a Stmt<'a> { - self.add_inc_before_consume_all_help(xs, |_: usize| true, b, live_vars_after) + self.add_inc_before_help(xs, |_: usize| true, b, live_vars_after) } fn add_inc_before_help( @@ -321,11 +289,17 @@ impl<'a> Context<'a> { // do nothing } else { let num_consumptions = get_num_consumptions(*x, xs, consume_param_pred.clone()); // number of times the argument is used - let num_incs = if !info.consume || // `x` is not a variable that must be consumed by the current procedure - live_vars_after.contains(x) || // `x` is live after executing instruction - is_borrow_param_help( *x ,xs, consume_param_pred.clone()) + + // `x` is not a variable that must be consumed by the current procedure + let need_not_consume = !info.consume; + + // `x` is live after executing instruction + let is_live_after = live_vars_after.contains(x); + // `x` is used in a position that is passed as a borrow reference - { + let is_borrowed = is_borrow_param_help(*x, xs, consume_param_pred.clone()); + + let num_incs = if need_not_consume || is_live_after || is_borrowed { num_consumptions } else { num_consumptions - 1 @@ -352,7 +326,7 @@ impl<'a> Context<'a> { // default to owned arguments let pred = |i: usize| match ps.get(i) { Some(param) => !param.borrow, - None => true, + None => unreachable!("or?"), }; self.add_inc_before_help(xs, pred, b, live_vars_after) } @@ -383,10 +357,10 @@ impl<'a> Context<'a> { b_live_vars: &LiveVarSet, ) -> &'a Stmt<'a> { for (i, x) in xs.iter().enumerate() { - /* We must add a `dec` if `x` must be consumed, it is alive after the application, - and it has been borrowed by the application. - Remark: `x` may occur multiple times in the application (e.g., `f x y x`). - This is why we check whether it is the first occurrence. */ + // We must add a `dec` if `x` must be consumed, it is alive after the application, + // and it has been borrowed by the application. + // Remark: `x` may occur multiple times in the application (e.g., `f x y x`). + // This is why we check whether it is the first occurrence. if self.must_consume(*x) && is_first_occurence(xs, i) && is_borrow_param(*x, xs, ps) @@ -399,6 +373,31 @@ impl<'a> Context<'a> { b } + fn add_dec_after_lowlevel( + &self, + xs: &[Symbol], + ps: &[bool], + mut b: &'a Stmt<'a>, + b_live_vars: &LiveVarSet, + ) -> &'a Stmt<'a> { + for (i, (x, is_borrow)) in xs.iter().zip(ps.iter()).enumerate() { + /* We must add a `dec` if `x` must be consumed, it is alive after the application, + and it has been borrowed by the application. + Remark: `x` may occur multiple times in the application (e.g., `f x y x`). + This is why we check whether it is the first occurrence. */ + + if self.must_consume(*x) + && is_first_occurence(xs, i) + && *is_borrow + && !b_live_vars.contains(x) + { + b = self.add_dec(*x, b); + } + } + + b + } + #[allow(clippy::many_single_char_names)] fn visit_variable_declaration( &self, @@ -432,54 +431,37 @@ impl<'a> Context<'a> { self.arena.alloc(Stmt::Let(z, v, l, b)) } - RunLowLevel(_, _) => { - // THEORY: runlowlevel only occurs - // - // - in a custom hard-coded function - // - when we insert them as compiler authors - // - // if we're carefule to only use RunLowLevel for non-rc'd types - // (e.g. when building a cond/switch, we check equality on integers, and to boolean and) - // then RunLowLevel should not change in any way the refcounts. + RunLowLevel(op, args) => { + let ps = crate::borrow::lowlevel_borrow_signature(self.arena, op); + let b = self.add_dec_after_lowlevel(args, ps, b, b_live_vars); - // let b = self.add_dec_after_application(ys, ps, b, b_live_vars); self.arena.alloc(Stmt::Let(z, v, l, b)) } FunctionCall { args: ys, - call_type, arg_layouts, + call_type, .. } => { - // this is where the borrow signature would come in - //let ps := (getDecl ctx f).params; - use crate::ir::CallType; - use crate::layout::Builtin; - let symbol = match call_type { - CallType::ByName(s) => s, - CallType::ByPointer(s) => s, + // get the borrow signature + let ps = match self.param_map.get_symbol(call_type.get_inner()) { + Some(slice) => slice, + None => Vec::from_iter_in( + arg_layouts.iter().cloned().map(|layout| Param { + symbol: Symbol::UNDERSCORE, + borrow: false, + layout, + }), + self.arena, + ) + .into_bump_slice(), }; - let ps = Vec::from_iter_in( - arg_layouts.iter().map(|layout| { - let borrow = match layout { - Layout::Builtin(Builtin::List(_, _)) => true, - _ => false, - }; - - Param { - symbol, - borrow, - layout: layout.clone(), - } - }), - self.arena, - ) - .into_bump_slice(); - let b = self.add_dec_after_application(ys, ps, b, b_live_vars); - self.arena.alloc(Stmt::Let(z, v, l, b)) + let b = self.arena.alloc(Stmt::Let(z, v, l, b)); + + self.add_inc_before(ys, ps, b, b_live_vars) } EmptyArray | FunctionPointer(_, _) | Literal(_) | RuntimeErrorFunction(_) => { @@ -495,13 +477,15 @@ impl<'a> Context<'a> { fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self { let mut ctx = self.clone(); - // TODO actually make these non-constant - // can this type be reference-counted at runtime? let reference = layout.contains_refcounted(); // is this value a constant? - let persistent = false; + // TODO do function pointers also fall into this category? + let persistent = match expr { + Expr::FunctionCall { args, .. } => args.is_empty(), + _ => false, + }; // must this value be consumed? let consume = consume_expr(&ctx.vars, expr); @@ -518,9 +502,6 @@ impl<'a> Context<'a> { } fn update_var_info_with_params(&self, ps: &[Param]) -> Self { - //def updateVarInfoWithParams (ctx : Context) (ps : Array Param) : Context := - //let m := ps.foldl (fun (m : VarMap) p => m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }) ctx.varMap; - //{ ctx with varMap := m } let mut ctx = self.clone(); for p in ps.iter() { @@ -535,8 +516,13 @@ impl<'a> Context<'a> { ctx } - /* Add `dec` instructions for parameters that are references, are not alive in `b`, and are not borrow. - That is, we must make sure these parameters are consumed. */ + // Add `dec` instructions for parameters that are + // + // - references + // - not alive in `b` + // - not borrow. + // + // That is, we must make sure these parameters are consumed. fn add_dec_for_dead_params( &self, ps: &[Param<'a>], @@ -619,25 +605,20 @@ impl<'a> Context<'a> { Join { id: j, - parameters: xs, + parameters: _, remainder: b, continuation: v, } => { - let xs = *xs; - - let v_orig = v; - - // NOTE deviation from lean, insert into local context - let mut ctx = self.clone(); - ctx.local_context.join_points.insert(*j, (xs, v_orig)); + // get the parameters with borrow signature + let xs = self.param_map.get_join_point(*j); let (v, v_live_vars) = { - let ctx = ctx.update_var_info_with_params(xs); + let ctx = self.update_var_info_with_params(xs); ctx.visit_stmt(v) }; + let mut ctx = self.clone(); let v = ctx.add_dec_for_dead_params(xs, v, &v_live_vars); - let mut ctx = ctx.clone(); update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars); @@ -673,7 +654,10 @@ impl<'a> Context<'a> { Some(vars) => vars, None => &empty, }; - let ps = self.local_context.join_points.get(j).unwrap().0; + // TODO use borrow signature here? + let ps = self.param_map.get_join_point(*j); + // let ps = self.local_context.join_points.get(j).unwrap().0; + let b = self.add_inc_before(xs, ps, stmt, j_live_vars); let b_live_vars = collect_stmt(b, &self.jp_live_vars, MutSet::default()); @@ -796,8 +780,15 @@ pub fn collect_stmt( collect_stmt(cont, jp_live_vars, vars) } - Jump(_, arguments) => { + Jump(id, arguments) => { vars.extend(arguments.iter().copied()); + + // NOTE deviation from Lean + // we fall through when no join point is available + if let Some(jvars) = jp_live_vars.get(id) { + vars.extend(jvars); + } + vars } @@ -866,8 +857,13 @@ fn update_jp_live_vars(j: JoinPointId, ys: &[Param], v: &Stmt<'_>, m: &mut JPLiv m.insert(j, j_live_vars); } -pub fn visit_declaration<'a>(arena: &'a Bump, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { - let ctx = Context::new(arena); +/// used to process the main function in the repl +pub fn visit_declaration<'a>( + arena: &'a Bump, + param_map: &'a ParamMap<'a>, + stmt: &'a Stmt<'a>, +) -> &'a Stmt<'a> { + let ctx = Context::new(arena, param_map); let params = &[] as &[_]; let ctx = ctx.update_var_info_with_params(params); @@ -875,23 +871,21 @@ pub fn visit_declaration<'a>(arena: &'a Bump, stmt: &'a Stmt<'a>) -> &'a Stmt<'a ctx.add_dec_for_dead_params(params, b, &b_live_vars) } -pub fn visit_proc<'a>(arena: &'a Bump, proc: &mut Proc<'a>) { - let ctx = Context::new(arena); +pub fn visit_proc<'a>(arena: &'a Bump, param_map: &'a ParamMap<'a>, proc: &mut Proc<'a>) { + let ctx = Context::new(arena, param_map); - if proc.name.is_builtin() { - // we must take care of our own refcounting in builtins - return; - } - - let params = Vec::from_iter_in( - proc.args.iter().map(|(layout, symbol)| Param { - symbol: *symbol, - layout: layout.clone(), - borrow: layout.contains_refcounted(), - }), - arena, - ) - .into_bump_slice(); + let params = match param_map.get_symbol(proc.name) { + Some(slice) => slice, + None => Vec::from_iter_in( + proc.args.iter().cloned().map(|(layout, symbol)| Param { + symbol, + borrow: false, + layout, + }), + arena, + ) + .into_bump_slice(), + }; let stmt = arena.alloc(proc.body.clone()); let ctx = ctx.update_var_info_with_params(params); diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 5db5cd815f..d6dff5a32b 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -23,7 +23,7 @@ pub struct PartialProc<'a> { pub annotation: Variable, pub pattern_symbols: Vec<'a, Symbol>, pub body: roc_can::expr::Expr, - pub is_tail_recursive: bool, + pub is_self_recursive: bool, } #[derive(Clone, Debug, PartialEq)] @@ -40,7 +40,13 @@ pub struct Proc<'a> { pub body: Stmt<'a>, pub closes_over: Layout<'a>, pub ret_layout: Layout<'a>, - pub is_tail_recursive: bool, + pub is_self_recursive: SelfRecursive, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum SelfRecursive { + NotSelfRecursive, + SelfRecursive(JoinPointId), } impl<'a> Proc<'a> { @@ -111,15 +117,74 @@ impl<'a> Procs<'a> { for (key, in_prog_proc) in self.specialized.into_iter() { match in_prog_proc { InProgress => unreachable!("The procedure {:?} should have be done by now", key), - Done(mut proc) => { - crate::inc_dec::visit_proc(arena, &mut proc); + Done(proc) => { result.insert(key, proc); } } } + + for (_, proc) in result.iter_mut() { + use self::SelfRecursive::*; + if let SelfRecursive(id) = proc.is_self_recursive { + proc.body = crate::tail_recursion::make_tail_recursive( + arena, + id, + proc.name, + proc.body.clone(), + proc.args, + ); + } + } + + let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result)); + + for (_, proc) in result.iter_mut() { + crate::inc_dec::visit_proc(arena, borrow_params, proc); + } + result } + pub fn get_specialized_procs_help( + self, + arena: &'a Bump, + ) -> ( + MutMap<(Symbol, Layout<'a>), Proc<'a>>, + &'a crate::borrow::ParamMap<'a>, + ) { + let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher()); + + for (key, in_prog_proc) in self.specialized.into_iter() { + match in_prog_proc { + InProgress => unreachable!("The procedure {:?} should have be done by now", key), + Done(proc) => { + result.insert(key, proc); + } + } + } + + for (_, proc) in result.iter_mut() { + use self::SelfRecursive::*; + if let SelfRecursive(id) = proc.is_self_recursive { + proc.body = crate::tail_recursion::make_tail_recursive( + arena, + id, + proc.name, + proc.body.clone(), + proc.args, + ); + } + } + + let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result)); + + for (_, proc) in result.iter_mut() { + crate::inc_dec::visit_proc(arena, borrow_params, proc); + } + + (result, borrow_params) + } + // TODO trim down these arguments! #[allow(clippy::too_many_arguments)] pub fn insert_named( @@ -130,7 +195,7 @@ impl<'a> Procs<'a> { annotation: Variable, loc_args: std::vec::Vec<(Variable, Located)>, loc_body: Located, - is_tail_recursive: bool, + is_self_recursive: bool, ret_var: Variable, ) { match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { @@ -145,7 +210,7 @@ impl<'a> Procs<'a> { annotation, pattern_symbols, body: body.value, - is_tail_recursive, + is_self_recursive, }, ); } @@ -179,7 +244,7 @@ impl<'a> Procs<'a> { layout_cache: &mut LayoutCache<'a>, ) -> Result, RuntimeError> { // anonymous functions cannot reference themselves, therefore cannot be tail-recursive - let is_tail_recursive = false; + let is_self_recursive = false; match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { Ok((pattern_vars, pattern_symbols, body)) => { @@ -219,7 +284,7 @@ impl<'a> Procs<'a> { annotation, pattern_symbols, body: body.value, - is_tail_recursive, + is_self_recursive, }, ); } @@ -229,7 +294,7 @@ impl<'a> Procs<'a> { annotation, pattern_symbols, body: body.value, - is_tail_recursive, + is_self_recursive, }; // Mark this proc as in-progress, so if we're dealing with @@ -459,6 +524,15 @@ pub enum CallType { ByPointer(Symbol), } +impl CallType { + pub fn get_inner(&self) -> Symbol { + match self { + CallType::ByName(s) => *s, + CallType::ByPointer(s) => *s, + } + } +} + #[derive(Clone, Debug, PartialEq)] pub enum Expr<'a> { Literal(Literal<'a>), @@ -1001,7 +1075,7 @@ fn specialize<'a>( annotation, pattern_symbols, body, - is_tail_recursive, + is_self_recursive, } = partial_proc; // unify the called function with the specialized signature, then specialize the function body @@ -1031,9 +1105,6 @@ fn specialize<'a>( let proc_args = proc_args.into_bump_slice(); - let specialized_body = - crate::tail_recursion::make_tail_recursive(env, proc_name, specialized_body, proc_args); - let ret_layout = layout_cache .from_var(&env.arena, ret_var, env.subs) .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); @@ -1041,13 +1112,19 @@ fn specialize<'a>( // TODO WRONG let closes_over_layout = Layout::Struct(&[]); + let recursivity = if is_self_recursive { + SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol())) + } else { + SelfRecursive::NotSelfRecursive + }; + let proc = Proc { name: proc_name, args: proc_args, body: specialized_body, closes_over: closes_over_layout, ret_layout, - is_tail_recursive, + is_self_recursive: recursivity, }; Ok(proc) @@ -1109,8 +1186,8 @@ pub fn with_hole<'a>( let (loc_body, ret_var) = *boxed_body; - let is_tail_recursive = - matches!(recursivity, roc_can::expr::Recursive::TailRecursive); + let is_self_recursive = + !matches!(recursivity, roc_can::expr::Recursive::NotRecursive); procs.insert_named( env, @@ -1119,7 +1196,7 @@ pub fn with_hole<'a>( ann, loc_args, loc_body, - is_tail_recursive, + is_self_recursive, ret_var, ); @@ -1193,8 +1270,8 @@ pub fn with_hole<'a>( let (loc_body, ret_var) = *boxed_body; - let is_tail_recursive = - matches!(recursivity, roc_can::expr::Recursive::TailRecursive); + let is_self_recursive = + !matches!(recursivity, roc_can::expr::Recursive::NotRecursive); procs.insert_named( env, @@ -1203,7 +1280,7 @@ pub fn with_hole<'a>( ann, loc_args, loc_body, - is_tail_recursive, + is_self_recursive, ret_var, ); @@ -2060,8 +2137,8 @@ pub fn from_can<'a>( let (loc_body, ret_var) = *boxed_body; - let is_tail_recursive = - matches!(recursivity, roc_can::expr::Recursive::TailRecursive); + let is_self_recursive = + !matches!(recursivity, roc_can::expr::Recursive::NotRecursive); procs.insert_named( env, @@ -2070,7 +2147,7 @@ pub fn from_can<'a>( ann, loc_args, loc_body, - is_tail_recursive, + is_self_recursive, ret_var, ); @@ -2096,8 +2173,8 @@ pub fn from_can<'a>( let (loc_body, ret_var) = *boxed_body; - let is_tail_recursive = - matches!(recursivity, roc_can::expr::Recursive::TailRecursive); + let is_self_recursive = + !matches!(recursivity, roc_can::expr::Recursive::NotRecursive); procs.insert_named( env, @@ -2106,7 +2183,7 @@ pub fn from_can<'a>( ann, loc_args, loc_body, - is_tail_recursive, + is_self_recursive, ret_var, ); diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 64f21c3616..74d844be2d 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -22,6 +22,7 @@ pub enum Layout<'a> { Builtin(Builtin<'a>), Struct(&'a [Layout<'a>]), Union(&'a [&'a [Layout<'a>]]), + RecursiveUnion(&'a [&'a [Layout<'a>]]), /// A function. The types of its arguments, then the type of its return value. FunctionPointer(&'a [Layout<'a>], &'a Layout<'a>), Pointer(&'a Layout<'a>), @@ -96,6 +97,10 @@ impl<'a> Layout<'a> { Union(tags) => tags .iter() .all(|tag_layout| tag_layout.iter().all(|field| field.safe_to_memcpy())), + RecursiveUnion(_) => { + // a recursive union will always contain a pointer, and are thus not safe to memcpy + false + } FunctionPointer(_, _) => { // Function pointers are immutable and can always be safely copied true @@ -138,6 +143,16 @@ impl<'a> Layout<'a> { }) .max() .unwrap_or_default(), + RecursiveUnion(fields) => fields + .iter() + .map(|tag_layout| { + tag_layout + .iter() + .map(|field| field.stack_size(pointer_size)) + .sum() + }) + .max() + .unwrap_or_default(), FunctionPointer(_, _) => pointer_size, Pointer(_) => pointer_size, } @@ -146,6 +161,7 @@ impl<'a> Layout<'a> { pub fn is_refcounted(&self) -> bool { match self { Layout::Builtin(Builtin::List(_, _)) => true, + Layout::RecursiveUnion(_) => true, _ => false, } } @@ -164,6 +180,7 @@ impl<'a> Layout<'a> { .map(|ls| ls.iter()) .flatten() .any(|f| f.is_refcounted()), + RecursiveUnion(_) => true, FunctionPointer(_, _) | Pointer(_) => false, } } @@ -406,8 +423,41 @@ fn layout_from_flat_type<'a>( Ok(layout_from_tag_union(arena, tags, subs)) } - RecursiveTagUnion(_rec_var, _tags, _ext_var) => { - panic!("TODO make Layout for empty RecursiveTagUnion"); + RecursiveTagUnion(_rec_var, _tags, ext_var) => { + debug_assert!(ext_var_is_empty_tag_union(subs, ext_var)); + + // some observations + // + // * recursive tag unions are always recursive + // * therefore at least one tag has a pointer (non-zero sized) field + // * they must (to be instantiated) have 2 or more tags + // + // That means none of the optimizations for enums or single tag tag unions apply + + // let rec_var = subs.get_root_key_without_compacting(rec_var); + // let mut tag_layouts = Vec::with_capacity_in(tags.len(), arena); + // + // // tags: MutMap>, + // for (_name, variables) in tags { + // let mut tag_layout = Vec::with_capacity_in(variables.len(), arena); + // + // for var in variables { + // // TODO does this still cause problems with mutually recursive unions? + // if rec_var == subs.get_root_key_without_compacting(var) { + // // TODO make this a pointer? + // continue; + // } + // + // let var_content = subs.get_without_compacting(var).content; + // + // tag_layout.push(Layout::new(arena, var_content, subs)?); + // } + // + // tag_layouts.push(tag_layout.into_bump_slice()); + // } + // + // Ok(Layout::RecursiveUnion(tag_layouts.into_bump_slice())) + Ok(Layout::RecursiveUnion(&[])) } EmptyTagUnion => { panic!("TODO make Layout for empty Tag Union"); diff --git a/compiler/mono/src/lib.rs b/compiler/mono/src/lib.rs index 4feb7c8696..e44b3d7fae 100644 --- a/compiler/mono/src/lib.rs +++ b/compiler/mono/src/lib.rs @@ -11,6 +11,7 @@ // re-enable this when working on performance optimizations than have it block PRs. #![allow(clippy::large_enum_variant)] +pub mod borrow; pub mod inc_dec; pub mod ir; pub mod layout; diff --git a/compiler/mono/src/tail_recursion.rs b/compiler/mono/src/tail_recursion.rs index 7abdead12d..7458624172 100644 --- a/compiler/mono/src/tail_recursion.rs +++ b/compiler/mono/src/tail_recursion.rs @@ -1,19 +1,40 @@ -use crate::ir::{CallType, Env, Expr, JoinPointId, Param, Stmt}; +use crate::ir::{CallType, Expr, JoinPointId, Param, Stmt}; use crate::layout::Layout; use bumpalo::collections::Vec; use bumpalo::Bump; use roc_module::symbol::Symbol; +/// Make tail calls into loops (using join points) +/// +/// e.g. +/// +/// > factorial n accum = if n == 1 then accum else factorial (n - 1) (n * accum) +/// +/// becomes +/// +/// ```elm +/// factorial n1 accum1 = +/// let joinpoint j n accum = +/// if n == 1 then +/// accum +/// else +/// jump j (n - 1) (n * accum) +/// +/// in +/// jump j n1 accum1 +/// ``` +/// +/// This will effectively compile into a loop in llvm, and +/// won't grow the call stack for each iteration pub fn make_tail_recursive<'a>( - env: &mut Env<'a, '_>, + arena: &'a Bump, + id: JoinPointId, needle: Symbol, stmt: Stmt<'a>, args: &'a [(Layout<'a>, Symbol)], ) -> Stmt<'a> { - let id = JoinPointId(env.unique_symbol()); - - let alloced = env.arena.alloc(stmt); - match insert_jumps(env.arena, alloced, id, needle) { + let alloced = arena.alloc(stmt); + match insert_jumps(arena, alloced, id, needle) { None => alloced.clone(), Some(new) => { // jumps were inserted, we must now add a join point @@ -24,13 +45,14 @@ pub fn make_tail_recursive<'a>( layout: layout.clone(), borrow: true, }), - env.arena, + arena, ) .into_bump_slice(); - let args = Vec::from_iter_in(args.iter().map(|t| t.1), env.arena).into_bump_slice(); + // TODO could this be &[]? + let args = Vec::from_iter_in(args.iter().map(|t| t.1), arena).into_bump_slice(); - let jump = env.arena.alloc(Stmt::Jump(id, args)); + let jump = arena.alloc(Stmt::Jump(id, args)); Stmt::Join { id, @@ -185,7 +207,6 @@ fn insert_jumps<'a>( None } } - Ret(_) => None, Inc(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) { Some(cont) => Some(arena.alloc(Inc(*symbol, cont))), None => None, @@ -195,6 +216,7 @@ fn insert_jumps<'a>( None => None, }, + Ret(_) => None, Jump(_, _) => None, RuntimeError(_) => None, } diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 08e35aa084..c3b8bb3311 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -66,17 +66,18 @@ mod test_mono { // let mono_expr = Expr::new(&mut mono_env, loc_expr.value, &mut procs); let procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut LayoutCache::default()); - // apply inc/dec - let stmt = mono_env.arena.alloc(ir_expr); - let ir_expr = roc_mono::inc_dec::visit_declaration(mono_env.arena, stmt); - assert_eq!( procs.runtime_errors, roc_collections::all::MutMap::default() ); + let (procs, param_map) = procs.get_specialized_procs_help(mono_env.arena); + + // apply inc/dec + let stmt = mono_env.arena.alloc(ir_expr); + let ir_expr = roc_mono::inc_dec::visit_declaration(mono_env.arena, param_map, stmt); + let mut procs_string = procs - .get_specialized_procs(mono_env.arena) .values() .map(|proc| proc.to_pretty(200)) .collect::>(); @@ -94,6 +95,7 @@ mod test_mono { let result_lines = result.split("\n").collect::>(); assert_eq!(expected_lines, result_lines); + //assert_eq!(0, 1); } } @@ -380,27 +382,35 @@ mod test_mono { fn guard_pattern_true() { compiles_to_ir( r#" - when 2 is - 2 if False -> 42 - _ -> 0 + main = \{} -> + when 2 is + 2 if False -> 42 + _ -> 0 + + main {} "#, indoc!( r#" - let Test.0 = 2i64; - let Test.6 = true; - let Test.7 = 2i64; - let Test.10 = lowlevel Eq Test.7 Test.0; - let Test.8 = lowlevel And Test.10 Test.6; - let Test.3 = false; - jump Test.2 Test.3; - joinpoint Test.2 Test.9: - let Test.5 = lowlevel And Test.9 Test.8; - if Test.5 then - let Test.1 = 42i64; - ret Test.1; - else - let Test.4 = 0i64; - ret Test.4; + procedure Test.0 (Test.2): + let Test.5 = 2i64; + let Test.11 = true; + let Test.12 = 2i64; + let Test.15 = lowlevel Eq Test.12 Test.5; + let Test.13 = lowlevel And Test.15 Test.11; + let Test.8 = false; + jump Test.7 Test.8; + joinpoint Test.7 Test.14: + let Test.10 = lowlevel And Test.14 Test.13; + if Test.10 then + let Test.6 = 42i64; + ret Test.6; + else + let Test.9 = 0i64; + ret Test.9; + + let Test.4 = Struct {}; + let Test.3 = CallByName Test.0 Test.4; + ret Test.3; "# ), ) @@ -539,7 +549,6 @@ mod test_mono { let Test.6 = 2i64; let Test.4 = Array [Test.5, Test.6]; let Test.3 = CallByName Test.0 Test.4; - dec Test.4; ret Test.3; "# ), @@ -548,6 +557,8 @@ mod test_mono { #[test] fn list_append() { + // TODO this leaks at the moment + // ListAppend needs to decrement its arguments compiles_to_ir( r#" List.append [1] 2 @@ -562,7 +573,6 @@ mod test_mono { let Test.1 = Array [Test.3]; let Test.2 = 2i64; let Test.0 = CallByName List.5 Test.1 Test.2; - dec Test.1; ret Test.0; "# ), @@ -581,16 +591,16 @@ mod test_mono { indoc!( r#" procedure Num.14 (#Attr.2, #Attr.3): - let Test.13 = lowlevel NumAdd #Attr.2 #Attr.3; - ret Test.13; + let Test.11 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.11; procedure List.7 (#Attr.2): let Test.9 = lowlevel ListLen #Attr.2; ret Test.9; procedure List.7 (#Attr.2): - let Test.11 = lowlevel ListLen #Attr.2; - ret Test.11; + let Test.10 = lowlevel ListLen #Attr.2; + ret Test.10; let Test.8 = 1f64; let Test.1 = Array [Test.8]; @@ -613,35 +623,43 @@ mod test_mono { fn when_joinpoint() { compiles_to_ir( r#" - x : [ Red, White, Blue ] - x = Blue + main = \{} -> + x : [ Red, White, Blue ] + x = Blue - y = - when x is - Red -> 1 - White -> 2 - Blue -> 3 + y = + when x is + Red -> 1 + White -> 2 + Blue -> 3 - y + y + + main {} "#, indoc!( r#" - let Test.0 = 0u8; - switch Test.0: - case 1: - let Test.4 = 1i64; - jump Test.3 Test.4; + procedure Test.0 (Test.4): + let Test.2 = 0u8; + switch Test.2: + case 1: + let Test.9 = 1i64; + jump Test.8 Test.9; + + case 2: + let Test.10 = 2i64; + jump Test.8 Test.10; + + default: + let Test.11 = 3i64; + jump Test.8 Test.11; + + joinpoint Test.8 Test.3: + ret Test.3; - case 2: - let Test.5 = 2i64; - jump Test.3 Test.5; - - default: - let Test.6 = 3i64; - jump Test.3 Test.6; - - joinpoint Test.3 Test.1: - ret Test.1; + let Test.6 = Struct {}; + let Test.5 = CallByName Test.0 Test.6; + ret Test.5; "# ), ) @@ -704,43 +722,51 @@ mod test_mono { fn when_on_result() { compiles_to_ir( r#" - x : Result Int Int - x = Ok 2 + main = \{} -> + x : Result Int Int + x = Ok 2 - y = - when x is - Ok 3 -> 1 - Ok _ -> 2 - Err _ -> 3 - y + y = + when x is + Ok 3 -> 1 + Ok _ -> 2 + Err _ -> 3 + y + + main {} "#, indoc!( r#" - let Test.17 = 1i64; - let Test.18 = 2i64; - let Test.0 = Ok Test.17 Test.18; - let Test.13 = true; - let Test.15 = Index 0 Test.0; - let Test.14 = 1i64; - let Test.16 = lowlevel Eq Test.14 Test.15; - let Test.12 = lowlevel And Test.16 Test.13; - if Test.12 then - let Test.8 = true; - let Test.9 = 3i64; - let Test.10 = Index 0 Test.0; - let Test.11 = lowlevel Eq Test.9 Test.10; - let Test.7 = lowlevel And Test.11 Test.8; - if Test.7 then - let Test.4 = 1i64; - jump Test.3 Test.4; + procedure Test.0 (Test.4): + let Test.22 = 1i64; + let Test.23 = 2i64; + let Test.2 = Ok Test.22 Test.23; + let Test.18 = true; + let Test.20 = Index 0 Test.2; + let Test.19 = 1i64; + let Test.21 = lowlevel Eq Test.19 Test.20; + let Test.17 = lowlevel And Test.21 Test.18; + if Test.17 then + let Test.13 = true; + let Test.14 = 3i64; + let Test.15 = Index 0 Test.2; + let Test.16 = lowlevel Eq Test.14 Test.15; + let Test.12 = lowlevel And Test.16 Test.13; + if Test.12 then + let Test.9 = 1i64; + jump Test.8 Test.9; + else + let Test.10 = 2i64; + jump Test.8 Test.10; else - let Test.5 = 2i64; - jump Test.3 Test.5; - else - let Test.6 = 3i64; - jump Test.3 Test.6; - joinpoint Test.3 Test.1: - ret Test.1; + let Test.11 = 3i64; + jump Test.8 Test.11; + joinpoint Test.8 Test.3: + ret Test.3; + + let Test.6 = Struct {}; + let Test.5 = CallByName Test.0 Test.6; + ret Test.5; "# ), ) @@ -796,30 +822,38 @@ mod test_mono { compiles_to_ir( indoc!( r#" - when 10 is - x if x == 5 -> 0 - _ -> 42 + main = \{} -> + when 10 is + x if x == 5 -> 0 + _ -> 42 + + main {} "# ), indoc!( r#" - procedure Bool.5 (#Attr.2, #Attr.3): - let Test.10 = lowlevel Eq #Attr.2 #Attr.3; - ret Test.10; + procedure Test.0 (Test.3): + let Test.6 = 10i64; + let Test.14 = true; + let Test.10 = 5i64; + let Test.9 = CallByName Bool.5 Test.6 Test.10; + jump Test.8 Test.9; + joinpoint Test.8 Test.15: + let Test.13 = lowlevel And Test.15 Test.14; + if Test.13 then + let Test.7 = 0i64; + ret Test.7; + else + let Test.12 = 42i64; + ret Test.12; - let Test.1 = 10i64; - let Test.8 = true; - let Test.5 = 5i64; - let Test.4 = CallByName Bool.5 Test.1 Test.5; - jump Test.3 Test.4; - joinpoint Test.3 Test.9: - let Test.7 = lowlevel And Test.9 Test.8; - if Test.7 then - let Test.2 = 0i64; - ret Test.2; - else - let Test.6 = 42i64; - ret Test.6; + procedure Bool.5 (#Attr.2, #Attr.3): + let Test.11 = lowlevel Eq #Attr.2 #Attr.3; + ret Test.11; + + let Test.5 = Struct {}; + let Test.4 = CallByName Test.0 Test.5; + ret Test.4; "# ), ) @@ -905,12 +939,6 @@ mod test_mono { ), indoc!( r#" - procedure Test.1 (Test.3): - let Test.9 = 0i64; - let Test.10 = 0i64; - let Test.8 = CallByName List.4 Test.3 Test.9 Test.10; - ret Test.8; - procedure List.4 (#Attr.2, #Attr.3, #Attr.4): let Test.14 = lowlevel ListLen #Attr.2; let Test.12 = lowlevel NumLt #Attr.3 Test.14; @@ -920,12 +948,17 @@ mod test_mono { else ret #Attr.2; + procedure Test.1 (Test.3): + let Test.9 = 0i64; + let Test.10 = 0i64; + let Test.8 = CallByName List.4 Test.3 Test.9 Test.10; + ret Test.8; + let Test.5 = 1i64; let Test.6 = 2i64; let Test.7 = 3i64; let Test.0 = Array [Test.5, Test.6, Test.7]; let Test.4 = CallByName Test.1 Test.0; - dec Test.0; ret Test.4; "# ), @@ -1066,7 +1099,8 @@ mod test_mono { ) } - #[allow(dead_code)] + #[ignore] + #[test] fn quicksort_help() { crate::helpers::with_larger_debug_stack(|| { compiles_to_ir( @@ -1094,7 +1128,8 @@ mod test_mono { }) } - #[allow(dead_code)] + #[ignore] + #[test] fn quicksort_partition_help() { crate::helpers::with_larger_debug_stack(|| { compiles_to_ir( @@ -1128,7 +1163,8 @@ mod test_mono { }) } - #[allow(dead_code)] + #[ignore] + #[test] fn quicksort_full() { crate::helpers::with_larger_debug_stack(|| { compiles_to_ir( @@ -1217,29 +1253,29 @@ mod test_mono { "#, indoc!( r#" + procedure Num.15 (#Attr.2, #Attr.3): + let Test.13 = lowlevel NumSub #Attr.2 #Attr.3; + ret Test.13; + procedure Test.0 (Test.2, Test.3): - jump Test.20 Test.2 Test.3; - joinpoint Test.20 Test.2 Test.3: - let Test.17 = true; - let Test.18 = 0i64; - let Test.19 = lowlevel Eq Test.18 Test.2; - let Test.16 = lowlevel And Test.19 Test.17; - if Test.16 then + jump Test.18 Test.2 Test.3; + joinpoint Test.18 Test.2 Test.3: + let Test.15 = true; + let Test.16 = 0i64; + let Test.17 = lowlevel Eq Test.16 Test.2; + let Test.14 = lowlevel And Test.17 Test.15; + if Test.14 then ret Test.3; else - let Test.13 = 1i64; - let Test.9 = CallByName Num.15 Test.2 Test.13; + let Test.12 = 1i64; + let Test.9 = CallByName Num.15 Test.2 Test.12; let Test.10 = CallByName Num.16 Test.2 Test.3; - jump Test.20 Test.9 Test.10; + jump Test.18 Test.9 Test.10; procedure Num.16 (#Attr.2, #Attr.3): let Test.11 = lowlevel NumMul #Attr.2 #Attr.3; ret Test.11; - procedure Num.15 (#Attr.2, #Attr.3): - let Test.14 = lowlevel NumSub #Attr.2 #Attr.3; - ret Test.14; - let Test.5 = 10i64; let Test.6 = 1i64; let Test.4 = CallByName Test.0 Test.5 Test.6; @@ -1248,4 +1284,248 @@ mod test_mono { ), ) } + + #[test] + #[ignore] + fn is_nil() { + compiles_to_ir( + r#" + ConsList a : [ Cons a (ConsList a), Nil ] + + isNil : ConsList a -> Bool + isNil = \list -> + when list is + Nil -> True + Cons _ _ -> False + + isNil (Cons 0x2 Nil) + "#, + indoc!( + r#" + procedure Test.1 (Test.3): + let Test.13 = true; + let Test.15 = Index 0 Test.3; + let Test.14 = 1i64; + let Test.16 = lowlevel Eq Test.14 Test.15; + let Test.12 = lowlevel And Test.16 Test.13; + if Test.12 then + let Test.10 = true; + ret Test.10; + else + let Test.11 = false; + ret Test.11; + + let Test.6 = 0i64; + let Test.7 = 2i64; + let Test.9 = 1i64; + let Test.8 = Nil Test.9; + let Test.5 = Cons Test.6 Test.7 Test.8; + let Test.4 = CallByName Test.1 Test.5; + ret Test.4; + "# + ), + ) + } + + #[test] + #[ignore] + fn has_none() { + compiles_to_ir( + r#" + Maybe a : [ Just a, Nothing ] + ConsList a : [ Cons a (ConsList a), Nil ] + + hasNone : ConsList (Maybe a) -> Bool + hasNone = \list -> + when list is + Nil -> False + Cons Nothing _ -> True + Cons (Just _) xs -> hasNone xs + + hasNone (Cons (Just 3) Nil) + "#, + indoc!( + r#" + procedure Test.1 (Test.3): + let Test.13 = true; + let Test.15 = Index 0 Test.3; + let Test.14 = 1i64; + let Test.16 = lowlevel Eq Test.14 Test.15; + let Test.12 = lowlevel And Test.16 Test.13; + if Test.12 then + let Test.10 = true; + ret Test.10; + else + let Test.11 = false; + ret Test.11; + + let Test.6 = 0i64; + let Test.7 = 2i64; + let Test.9 = 1i64; + let Test.8 = Nil Test.9; + let Test.5 = Cons Test.6 Test.7 Test.8; + let Test.4 = CallByName Test.1 Test.5; + ret Test.4; + "# + ), + ) + } + + #[test] + fn mk_pair_of() { + compiles_to_ir( + r#" + mkPairOf = \x -> Pair x x + + mkPairOf [1,2,3] + "#, + indoc!( + r#" + procedure Test.0 (Test.2): + inc Test.2; + let Test.8 = Struct {Test.2, Test.2}; + ret Test.8; + + let Test.5 = 1i64; + let Test.6 = 2i64; + let Test.7 = 3i64; + let Test.4 = Array [Test.5, Test.6, Test.7]; + let Test.3 = CallByName Test.0 Test.4; + ret Test.3; + "# + ), + ) + } + + #[test] + fn fst() { + compiles_to_ir( + r#" + fst = \x, y -> x + + fst [1,2,3] [3,2,1] + "#, + indoc!( + r#" + procedure Test.0 (Test.2, Test.3): + inc Test.2; + ret Test.2; + + let Test.10 = 1i64; + let Test.11 = 2i64; + let Test.12 = 3i64; + let Test.5 = Array [Test.10, Test.11, Test.12]; + let Test.7 = 3i64; + let Test.8 = 2i64; + let Test.9 = 1i64; + let Test.6 = Array [Test.7, Test.8, Test.9]; + let Test.4 = CallByName Test.0 Test.5 Test.6; + dec Test.6; + dec Test.5; + ret Test.4; + "# + ), + ) + } + + #[test] + fn list_cannot_update_inplace() { + compiles_to_ir( + indoc!( + r#" + x : List Int + x = [1,2,3] + + add : List Int -> List Int + add = \y -> List.set y 0 0 + + List.len (add x) + List.len x + "# + ), + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.19 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.19; + + procedure Test.1 (Test.3): + let Test.13 = 0i64; + let Test.14 = 0i64; + let Test.12 = CallByName List.4 Test.3 Test.13 Test.14; + ret Test.12; + + procedure List.4 (#Attr.2, #Attr.3, #Attr.4): + let Test.18 = lowlevel ListLen #Attr.2; + let Test.16 = lowlevel NumLt #Attr.3 Test.18; + if Test.16 then + let Test.17 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4; + ret Test.17; + else + ret #Attr.2; + + procedure List.7 (#Attr.2): + let Test.11 = lowlevel ListLen #Attr.2; + ret Test.11; + + let Test.8 = 1i64; + let Test.9 = 2i64; + let Test.10 = 3i64; + let Test.0 = Array [Test.8, Test.9, Test.10]; + inc Test.0; + let Test.7 = CallByName Test.1 Test.0; + let Test.5 = CallByName List.7 Test.7; + dec Test.7; + let Test.6 = CallByName List.7 Test.0; + dec Test.0; + let Test.4 = CallByName Num.14 Test.5 Test.6; + ret Test.4; + "# + ), + ) + } + + #[test] + fn list_get() { + compiles_to_ir( + indoc!( + r#" + main = \{} -> + List.get [1,2,3] 0 + + main {} + "# + ), + indoc!( + r#" + procedure Test.0 (Test.2): + let Test.16 = 1i64; + let Test.17 = 2i64; + let Test.18 = 3i64; + let Test.6 = Array [Test.16, Test.17, Test.18]; + let Test.7 = 0i64; + let Test.5 = CallByName List.3 Test.6 Test.7; + dec Test.6; + ret Test.5; + + procedure List.3 (#Attr.2, #Attr.3): + let Test.15 = lowlevel ListLen #Attr.2; + let Test.11 = lowlevel NumLt #Attr.3 Test.15; + if Test.11 then + let Test.13 = 1i64; + let Test.14 = lowlevel ListGetUnsafe #Attr.2 #Attr.3; + let Test.12 = Ok Test.13 Test.14; + ret Test.12; + else + let Test.9 = 0i64; + let Test.10 = Struct {}; + let Test.8 = Err Test.9 Test.10; + ret Test.8; + + let Test.4 = Struct {}; + let Test.3 = CallByName Test.0 Test.4; + ret Test.3; + "# + ), + ) + } } diff --git a/examples/shared-quicksort/Quicksort.roc b/examples/shared-quicksort/Quicksort.roc new file mode 100644 index 0000000000..28859806af --- /dev/null +++ b/examples/shared-quicksort/Quicksort.roc @@ -0,0 +1,69 @@ +app Quicksort provides [ quicksort ] imports [] + +quicksort : List Int -> List Int +quicksort = \originalList -> helper originalList + +helper : List Int -> List Int +helper = \originalList -> + + quicksortHelp : List (Num a), Int, Int -> List (Num a) + quicksortHelp = \list, low, high -> + if low < high then + when partition low high list is + Pair partitionIndex partitioned -> + partitioned + |> quicksortHelp low (partitionIndex - 1) + |> quicksortHelp (partitionIndex + 1) high + else + list + + + swap : Int, Int, List a -> List a + swap = \i, j, list -> + when Pair (List.get list i) (List.get list j) is + Pair (Ok atI) (Ok atJ) -> + list + |> List.set i atJ + |> List.set j atI + + _ -> + [] + + partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] + partition = \low, high, initialList -> + when List.get initialList high is + Ok pivot -> + when partitionHelp (low - 1) low initialList high pivot is + Pair newI newList -> + Pair (newI + 1) (swap (newI + 1) high newList) + + Err _ -> + Pair (low - 1) initialList + + + partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ] + partitionHelp = \i, j, list, high, pivot -> + if j < high then + when List.get list j is + Ok value -> + if value <= pivot then + partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + else + partitionHelp i (j + 1) list high pivot + + Err _ -> + Pair i list + else + Pair i list + + + + result = quicksortHelp originalList 0 (List.len originalList - 1) + + if List.len originalList > 3 then + result + else + # Absolutely make the `originalList` Shared by using it again here + # but this branch is not evaluated, so should not affect performance + List.set originalList 0 (List.len originalList) + diff --git a/examples/shared-quicksort/host.rs b/examples/shared-quicksort/host.rs new file mode 100644 index 0000000000..08eaf722e4 --- /dev/null +++ b/examples/shared-quicksort/host.rs @@ -0,0 +1,47 @@ +use std::time::SystemTime; + +#[link(name = "roc_app", kind = "static")] +extern "C" { + #[allow(improper_ctypes)] + #[link_name = "quicksort#1"] + fn quicksort(list: &[i64]) -> Box<[i64]>; +} + +const NUM_NUMS: usize = 1_000_000; + +pub fn main() { + let nums = { + let mut nums = Vec::with_capacity(NUM_NUMS + 1); + + // give this list refcount 1 + nums.push((std::usize::MAX - 1) as i64); + + for index in 1..nums.capacity() { + let num = index as i64 % 12345; + + nums.push(num); + } + + nums + }; + + println!("Running Roc shared quicksort"); + let start_time = SystemTime::now(); + let answer = unsafe { quicksort(&nums[1..]) }; + let end_time = SystemTime::now(); + let duration = end_time.duration_since(start_time).unwrap(); + + println!( + "Roc quicksort took {:.4} ms to compute this answer: {:?}", + duration.as_secs_f64() * 1000.0, + // truncate the answer, so stdout is not swamped + // NOTE index 0 is the refcount! + &answer[1..20] + ); + + // the pointer is to the first _element_ of the list, + // but the refcount precedes it. Thus calling free() on + // this pointer would segfault/cause badness. Therefore, we + // leak it for now + Box::leak(answer); +}