Handle guards that appear multiple times in a compiled decision tree

Suppose we have a when expression

```
15 if foo -> <b1>
b  if bar -> <b2>
_         -> <b3>
```

that may have a decision tree like

```
15?
  \true => foo?
              \true  => <b1>
              \false => bar?
                           \true  => <b2>
                           \false => <b3>
  \false => bar?
               \true  => <b2>
               \false => <b3>
```

In this case, the guard "bar?" appears twice in the compiled decision
tree. We need to materialize the guard expression in both locations in
the compiled tree, which means we cannot as-is stamp a compiled `bar?`
twice in each location. The reason is that

- the compiled joinpoint for each `bar?` guard needs to have a unique ID
- the guard expression might have call which needs unique call spec IDs,
  or other joins that need unique joinpoint IDs.

So, save the expression as we build up the decision tree and materialize
the guard each time we need it. In practice the guards should be quite
small, so duplicating should be fine. We could avoid duplication, but
it's not clear to me how to do that exactly since the branches after the
guard might end up being different.
This commit is contained in:
Ayaz Hafiz 2023-03-25 13:48:40 -05:00
parent f3ddc254c1
commit dd55be6142
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
3 changed files with 160 additions and 43 deletions

View file

@ -1,7 +1,8 @@
use crate::borrow::Ownership;
use crate::ir::{
build_list_index_probe, substitute_in_exprs_many, BranchInfo, Call, CallType, DestructType,
Env, Expr, JoinPointId, ListIndex, Literal, Param, Pattern, Procs, Stmt,
build_list_index_probe, substitute_in_exprs_many, BranchInfo, Call, CallType,
CompiledGuardStmt, DestructType, Env, Expr, GuardStmtSpec, JoinPointId, ListIndex, Literal,
Param, Pattern, Procs, Stmt,
};
use crate::layout::{
Builtin, InLayout, Layout, LayoutCache, LayoutInterner, TLLayoutInterner, TagIdIntType,
@ -42,14 +43,13 @@ fn compile<'a>(
}
#[derive(Clone, Debug, PartialEq)]
pub enum Guard<'a> {
pub(crate) enum Guard<'a> {
NoGuard,
Guard {
/// pattern
pattern: Pattern<'a>,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
/// How to compile the guard statement.
stmt_spec: GuardStmtSpec,
},
}
@ -81,10 +81,8 @@ enum GuardedTest<'a> {
GuardedNoTest {
/// pattern
pattern: Pattern<'a>,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
/// body
stmt: Stmt<'a>,
/// How to compile the guard body.
stmt_spec: GuardStmtSpec,
},
// e.g. `<pattern> -> ...`
TestNotGuarded {
@ -194,9 +192,9 @@ impl<'a> Hash for Test<'a> {
impl<'a> Hash for GuardedTest<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
GuardedTest::GuardedNoTest { id, .. } => {
GuardedTest::GuardedNoTest { stmt_spec, .. } => {
state.write_u8(1);
id.hash(state);
stmt_spec.hash(state);
}
GuardedTest::TestNotGuarded { test } => {
state.write_u8(0);
@ -238,8 +236,8 @@ fn to_decision_tree<'a>(
match first.guard {
Guard::NoGuard => unreachable!(),
Guard::Guard { id, stmt, pattern } => {
let guarded_test = GuardedTest::GuardedNoTest { id, stmt, pattern };
Guard::Guard { pattern, stmt_spec } => {
let guarded_test = GuardedTest::GuardedNoTest { pattern, stmt_spec };
// the guard test does not have a path
let path = vec![];
@ -1366,11 +1364,11 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz
enum Decider<'a, T> {
Leaf(T),
Guarded {
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
pattern: Pattern<'a>,
/// The guard expression and how to compile it.
stmt_spec: GuardStmtSpec,
success: Box<Decider<'a, T>>,
failure: Box<Decider<'a, T>>,
},
@ -1405,7 +1403,7 @@ struct JumpSpec<'a> {
join_body: Stmt<'a>,
}
pub fn optimize_when<'a>(
pub(crate) fn optimize_when<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
@ -2068,9 +2066,8 @@ fn decide_to_branching<'a>(
}
Leaf(Inline(expr)) => expr,
Guarded {
id,
stmt,
pattern,
stmt_spec,
success,
failure,
} => {
@ -2116,8 +2113,13 @@ fn decide_to_branching<'a>(
ownership: Ownership::Owned,
};
let CompiledGuardStmt {
join_point_id,
stmt,
} = stmt_spec.generate_guard_and_join(env, procs, layout_cache);
let join = Stmt::Join {
id,
id: join_point_id,
parameters: arena.alloc([param]),
body: arena.alloc(decide),
remainder: arena.alloc(stmt),
@ -2593,14 +2595,13 @@ fn chain_decider<'a>(
success_tree: DecisionTree<'a>,
) -> Decider<'a, u64> {
match guarded_test {
GuardedTest::GuardedNoTest { id, stmt, pattern } => {
GuardedTest::GuardedNoTest { pattern, stmt_spec } => {
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
Decider::Guarded {
id,
stmt,
pattern,
stmt_spec,
success,
failure,
}
@ -2709,15 +2710,13 @@ fn insert_choices<'a>(
}
Guarded {
id,
stmt,
pattern,
stmt_spec,
success,
failure,
} => Guarded {
id,
stmt,
pattern,
stmt_spec,
success: Box::new(insert_choices(choice_dict, *success)),
failure: Box::new(insert_choices(choice_dict, *failure)),
},

View file

@ -7030,26 +7030,16 @@ fn from_can_when<'a>(
use crate::decision_tree::Guard;
let result = if let Some(loc_expr) = opt_guard {
let id = JoinPointId(env.unique_symbol());
let symbol = env.unique_symbol();
let jump = env.arena.alloc(Stmt::Jump(id, env.arena.alloc([symbol])));
let guard_stmt = with_hole(
env,
loc_expr.value,
Variable::BOOL,
procs,
layout_cache,
symbol,
jump,
);
let guard_spec = GuardStmtSpec {
guard_expr: loc_expr.value,
identity: env.next_call_specialization_id(),
};
(
pattern.clone(),
Guard::Guard {
id,
pattern,
stmt: guard_stmt,
stmt_spec: guard_spec,
},
branch_stmt,
)
@ -7077,6 +7067,83 @@ fn from_can_when<'a>(
)
}
/// A functor to generate IR for a guard under a `when` branch.
/// Used in the decision tree compiler, after building a decision tree and converting into IR.
///
/// A guard might appear more than once in various places in the compiled decision tree, so the
/// functor here may be called more than once. As such, it implements clone, which duplicates the
/// guard AST for subsequent IR-regeneration. This is a bit wasteful, but in practice, guard ASTs
/// are quite small. Moreoever, they must be generated on a per-case basis, since the guard may
/// have calls or joins, whose specialization IDs and joinpoint IDs, respectively, must be unique.
#[derive(Debug, Clone)]
pub(crate) struct GuardStmtSpec {
guard_expr: roc_can::expr::Expr,
/// Unique id to indentity identical guard statements, even across clones.
/// Needed so that we can implement [PartialEq] on this type. Re-uses call specialization IDs,
/// since the identity is kind of irrelevant.
identity: CallSpecId,
}
impl PartialEq for GuardStmtSpec {
fn eq(&self, other: &Self) -> bool {
self.identity == other.identity
}
}
impl std::hash::Hash for GuardStmtSpec {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.identity.id.hash(state);
}
}
impl GuardStmtSpec {
/// Generates IR for the guard, and the joinpoint that the guard will jump to with the
/// calculated guard boolean value.
///
/// The caller should create a joinpoint with the given joinpoint ID and decide how to branch
/// after the guard has been evaluated.
///
/// The compiled guard statement expects the pattern before the guard to be destructed before the
/// returned statement. The caller should layer on the pattern destructuring, as bound from the
/// `when` condition value.
pub(crate) fn generate_guard_and_join<'a>(
self,
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
) -> CompiledGuardStmt<'a> {
let Self {
guard_expr,
identity: _,
} = self;
let join_point_id = JoinPointId(env.unique_symbol());
let symbol = env.unique_symbol();
let jump = env.arena.alloc(Stmt::Jump(join_point_id, env.arena.alloc([symbol])));
let stmt = with_hole(
env,
guard_expr,
Variable::BOOL,
procs,
layout_cache,
symbol,
jump,
);
CompiledGuardStmt {
join_point_id,
stmt,
}
}
}
pub(crate) struct CompiledGuardStmt<'a> {
pub join_point_id: JoinPointId,
pub stmt: Stmt<'a>,
}
fn substitute(substitutions: &BumpMap<Symbol, Symbol>, s: Symbol) -> Option<Symbol> {
match substitutions.get(&s) {
Some(new) => {