diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 5f01331370..5a247827b9 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -173,12 +173,19 @@ pub fn construct_optimization_passes<'a>( } OptLevel::Optimize => { // this threshold seems to do what we want - pmb.set_inliner_with_threshold(2); + pmb.set_inliner_with_threshold(0); // TODO figure out which of these actually help // function passes + fpm.add_cfg_simplification_pass(); + mpm.add_cfg_simplification_pass(); + + fpm.add_jump_threading_pass(); + mpm.add_jump_threading_pass(); + + //fpm.add_ind_var_simplify_pass(); fpm.add_memcpy_optimize_pass(); // this one is very important // In my testing, these don't do much for quicksort @@ -631,7 +638,17 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( result } - Ret(symbol) => load_symbol(env, scope, symbol), + Ret(symbol) => { + let value = load_symbol(env, scope, symbol); + + if let Some(block) = env.builder.get_insert_block() { + if block.get_terminator().is_none() { + env.builder.build_return(Some(&value)); + } + } + + value + } Cond { branching_symbol, @@ -659,7 +676,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( &dyn inkwell::values::BasicValue<'_>, inkwell::basic_block::BasicBlock<'_>, )> = std::vec::Vec::with_capacity(2); - let cont_block = context.append_basic_block(parent, "branchcont"); + let cont_block = context.append_basic_block(parent, "condbranchcont"); builder.build_conditional_branch(value, then_block, else_block); @@ -1025,7 +1042,7 @@ fn decrement_refcount_list<'a, 'ctx, 'env>( // build blocks let then_block = ctx.append_basic_block(parent, "then"); let else_block = ctx.append_basic_block(parent, "else"); - let cont_block = ctx.append_basic_block(parent, "branchcont"); + let cont_block = ctx.append_basic_block(parent, "dec_ref_branchcont"); builder.build_conditional_branch(comparison, then_block, else_block); @@ -1283,7 +1300,7 @@ where // build blocks let then_block = context.append_basic_block(parent, "then"); let else_block = context.append_basic_block(parent, "else"); - let cont_block = context.append_basic_block(parent, "branchcont"); + let cont_block = context.append_basic_block(parent, "phi2_branchcont"); builder.build_conditional_branch(comparison, then_block, else_block); @@ -1410,7 +1427,12 @@ pub fn build_proc<'a, 'ctx, 'env>( let body = build_exp_stmt(env, layout_ids, &mut scope, fn_val, &proc.body); - builder.build_return(Some(&body)); + // only add a return if codegen did not already add one + if let Some(block) = builder.get_insert_block() { + if block.get_terminator().is_none() { + builder.build_return(Some(&body)); + } + } } pub fn verify_fn(fn_val: FunctionValue<'_>) { diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index 839136a089..da0dffcfaf 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -170,7 +170,8 @@ pub fn helper_without_uniqueness<'a>( builder.position_at_end(basic_block); - let ret = roc_gen::llvm::build::build_exp_stmt( + // builds the function body (return statement included) + roc_gen::llvm::build::build_exp_stmt( &env, &mut layout_ids, &mut Scope::default(), @@ -178,8 +179,6 @@ pub fn helper_without_uniqueness<'a>( &main_body, ); - builder.build_return(Some(&ret)); - // Uncomment this to see the module's un-optimized LLVM instruction output: // env.module.print_to_stderr(); @@ -361,7 +360,8 @@ pub fn helper_with_uniqueness<'a>( builder.position_at_end(basic_block); - let ret = roc_gen::llvm::build::build_exp_stmt( + // builds the function body (return statement included) + roc_gen::llvm::build::build_exp_stmt( &env, &mut layout_ids, &mut Scope::default(), @@ -369,8 +369,6 @@ pub fn helper_with_uniqueness<'a>( &main_body, ); - builder.build_return(Some(&ret)); - // you're in the version with uniqueness! // Uncomment this to see the module's un-optimized LLVM instruction output: diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index c3a52c3349..9d321e73da 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -999,9 +999,10 @@ fn specialize<'a>( debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_))); - let ret_symbol = env.unique_symbol(); - let hole = env.arena.alloc(Stmt::Ret(ret_symbol)); - let specialized_body = with_hole(env, body, procs, layout_cache, ret_symbol, hole); + //let ret_symbol = env.unique_symbol(); + //let hole = env.arena.alloc(Stmt::Ret(ret_symbol)); + //let specialized_body = with_hole(env, body, procs, layout_cache, ret_symbol, hole); + let specialized_body = from_can(env, body, procs, layout_cache); // reset subs, so we don't get type errors when specializing for a different signature env.subs.rollback_to(snapshot); @@ -1449,68 +1450,119 @@ pub fn with_hole<'a>( .from_var(env.arena, cond_var, env.subs) .expect("invalid cond_layout"); - let assigned_in_jump = env.unique_symbol(); - let id = JoinPointId(env.unique_symbol()); - let jump = env - .arena - .alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump]))); + // if the hole is a return, then we don't need to merge the two + // branches together again, we can just immediately return + let is_terminated = matches!(hole, Stmt::Ret(_)); - let mut stmt = with_hole( - env, - final_else.value, - procs, - layout_cache, - assigned_in_jump, - jump, - ); + if is_terminated { + let terminator = hole; - for (loc_cond, loc_then) in branches.into_iter().rev() { - let branching_symbol = env.unique_symbol(); - let then = with_hole( + let mut stmt = with_hole( env, - loc_then.value, + final_else.value, + procs, + layout_cache, + assigned, + terminator, + ); + + for (loc_cond, loc_then) in branches.into_iter().rev() { + let branching_symbol = env.unique_symbol(); + let then = with_hole( + env, + loc_then.value, + procs, + layout_cache, + assigned, + terminator, + ); + + stmt = Stmt::Cond { + cond_symbol: branching_symbol, + branching_symbol, + cond_layout: cond_layout.clone(), + branching_layout: cond_layout.clone(), + pass: env.arena.alloc(then), + fail: env.arena.alloc(stmt), + ret_layout: ret_layout.clone(), + }; + + // add condition + stmt = with_hole( + env, + loc_cond.value, + procs, + layout_cache, + branching_symbol, + env.arena.alloc(stmt), + ); + } + stmt + } else { + let assigned_in_jump = env.unique_symbol(); + let id = JoinPointId(env.unique_symbol()); + + let terminator = env + .arena + .alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump]))); + + let mut stmt = with_hole( + env, + final_else.value, procs, layout_cache, assigned_in_jump, - jump, + terminator, ); - stmt = Stmt::Cond { - cond_symbol: branching_symbol, - branching_symbol, - cond_layout: cond_layout.clone(), - branching_layout: cond_layout.clone(), - pass: env.arena.alloc(then), - fail: env.arena.alloc(stmt), - ret_layout: ret_layout.clone(), + for (loc_cond, loc_then) in branches.into_iter().rev() { + let branching_symbol = env.unique_symbol(); + let then = with_hole( + env, + loc_then.value, + procs, + layout_cache, + assigned_in_jump, + terminator, + ); + + stmt = Stmt::Cond { + cond_symbol: branching_symbol, + branching_symbol, + cond_layout: cond_layout.clone(), + branching_layout: cond_layout.clone(), + pass: env.arena.alloc(then), + fail: env.arena.alloc(stmt), + ret_layout: ret_layout.clone(), + }; + + // add condition + stmt = with_hole( + env, + loc_cond.value, + procs, + layout_cache, + branching_symbol, + env.arena.alloc(stmt), + ); + } + + let layout = layout_cache + .from_var(env.arena, branch_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let param = Param { + symbol: assigned, + layout, + borrow: false, }; - // add condition - stmt = with_hole( - env, - loc_cond.value, - procs, - layout_cache, - branching_symbol, - env.arena.alloc(stmt), - ); - } - - let layout = layout_cache - .from_var(env.arena, branch_var, env.subs) - .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); - - let param = Param { - symbol: assigned, - layout, - borrow: false, - }; - - Stmt::Join { - id, - parameters: env.arena.alloc([param]), - remainder: env.arena.alloc(stmt), - continuation: hole, + Stmt::Join { + id, + parameters: env.arena.alloc([param]), + remainder: env.arena.alloc(stmt), + continuation: hole, + } } } @@ -1883,6 +1935,89 @@ pub fn from_can<'a>( use roc_can::expr::Expr::*; match can_expr { + When { + cond_var, + expr_var, + region, + loc_cond, + branches, + } => { + let cond_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_cond.value { + symbol + } else { + env.unique_symbol() + }; + + let mut stmt = from_can_when( + env, + cond_var, + expr_var, + region, + cond_symbol, + branches, + layout_cache, + procs, + None, + ); + + // define the `when` condition + if let roc_can::expr::Expr::Var(_) = loc_cond.value { + // do nothing + } else { + stmt = with_hole( + env, + loc_cond.value, + procs, + layout_cache, + cond_symbol, + env.arena.alloc(stmt), + ); + }; + + stmt + } + If { + cond_var, + branch_var, + branches, + final_else, + } => { + let ret_layout = layout_cache + .from_var(env.arena, branch_var, env.subs) + .expect("invalid ret_layout"); + let cond_layout = layout_cache + .from_var(env.arena, cond_var, env.subs) + .expect("invalid cond_layout"); + + let mut stmt = from_can(env, final_else.value, procs, layout_cache); + + for (loc_cond, loc_then) in branches.into_iter().rev() { + let branching_symbol = env.unique_symbol(); + let then = from_can(env, loc_then.value, procs, layout_cache); + + stmt = Stmt::Cond { + cond_symbol: branching_symbol, + branching_symbol, + cond_layout: cond_layout.clone(), + branching_layout: cond_layout.clone(), + pass: env.arena.alloc(then), + fail: env.arena.alloc(stmt), + ret_layout: ret_layout.clone(), + }; + + // add condition + stmt = with_hole( + env, + loc_cond.value, + procs, + layout_cache, + branching_symbol, + env.arena.alloc(stmt), + ); + } + + stmt + } LetRec(defs, cont, _, _) => { // because Roc is strict, only functions can be recursive! for def in defs.into_iter() { diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 273740e130..8688bdc975 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -712,7 +712,7 @@ mod test_mono { x : [ Red, White, Blue ] x = Blue - y = + y = when x is Red -> 1 White -> 2 @@ -749,7 +749,7 @@ mod test_mono { r#" if True then 1 - else + else 2 "#, indoc!( @@ -776,7 +776,7 @@ mod test_mono { 1 else if False then 2 - else + else 3 "#, indoc!( @@ -809,7 +809,7 @@ mod test_mono { x : Result Int Int x = Ok 2 - y = + y = when x is Ok 3 -> 1 Ok _ -> 2 @@ -1039,4 +1039,139 @@ mod test_mono { ), ) } + + #[allow(dead_code)] + fn quicksort_help() { + with_larger_debug_stack(|| { + compiles_to_ir( + indoc!( + r#" + quicksortHelp : List (Num a), Int, Int -> List (Num a) + quicksortHelp = \list, low, high -> + if low < high then + (Pair partitionIndex partitioned) = Pair 0 [] + + partitioned + |> quicksortHelp low (partitionIndex - 1) + |> quicksortHelp (partitionIndex + 1) high + else + list + + quicksortHelp [] 0 0 + "# + ), + indoc!( + r#" + "# + ), + ) + }) + } + + #[allow(dead_code)] + fn quicksort_partition_help() { + with_larger_debug_stack(|| { + compiles_to_ir( + indoc!( + r#" + partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ] + partitionHelp = \i, j, list, high, pivot -> + if j < high then + when List.get list j is + Ok value -> + if value <= pivot then + partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + else + partitionHelp i (j + 1) list high pivot + + Err _ -> + Pair i list + else + Pair i list + + + + partitionHelp 0 0 [] 0 0 + "# + ), + indoc!( + r#" + "# + ), + ) + }) + } + + #[allow(dead_code)] + fn quicksort_full() { + with_larger_debug_stack(|| { + compiles_to_ir( + indoc!( + r#" + quicksort = \originalList -> + quicksortHelp : List (Num a), Int, Int -> List (Num a) + quicksortHelp = \list, low, high -> + if low < high then + (Pair partitionIndex partitioned) = partition low high list + + partitioned + |> quicksortHelp low (partitionIndex - 1) + |> quicksortHelp (partitionIndex + 1) high + else + list + + + swap : Int, Int, List a -> List a + swap = \i, j, list -> + when Pair (List.get list i) (List.get list j) is + Pair (Ok atI) (Ok atJ) -> + list + |> List.set i atJ + |> List.set j atI + + _ -> + [] + + partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] + partition = \low, high, initialList -> + when List.get initialList high is + Ok pivot -> + when partitionHelp (low - 1) low initialList high pivot is + Pair newI newList -> + Pair (newI + 1) (swap (newI + 1) high newList) + + Err _ -> + Pair (low - 1) initialList + + + partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ] + partitionHelp = \i, j, list, high, pivot -> + if j < high then + when List.get list j is + Ok value -> + if value <= pivot then + partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot + else + partitionHelp i (j + 1) list high pivot + + Err _ -> + Pair i list + else + Pair i list + + + + n = List.len originalList + quicksortHelp originalList 0 (n - 1) + + quicksort [1,2,3] + "# + ), + indoc!( + r#" + "# + ), + ) + }) + } } diff --git a/examples/quicksort/host.rs b/examples/quicksort/host.rs index c055466df6..3aa0efb138 100644 --- a/examples/quicksort/host.rs +++ b/examples/quicksort/host.rs @@ -75,7 +75,7 @@ pub fn main() { }; // TODO FIXME don't truncate! This is just for testing. - nums.truncate(1_000_00); + nums.truncate(1_000_000); let nums: Box<[i64]> = nums.into();