refactor guards

This commit is contained in:
Folkert 2021-07-04 21:17:57 +02:00
parent b99f710c49
commit e7c88cac98
3 changed files with 230 additions and 12 deletions

View file

@ -933,6 +933,17 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz
#[derive(Clone, Debug, PartialEq)]
enum Decider<'a, T> {
Leaf(T),
Guarded {
/// Symbol that stores a boolean
/// when true this branch is picked, otherwise skipped
symbol: Symbol,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
success: Box<Decider<'a, T>>,
failure: Box<Decider<'a, T>>,
},
Chain {
test_chain: Vec<(Vec<PathInstruction>, Test<'a>)>,
success: Box<Decider<'a, T>>,
@ -1532,6 +1543,62 @@ fn decide_to_branching<'a>(
Stmt::Jump(jumps[index].1, &[])
}
Leaf(Inline(expr)) => expr,
Guarded {
symbol,
id,
stmt,
success,
failure,
} => {
// the guard is the final thing that we check, so needs to be layered on first!
let test_symbol = env.unique_symbol();
let arena = env.arena;
let pass_expr = decide_to_branching(
env,
procs,
layout_cache,
cond_symbol,
cond_layout,
ret_layout,
*success,
jumps,
);
let fail_expr = decide_to_branching(
env,
procs,
layout_cache,
cond_symbol,
cond_layout,
ret_layout,
*failure,
jumps,
);
let decide = crate::ir::cond(
env,
test_symbol,
Layout::Builtin(Builtin::Int1),
pass_expr,
fail_expr,
ret_layout,
);
// calculate the guard value
let param = Param {
symbol: test_symbol,
layout: Layout::Builtin(Builtin::Int1),
borrow: false,
};
Stmt::Join {
id,
parameters: arena.alloc([param]),
remainder: arena.alloc(stmt),
body: arena.alloc(decide),
}
}
Chain {
test_chain,
success,
@ -1566,6 +1633,8 @@ fn decide_to_branching<'a>(
let (tests, guard) = stores_and_condition(env, cond_symbol, &cond_layout, test_chain);
debug_assert!(guard.is_none());
let number_of_tests = tests.len() as i64 + guard.is_some() as i64;
debug_assert!(number_of_tests > 0);
@ -1792,23 +1861,75 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success,
failure: failure.clone(),
};
match opt_test {
Some(test) => Chain {
test_chain: vec![(path, *test)],
success: Box::new(guarded),
failure,
},
None => guarded,
}
} else {
to_chain(path, test, success_tree, failure_tree)
}
}
_ => {
let fallback = edges.remove(edges.len() - 1).1;
let fallback_tree = edges.remove(edges.len() - 1).1;
let fallback_decider = tree_to_decider(fallback_tree);
let necessary_tests = edges
.into_iter()
.map(|(test, decider)| (test, tree_to_decider(decider)))
.map(|(test, dectree)| {
let decider = tree_to_decider(dectree);
if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success: Box::new(decider),
failure: Box::new(fallback_decider.clone()),
};
match opt_test {
Some(test) => (*test, guarded),
None => todo!(),
}
} else {
(test, decider)
}
})
.collect();
FanOut {
path,
tests: necessary_tests,
fallback: Box::new(tree_to_decider(fallback)),
fallback: Box::new(fallback_decider),
}
}
},
@ -1821,6 +1942,32 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success,
failure: failure.clone(),
};
match opt_test {
Some(test) => Chain {
test_chain: vec![(path, *test)],
success: Box::new(guarded),
failure,
},
None => guarded,
}
} else {
to_chain(path, test, success_tree, failure_tree)
}
@ -1828,16 +1975,47 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
_ => {
let fallback = *last;
let fallback_decider = tree_to_decider(fallback);
// let necessary_tests = edges
// .into_iter()
// .map(|(test, decider)| (test, tree_to_decider(decider)))
// .collect();
let necessary_tests = edges
.into_iter()
.map(|(test, decider)| (test, tree_to_decider(decider)))
.map(|(test, dectree)| {
let decider = tree_to_decider(dectree);
if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success: Box::new(decider),
failure: Box::new(fallback_decider.clone()),
};
match opt_test {
Some(test) => (*test, guarded),
None => todo!(),
}
} else {
(test, decider)
}
})
.collect();
FanOut {
path,
tests: necessary_tests,
fallback: Box::new(tree_to_decider(fallback)),
fallback: Box::new(fallback_decider),
}
}
},
@ -1894,6 +2072,13 @@ fn count_targets(targets: &mut bumpalo::collections::Vec<u64>, initial: &Decider
targets[*target as usize] += 1;
}
Guarded {
success, failure, ..
} => {
stack.push(success);
stack.push(failure);
}
Chain {
success, failure, ..
} => {
@ -1942,6 +2127,20 @@ fn insert_choices<'a>(
Leaf(choice_dict[&target].clone())
}
Guarded {
symbol,
id,
stmt,
success,
failure,
} => Guarded {
symbol,
id,
stmt,
success: Box::new(insert_choices(choice_dict, *success)),
failure: Box::new(insert_choices(choice_dict, *failure)),
},
Chain {
test_chain,
success,