From ecad660e7fdd7fb2e9ae13c8c5b28fd2adc0995f Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Fri, 24 Mar 2023 18:49:46 -0500 Subject: [PATCH] Ensure that when jumping to a branch, all pattern symbols are loaded If we are jumping to a target branch, it is necessary that the target branch has all required pattern symbols loaded in it. Usually this is already the case, but there is an exception with guarded patterns. Guarded patterns have their patterns loaded only right before the guard is evaluated, which happens at some point further along the decision tree. As such, when a guarded pattern jumps to its target destination, it should append the loaded patterns as parameters on the target joinpoint. --- crates/compiler/mono/src/decision_tree.rs | 216 ++++++++++++++++------ 1 file changed, 162 insertions(+), 54 deletions(-) diff --git a/crates/compiler/mono/src/decision_tree.rs b/crates/compiler/mono/src/decision_tree.rs index 444565e5c3..7dcf9bb984 100644 --- a/crates/compiler/mono/src/decision_tree.rs +++ b/crates/compiler/mono/src/decision_tree.rs @@ -1,7 +1,7 @@ use crate::borrow::Ownership; use crate::ir::{ - build_list_index_probe, BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId, - ListIndex, Literal, Param, Pattern, Procs, Stmt, + build_list_index_probe, substitute_in_exprs_many, BranchInfo, Call, CallType, DestructType, + Env, Expr, JoinPointId, ListIndex, Literal, Param, Pattern, Procs, Stmt, }; use crate::layout::{ Builtin, InLayout, Layout, LayoutCache, LayoutInterner, TLLayoutInterner, TagIdIntType, @@ -57,6 +57,10 @@ impl<'a> Guard<'a> { fn is_none(&self) -> bool { self == &Guard::NoGuard } + + fn is_some(&self) -> bool { + !self.is_none() + } } type Edge<'a> = (GuardedTest<'a>, DecisionTree<'a>); @@ -82,10 +86,12 @@ enum GuardedTest<'a> { /// body stmt: Stmt<'a>, }, + // e.g. ` -> ...` TestNotGuarded { test: Test<'a>, }, - Placeholder, + // e.g. `_ -> ...` or `x -> ...` + PlaceholderWithGuard, } #[derive(Clone, Copy, Debug, PartialEq, Hash)] @@ -196,7 +202,7 @@ impl<'a> Hash for GuardedTest<'a> { state.write_u8(0); test.hash(state); } - GuardedTest::Placeholder => { + GuardedTest::PlaceholderWithGuard => { state.write_u8(2); } } @@ -264,6 +270,7 @@ fn to_decision_tree<'a>( let path = pick_path(&branches).clone(); let bs = branches.clone(); + let (edges, fallback) = gather_edges(interner, branches, &path); let mut decision_edges: Vec<_> = edges @@ -308,7 +315,7 @@ fn break_out_guard<'a>( ) -> DecisionTree<'a> { match edges .iter() - .position(|(t, _)| matches!(t, GuardedTest::Placeholder)) + .position(|(t, _)| matches!(t, GuardedTest::PlaceholderWithGuard)) { None => DecisionTree::Decision { path, @@ -347,7 +354,7 @@ fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool { .all(|t| matches!(t, GuardedTest::TestNotGuarded { .. })); match tests.last().unwrap() { - GuardedTest::Placeholder => false, + GuardedTest::PlaceholderWithGuard => false, GuardedTest::GuardedNoTest { .. } => false, GuardedTest::TestNotGuarded { test } => no_guard && tests_are_complete_help(test, length), } @@ -687,7 +694,7 @@ fn test_at_path<'a>( if let Guard::Guard { .. } = &branch.guard { // no tests for this pattern remain, but we cannot discard it yet // because it has a guard! - Some(GuardedTest::Placeholder) + Some(GuardedTest::PlaceholderWithGuard) } else { None } @@ -709,10 +716,33 @@ fn edges_for<'a>( // if we test for a guard, skip all branches until one that has a guard let it = match test { - GuardedTest::GuardedNoTest { .. } | GuardedTest::Placeholder => { + GuardedTest::GuardedNoTest { .. } => { let index = branches .iter() - .position(|b| !b.guard.is_none()) + .position(|b| b.guard.is_some()) + .expect("if testing for a guard, one branch must have a guard"); + + branches[index..].iter() + } + GuardedTest::PlaceholderWithGuard => { + // Skip all branches until we hit the one with a placeholder and a guard. + let index = branches + .iter() + .position(|b| { + if b.guard.is_none() { + return false; + } + + let (_, pattern) = b + .patterns + .iter() + .find(|(branch_path, _)| branch_path == path) + .expect( + "if testing for a placeholder with guard, must find a branch matching the path", + ); + + test_for_pattern(pattern).is_none() + }) .expect("if testing for a guard, one branch must have a guard"); branches[index..].iter() @@ -741,7 +771,7 @@ fn to_relevant_branch<'a>( found_pattern: pattern, end, } => match guarded_test { - GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => { + GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => { // if there is no test, the pattern should not require any debug_assert!( matches!(pattern, Pattern::Identifier(_) | Pattern::Underscore,), @@ -1332,7 +1362,7 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz relevant_tests.len() + (if !fallbacks { 0 } else { 1 }) } -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, PartialEq)] enum Decider<'a, T> { Leaf(T), Guarded { @@ -1364,6 +1394,17 @@ enum Choice<'a> { type StoresVec<'a> = bumpalo::collections::Vec<'a, (Symbol, InLayout<'a>, Expr<'a>)>; +struct JumpSpec<'a> { + target_index: u64, + id: JoinPointId, + /// Symbols, from the unpacked pattern, to add on when jumping to the target. + jump_pattern_param_symbols: &'a [Symbol], + + // Used to construct the joinpoint + join_params: &'a [Param<'a>], + join_body: Stmt<'a>, +} + pub fn optimize_when<'a>( env: &mut Env<'a, '_>, procs: &mut Procs<'a>, @@ -1373,11 +1414,11 @@ pub fn optimize_when<'a>( ret_layout: InLayout<'a>, opt_branches: bumpalo::collections::Vec<'a, (Pattern<'a>, Guard<'a>, Stmt<'a>)>, ) -> Stmt<'a> { - let (patterns, _indexed_branches) = opt_branches + let (patterns, indexed_branches): (_, Vec<_>) = opt_branches .into_iter() .enumerate() .map(|(index, (pattern, guard, branch))| { - let has_guard = !guard.is_none(); + let has_guard = guard.is_some(); ( (guard, pattern.clone(), index as u64), (index as u64, branch, pattern, has_guard), @@ -1385,8 +1426,6 @@ pub fn optimize_when<'a>( }) .unzip(); - let indexed_branches: Vec<_> = _indexed_branches; - let decision_tree = compile(&layout_cache.interner, patterns); let decider = tree_to_decider(decision_tree); @@ -1397,19 +1436,95 @@ pub fn optimize_when<'a>( let mut choices = MutMap::default(); let mut jumps = Vec::new(); - for (index, mut branch, pattern, has_guard) in indexed_branches.into_iter() { - // bind the fields referenced in the pattern. For guards this happens separately, so - // the pattern variables are defined when evaluating the guard. - if !has_guard { - branch = - crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, branch); + for (target, mut branch, pattern, has_guard) in indexed_branches.into_iter() { + let should_inline = { + let target_counts = &target_counts; + match target_counts.get(target as usize) { + None => unreachable!( + "this should never happen: {:?} not in {:?}", + target, target_counts + ), + Some(count) => *count == 1, + } + }; + + let join_params: &'a [Param<'a>]; + let jump_pattern_param_symbols: &'a [Symbol]; + match (has_guard, should_inline) { + (false, _) => { + // Bind the fields referenced in the pattern. + branch = crate::ir::store_pattern( + env, + procs, + layout_cache, + &pattern, + cond_symbol, + branch, + ); + + join_params = env.arena.alloc([]); + jump_pattern_param_symbols = env.arena.alloc([]); + } + (true, true) => { + // Nothing more to do - the patterns will be bound when the guard is evaluated in + // `decide_to_branching`. + join_params = env.arena.alloc([]); + jump_pattern_param_symbols = env.arena.alloc([]); + } + (true, false) => { + // The patterns will be bound when the guard is evaluated, and then we need to get + // them back into the joinpoint here. + // + // So, figure out what symbols the pattern binds, and update the joinpoint + // parameter to take each symbol. Then, when the joinpoint is called, the unpacked + // symbols will be filled in. + // + // Since the joinpoint's parameters will be fresh symbols, the join body also needs + // updating. + let pattern_bindings = pattern.collect_symbols(cond_layout); + + let mut parameters_buf = + bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena); + let mut pattern_symbols_buf = + bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena); + + for &(pattern_symbol, layout) in pattern_bindings.iter() { + let param_symbol = env.unique_symbol(); + parameters_buf.push(Param { + symbol: param_symbol, + layout, + ownership: Ownership::Owned, + }); + pattern_symbols_buf.push(pattern_symbol); + } + + join_params = parameters_buf.into_bump_slice(); + jump_pattern_param_symbols = pattern_symbols_buf.into_bump_slice(); + + let substitutions = pattern_bindings + .iter() + .zip(join_params.iter()) + .map(|((pat, _), param)| (*pat, param.symbol)) + .collect(); + substitute_in_exprs_many(env.arena, &mut branch, substitutions); + } } - let ((branch_index, choice), opt_jump) = create_choices(&target_counts, index, branch); + let ((branch_index, choice), opt_jump) = if should_inline { + ((target, Choice::Inline(branch)), None) + } else { + ((target, Choice::Jump(target)), Some((target, branch))) + }; - if let Some((index, body)) = opt_jump { + if let Some((target_index, body)) = opt_jump { let id = JoinPointId(env.unique_symbol()); - jumps.push((index, id, body)); + jumps.push(JumpSpec { + target_index, + id, + jump_pattern_param_symbols, + join_params, + join_body: body, + }); } choices.insert(branch_index, choice); @@ -1428,11 +1543,18 @@ pub fn optimize_when<'a>( &jumps, ); - for (_, id, body) in jumps.into_iter() { + for JumpSpec { + target_index: _, + id, + jump_pattern_param_symbols: _, + join_params, + join_body, + } in jumps.into_iter() + { stmt = Stmt::Join { id, - parameters: &[], - body: env.arena.alloc(body), + parameters: join_params, + body: env.arena.alloc(join_body), remainder: env.arena.alloc(stmt), }; } @@ -1929,7 +2051,7 @@ fn decide_to_branching<'a>( cond_layout: InLayout<'a>, ret_layout: InLayout<'a>, decider: Decider<'a, Choice<'a>>, - jumps: &[(u64, JoinPointId, Stmt<'a>)], + jumps: &[JumpSpec<'a>], ) -> Stmt<'a> { use Choice::*; use Decider::*; @@ -1939,10 +2061,10 @@ fn decide_to_branching<'a>( match decider { Leaf(Jump(label)) => { let index = jumps - .binary_search_by_key(&label, |r| r.0) + .binary_search_by_key(&label, |r| r.target_index) .expect("jump not in list of jumps"); - Stmt::Jump(jumps[index].1, &[]) + Stmt::Jump(jumps[index].id, jumps[index].jump_pattern_param_symbols) } Leaf(Inline(expr)) => expr, Guarded { @@ -1997,8 +2119,8 @@ fn decide_to_branching<'a>( let join = Stmt::Join { id, parameters: arena.alloc([param]), - remainder: arena.alloc(stmt), body: arena.alloc(decide), + remainder: arena.alloc(stmt), }; crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, join) @@ -2282,15 +2404,17 @@ fn sort_edge_tests_by_priority(edges: &mut [Edge<'_>]) { edges.sort_by(|(t1, _), (t2, _)| match (t1, t2) { // Guarded takes priority (GuardedNoTest { .. }, GuardedNoTest { .. }) => Equal, - (GuardedNoTest { .. }, TestNotGuarded { .. }) | (GuardedNoTest { .. }, Placeholder) => Less, + (GuardedNoTest { .. }, TestNotGuarded { .. }) + | (GuardedNoTest { .. }, PlaceholderWithGuard) => Less, // Interesting case: what test do we pick? (TestNotGuarded { test: t1 }, TestNotGuarded { test: t2 }) => order_tests(t1, t2), // Otherwise we are between guarded and fall-backs (TestNotGuarded { .. }, GuardedNoTest { .. }) => Greater, - (TestNotGuarded { .. }, Placeholder) => Less, + (TestNotGuarded { .. }, PlaceholderWithGuard) => Less, // Placeholder is always last - (Placeholder, Placeholder) => Equal, - (Placeholder, GuardedNoTest { .. }) | (Placeholder, TestNotGuarded { .. }) => Greater, + (PlaceholderWithGuard, PlaceholderWithGuard) => Equal, + (PlaceholderWithGuard, GuardedNoTest { .. }) + | (PlaceholderWithGuard, TestNotGuarded { .. }) => Greater, }); fn order_tests(t1: &Test, t2: &Test) -> Ordering { @@ -2452,7 +2576,7 @@ fn fanout_decider_help<'a>( guarded_test: GuardedTest<'a>, ) -> (Test<'a>, Decider<'a, u64>) { match guarded_test { - GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => { + GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => { unreachable!("this would not end up in a switch") } GuardedTest::TestNotGuarded { test } => { @@ -2478,7 +2602,7 @@ fn chain_decider<'a>( stmt, pattern, success, - failure: failure.clone(), + failure, } } GuardedTest::TestNotGuarded { test } => { @@ -2489,7 +2613,7 @@ fn chain_decider<'a>( } } - GuardedTest::Placeholder => { + GuardedTest::PlaceholderWithGuard => { // ? tree_to_decider(success_tree) } @@ -2572,22 +2696,6 @@ fn count_targets(targets: &mut bumpalo::collections::Vec, initial: &Decider } } -#[allow(clippy::type_complexity)] -fn create_choices<'a>( - target_counts: &bumpalo::collections::Vec<'a, u64>, - target: u64, - branch: Stmt<'a>, -) -> ((u64, Choice<'a>), Option<(u64, Stmt<'a>)>) { - match target_counts.get(target as usize) { - None => unreachable!( - "this should never happen: {:?} not in {:?}", - target, target_counts - ), - Some(1) => ((target, Choice::Inline(branch)), None), - Some(_) => ((target, Choice::Jump(target)), Some((target, branch))), - } -} - fn insert_choices<'a>( choice_dict: &MutMap>, decider: Decider<'a, u64>,