diff --git a/compiler/gen/src/crane/build.rs b/compiler/gen/src/crane/build.rs index 5349788285..844420f5fb 100644 --- a/compiler/gen/src/crane/build.rs +++ b/compiler/gen/src/crane/build.rs @@ -77,13 +77,17 @@ pub fn build_expr<'a, B: Backend>( Bool(val) => builder.ins().bconst(types::B1, *val), Byte(val) => builder.ins().iconst(types::I8, *val as i64), Cond { - cond, - pass, - fail, + branch_symbol, + pass: (pass_stores, pass_expr), + fail: (fail_stores, fail_expr), cond_layout, ret_layout, + .. } => { - let cond_value = load_symbol(env, scope, builder, *cond); + let cond_value = load_symbol(env, scope, builder, *branch_symbol); + + let pass = env.arena.alloc(Expr::Store(pass_stores, pass_expr)); + let fail = env.arena.alloc(Expr::Store(fail_stores, fail_expr)); let branch = Branch2 { cond: cond_value, @@ -98,15 +102,24 @@ pub fn build_expr<'a, B: Backend>( Switch { cond, branches, - default_branch, + default_branch: (default_stores, default_expr), ret_layout, cond_layout, } => { let ret_type = type_from_layout(env.cfg, &ret_layout); + + let default_branch = env.arena.alloc(Expr::Store(default_stores, default_expr)); + + let mut combined = Vec::with_capacity_in(branches.len(), env.arena); + + for (int, stores, expr) in branches.iter() { + combined.push((*int, Expr::Store(stores, expr))); + } + let switch_args = SwitchArgs { cond_layout, cond_expr: cond, - branches, + branches: combined.into_bump_slice(), default_branch, ret_type, }; diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index d5b03e06a9..e9d6a804e9 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -57,14 +57,17 @@ pub fn build_expr<'a, 'ctx, 'env>( Bool(b) => env.context.bool_type().const_int(*b as u64, false).into(), Byte(b) => env.context.i8_type().const_int(*b as u64, false).into(), Cond { - cond, - pass, - fail, + branch_symbol, + pass: (pass_stores, pass_expr), + fail: (fail_stores, fail_expr), ret_layout, .. } => { + let pass = env.arena.alloc(Expr::Store(pass_stores, pass_expr)); + let fail = env.arena.alloc(Expr::Store(fail_stores, fail_expr)); + let conditional = Branch2 { - cond, + cond: branch_symbol, pass, fail, ret_layout: ret_layout.clone(), @@ -75,16 +78,25 @@ pub fn build_expr<'a, 'ctx, 'env>( Switch { cond, branches, - default_branch, + default_branch: (default_stores, default_expr), ret_layout, cond_layout, } => { let ret_type = basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); + + let default_branch = env.arena.alloc(Expr::Store(default_stores, default_expr)); + + let mut combined = Vec::with_capacity_in(branches.len(), env.arena); + + for (int, stores, expr) in branches.iter() { + combined.push((*int, Expr::Store(stores, expr))); + } + let switch_args = SwitchArgs { cond_layout: cond_layout.clone(), cond_expr: cond, - branches, + branches: combined.into_bump_slice(), default_branch, ret_type, }; diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index bb1d64fe8a..2d39423d29 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -856,26 +856,31 @@ enum Decider<'a, T> { #[derive(Clone, Debug, PartialEq)] enum Choice<'a> { - Inline(Expr<'a>), + Inline(Stores<'a>, Expr<'a>), Jump(Label), } +type Stores<'a> = &'a [(Symbol, Layout<'a>, Expr<'a>)]; + pub fn optimize_when<'a>( env: &mut Env<'a, '_>, cond_symbol: Symbol, cond_layout: Layout<'a>, ret_layout: Layout<'a>, - opt_branches: Vec<(Pattern<'a>, Guard<'a>, Expr<'a>)>, + opt_branches: Vec<(Pattern<'a>, Guard<'a>, Stores<'a>, Expr<'a>)>, ) -> Expr<'a> { let (patterns, _indexed_branches) = opt_branches .into_iter() .enumerate() - .map(|(index, (pattern, guard, branch))| { - ((guard, pattern, index as u64), (index as u64, branch)) + .map(|(index, (pattern, guard, stores, branch))| { + ( + (guard, pattern, index as u64), + (index as u64, stores, branch), + ) }) .unzip(); - let indexed_branches: Vec<(u64, Expr<'a>)> = _indexed_branches; + let indexed_branches: Vec<(u64, Stores<'a>, Expr<'a>)> = _indexed_branches; let decision_tree = compile(patterns); let decider = tree_to_decider(decision_tree); @@ -884,8 +889,9 @@ pub fn optimize_when<'a>( let mut choices = MutMap::default(); let mut jumps = Vec::new(); - for (index, branch) in indexed_branches.into_iter() { - let ((branch_index, choice), opt_jump) = create_choices(&target_counts, index, branch); + for (index, stores, branch) in indexed_branches.into_iter() { + let ((branch_index, choice), opt_jump) = + create_choices(&target_counts, index, stores, branch); if let Some(jump) = opt_jump { jumps.push(jump); @@ -896,7 +902,7 @@ pub fn optimize_when<'a>( let choice_decider = insert_choices(&choices, decider); - let result = decide_to_branching( + let (stores, expr) = decide_to_branching( env, cond_symbol, cond_layout, @@ -908,7 +914,7 @@ pub fn optimize_when<'a>( // increase the jump counter by the number of jumps in this branching structure *env.jump_counter += jumps.len() as u64; - result + Expr::Store(stores, env.arena.alloc(expr)) } fn path_to_expr<'a>( @@ -1064,16 +1070,16 @@ fn decide_to_branching<'a>( cond_layout: Layout<'a>, ret_layout: Layout<'a>, decider: Decider<'a, Choice<'a>>, - jumps: &Vec<(u64, Expr<'a>)>, -) -> Expr<'a> { + jumps: &Vec<(u64, Stores<'a>, Expr<'a>)>, +) -> (Stores<'a>, Expr<'a>) { use Choice::*; use Decider::*; let jump_count = *env.jump_counter; match decider { - Leaf(Jump(label)) => Expr::Jump(label + jump_count), - Leaf(Inline(expr)) => expr, + Leaf(Jump(label)) => (&[], Expr::Jump(label + jump_count)), + Leaf(Inline(stores, expr)) => (stores, expr), Chain { test_chain, success, @@ -1087,51 +1093,48 @@ fn decide_to_branching<'a>( test_to_equality(env, cond_symbol, &cond_layout, &path, test, &mut tests); } - let pass = env.arena.alloc(decide_to_branching( + let (pass_stores, pass_expr) = decide_to_branching( env, cond_symbol, cond_layout.clone(), ret_layout.clone(), *success, jumps, - )); + ); - let fail = env.arena.alloc(decide_to_branching( + let (fail_stores, fail_expr) = decide_to_branching( env, cond_symbol, cond_layout.clone(), ret_layout.clone(), *failure, jumps, - )); + ); + + let fail = (fail_stores, &*env.arena.alloc(fail_expr)); + let pass = (pass_stores, &*env.arena.alloc(pass_expr)); let condition = boolean_all(env.arena, tests); + let branch_symbol = env.fresh_symbol(); + let stores = [(branch_symbol, Layout::Builtin(Builtin::Bool), condition)]; + let cond_layout = Layout::Builtin(Builtin::Bool); - if let Expr::Load(symbol) = condition { - Expr::Cond { - cond: symbol, - cond_layout, - pass, - fail, - ret_layout, - } - } else { - let cond_symbol = env.fresh_symbol(); - let stores = vec![(cond_symbol, cond_layout.clone(), condition)]; - + ( + env.arena.alloc(stores), Expr::Store( - env.arena.alloc(stores), + &[], env.arena.alloc(Expr::Cond { - cond: cond_symbol, + cond_symbol, + branch_symbol, cond_layout, pass, fail, ret_layout, }), - ) - } + ), + ) } FanOut { path, @@ -1142,19 +1145,20 @@ fn decide_to_branching<'a>( // switch on the tag discriminant (currently an i64 value) let (cond, cond_layout) = path_to_expr_help(env, cond_symbol, &path, cond_layout); - let default_branch = env.arena.alloc(decide_to_branching( + let (default_stores, default_expr) = decide_to_branching( env, cond_symbol, cond_layout.clone(), ret_layout.clone(), *fallback, jumps, - )); + ); + let default_branch = (default_stores, &*env.arena.alloc(default_expr)); let mut branches = bumpalo::collections::Vec::with_capacity_in(tests.len(), env.arena); for (test, decider) in tests { - let branch = decide_to_branching( + let (stores, branch) = decide_to_branching( env, cond_symbol, cond_layout.clone(), @@ -1172,17 +1176,20 @@ fn decide_to_branching<'a>( other => todo!("other {:?}", other), }; - branches.push((tag, branch)); + branches.push((tag, stores, branch)); } // make a jump table based on the tests - Expr::Switch { - cond: env.arena.alloc(cond), - cond_layout, - branches: branches.into_bump_slice(), - default_branch, - ret_layout, - } + ( + &[], + Expr::Switch { + cond: env.arena.alloc(cond), + cond_layout, + branches: branches.into_bump_slice(), + default_branch, + ret_layout, + }, + ) } } } @@ -1372,18 +1379,23 @@ fn count_targets_help(decision_tree: &Decider, targets: &mut MutMap( target_counts: &MutMap, target: u64, + stores: Stores<'a>, branch: Expr<'a>, -) -> ((u64, Choice<'a>), Option<(u64, Expr<'a>)>) { +) -> ((u64, Choice<'a>), Option<(u64, Stores<'a>, Expr<'a>)>) { match target_counts.get(&target) { None => unreachable!( "this should never happen: {:?} not in {:?}", target, target_counts ), - Some(1) => ((target, Choice::Inline(branch)), None), - Some(_) => ((target, Choice::Jump(target)), Some((target, branch))), + Some(1) => ((target, Choice::Inline(stores, branch)), None), + Some(_) => ( + (target, Choice::Jump(target)), + Some((target, stores, branch)), + ), } } diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index a7d1e96f3e..165ee915c7 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -122,6 +122,8 @@ impl<'a, 'i> Env<'a, 'i> { } } +pub type Stores<'a> = &'a [(Symbol, Layout<'a>, Expr<'a>)]; + #[derive(Clone, Debug, PartialEq)] pub enum Expr<'a> { // Literals @@ -152,11 +154,18 @@ pub enum Expr<'a> { // The left-hand side of the conditional comparison and the right-hand side. // These are stored separately because there are different machine instructions // for e.g. "compare float and jump" vs. "compare integer and jump" - cond: Symbol, + + // symbol storing the original expression that we branch on, e.g. `Ok 42` + // required for RC logic + cond_symbol: Symbol, + + // symbol storing the value that we branch on, e.g. `1` representing the `Ok` tag + branch_symbol: Symbol, + cond_layout: Layout<'a>, // What to do if the condition either passes or fails - pass: &'a Expr<'a>, - fail: &'a Expr<'a>, + pass: (Stores<'a>, &'a Expr<'a>), + fail: (Stores<'a>, &'a Expr<'a>), ret_layout: Layout<'a>, }, /// Conditional branches for integers. These are more efficient. @@ -166,9 +175,9 @@ pub enum Expr<'a> { cond_layout: Layout<'a>, /// The u64 in the tuple will be compared directly to the condition Expr. /// If they are equal, this branch will be taken. - branches: &'a [(u64, Expr<'a>)], + branches: &'a [(u64, Stores<'a>, Expr<'a>)], /// If no other branches pass, this default branch will be taken. - default_branch: &'a Expr<'a>, + default_branch: (Stores<'a>, &'a Expr<'a>), /// Each branch must return a value of this type. ret_layout: Layout<'a>, }, @@ -647,19 +656,20 @@ fn from_can<'a>( let cond = from_can(env, loc_cond.value, procs, None); let then = from_can(env, loc_then.value, procs, None); - let cond_symbol = env.fresh_symbol(); + let branch_symbol = env.fresh_symbol(); let cond_expr = Expr::Cond { - cond: cond_symbol, + cond_symbol: branch_symbol, + branch_symbol, cond_layout: cond_layout.clone(), - pass: env.arena.alloc(then), - fail: env.arena.alloc(expr), + pass: (&[], env.arena.alloc(then)), + fail: (&[], env.arena.alloc(expr)), ret_layout: ret_layout.clone(), }; expr = Expr::Store( env.arena - .alloc(vec![(cond_symbol, Layout::Builtin(Builtin::Bool), cond)]), + .alloc(vec![(branch_symbol, Layout::Builtin(Builtin::Bool), cond)]), env.arena.alloc(cond_expr), ); } @@ -1070,7 +1080,7 @@ fn from_can_when<'a>( let mut stores = Vec::with_capacity_in(1, env.arena); - let (mono_guard, expr_with_stores) = match store_pattern( + let (mono_guard, stores, expr) = match store_pattern( env, &mono_pattern, cond_symbol, @@ -1091,15 +1101,14 @@ fn from_can_when<'a>( stores: stores.into_bump_slice(), expr, }, + &[] as &[_], mono_expr.clone(), ) } else { ( crate::decision_tree::Guard::NoGuard, - Expr::Store( - stores.into_bump_slice(), - env.arena.alloc(mono_expr.clone()), - ), + stores.into_bump_slice(), + mono_expr.clone(), ) } } @@ -1111,19 +1120,21 @@ fn from_can_when<'a>( stores: &[], expr: Expr::RuntimeError(env.arena.alloc(message)), }, + &[] as &[_], // we can never hit this Expr::RuntimeError(&"invalid pattern with guard: unreachable"), ) } else { ( crate::decision_tree::Guard::NoGuard, + &[] as &[_], Expr::RuntimeError(env.arena.alloc(message)), ) } } }; - opt_branches.push((mono_pattern, mono_guard, expr_with_stores)); + opt_branches.push((mono_pattern, mono_guard, stores, expr)); } } diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index c4c9b9a660..acaf0b3882 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -171,10 +171,11 @@ mod test_mono { Expr::Bool(true), )], &Cond { - cond: gen_symbol_0, + cond_symbol: gen_symbol_0, + branch_symbol: gen_symbol_0, cond_layout: Builtin(Bool), - pass: &Expr::Str("bar"), - fail: &Expr::Str("foo"), + pass: (&[] as &[_], &Expr::Str("bar")), + fail: (&[] as &[_], &Expr::Str("foo")), ret_layout: Builtin(Str), }, ) @@ -208,22 +209,27 @@ mod test_mono { Expr::Bool(true), )], &Cond { - cond: gen_symbol_0, + cond_symbol: gen_symbol_0, + branch_symbol: gen_symbol_0, cond_layout: Builtin(Bool), - pass: &Expr::Str("bar"), - fail: &Store( - &[( - gen_symbol_1, - Layout::Builtin(layout::Builtin::Bool), - Expr::Bool(false), - )], - &Cond { - cond: gen_symbol_1, - cond_layout: Builtin(Bool), - pass: &Expr::Str("foo"), - fail: &Expr::Str("baz"), - ret_layout: Builtin(Str), - }, + pass: (&[] as &[_], &Expr::Str("bar")), + fail: ( + &[] as &[_], + &Store( + &[( + gen_symbol_1, + Layout::Builtin(layout::Builtin::Bool), + Expr::Bool(false), + )], + &Cond { + cond_symbol: gen_symbol_1, + branch_symbol: gen_symbol_1, + cond_layout: Builtin(Bool), + pass: (&[] as &[_], &Expr::Str("foo")), + fail: (&[] as &[_], &Expr::Str("baz")), + ret_layout: Builtin(Str), + }, + ), ), ret_layout: Builtin(Str), }, @@ -261,10 +267,11 @@ mod test_mono { Expr::Bool(true), )], &Cond { - cond: gen_symbol_0, + cond_symbol: gen_symbol_0, + branch_symbol: gen_symbol_0, cond_layout: Builtin(Bool), - pass: &Expr::Str("bar"), - fail: &Expr::Str("foo"), + pass: (&[] as &[_], &Expr::Str("bar")), + fail: (&[] as &[_], &Expr::Str("foo")), ret_layout: Builtin(Str), }, ), diff --git a/compiler/mono/tests/test_opt.rs b/compiler/mono/tests/test_opt.rs index 2b8bffc974..fb2aaaf79d 100644 --- a/compiler/mono/tests/test_opt.rs +++ b/compiler/mono/tests/test_opt.rs @@ -114,14 +114,15 @@ mod test_opt { } Cond { - cond: _, + cond_symbol: _, + branch_symbol: _, cond_layout: _, pass, fail, ret_layout: _, } => { - extract_named_calls_help(pass, calls, unexpected_calls); - extract_named_calls_help(fail, calls, unexpected_calls); + extract_named_calls_help(pass.1, calls, unexpected_calls); + extract_named_calls_help(fail.1, calls, unexpected_calls); } Switch { cond, @@ -131,9 +132,9 @@ mod test_opt { ret_layout: _, } => { extract_named_calls_help(cond, calls, unexpected_calls); - extract_named_calls_help(default_branch, calls, unexpected_calls); + extract_named_calls_help(default_branch.1, calls, unexpected_calls); - for (_, branch_expr) in branches.iter() { + for (_, _, branch_expr) in branches.iter() { extract_named_calls_help(branch_expr, calls, unexpected_calls); } }