diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 7f5a19275d..59b5760dda 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -2158,4 +2158,61 @@ mod gen_primitives { i64 ); } + + #[test] + fn switch_fuse_rc_non_exhaustive() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + Foo : [ A I64 Foo, B I64 Foo, C I64 Foo, Empty ] + + sum : Foo, I64 -> I64 + sum = \foo, accum -> + when foo is + A x resta -> sum resta (x + accum) + B x restb -> sum restb (x + accum) + # Empty -> accum + # C x restc -> sum restc (x + accum) + _ -> accum + + main : I64 + main = + A 1 (B 2 (C 3 Empty)) + |> sum 0 + "# + ), + 3, + i64 + ); + } + + #[test] + fn switch_fuse_rc_exhaustive() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + Foo : [ A I64 Foo, B I64 Foo, C I64 Foo, Empty ] + + sum : Foo, I64 -> I64 + sum = \foo, accum -> + when foo is + A x resta -> sum resta (x + accum) + B x restb -> sum restb (x + accum) + C x restc -> sum restc (x + accum) + Empty -> accum + + main : I64 + main = + A 1 (B 2 (C 3 Empty)) + |> sum 0 + "# + ), + 6, + i64 + ); + } } diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index 275d7278c4..4e2388101a 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -1654,6 +1654,9 @@ fn decide_to_branching<'a>( let mut branches = bumpalo::collections::Vec::with_capacity_in(tests.len(), env.arena); + let mut tag_id_sum: i64 = (0..tests.len() as i64 + 1).sum(); + let mut union_size: i64 = -1; + for (test, decider) in tests { let branch = decide_to_branching( env, @@ -1675,16 +1678,42 @@ fn decide_to_branching<'a>( other => todo!("other {:?}", other), }; - branches.push((tag, BranchInfo::None, branch)); + // branch info is only useful for refcounted values + let branch_info = if let Test::IsCtor { tag_id, union, .. } = test { + tag_id_sum -= tag_id as i64; + union_size = union.alternatives.len() as i64; + + BranchInfo::Constructor { + scrutinee: inner_cond_symbol, + layout: inner_cond_layout.clone(), + tag_id, + } + } else { + tag_id_sum = -1; + BranchInfo::None + }; + + branches.push((tag, branch_info, branch)); } + // determine if the switch is exhaustive + let default_branch_info = if tag_id_sum > 0 && union_size > 0 { + BranchInfo::Constructor { + scrutinee: inner_cond_symbol, + layout: inner_cond_layout.clone(), + tag_id: tag_id_sum as u8, + } + } else { + BranchInfo::None + }; + // We have learned more about the exact layout of the cond (based on the path) // but tests are still relative to the original cond symbol let mut switch = Stmt::Switch { cond_layout: inner_cond_layout, cond_symbol: inner_cond_symbol, branches: branches.into_bump_slice(), - default_branch: (BranchInfo::None, env.arena.alloc(default_branch)), + default_branch: (default_branch_info, env.arena.alloc(default_branch)), ret_layout, };