handle guards in a first-class way

This commit is contained in:
Folkert 2021-07-04 22:35:00 +02:00
parent 8f0c13ecc1
commit 2c0aa8a5a1

View file

@ -17,7 +17,7 @@ const RECORD_TAG_NAME: &str = "#Record";
/// some normal branches and gives out a decision tree that has "labels" at all
/// the leafs and a dictionary that maps these "labels" to the code that should
/// run.
pub fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> {
fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> {
let formatted = raw_branches
.into_iter()
.map(|(guard, pattern, index)| Branch {
@ -49,15 +49,35 @@ impl<'a> Guard<'a> {
}
#[derive(Clone, Debug, PartialEq)]
pub enum DecisionTree<'a> {
enum DecisionTree<'a> {
Match(Label),
Decision {
path: Vec<PathInstruction>,
edges: Vec<(Test<'a>, DecisionTree<'a>)>,
edges: Vec<(GuardedTest<'a>, DecisionTree<'a>)>,
default: Option<Box<DecisionTree<'a>>>,
},
}
#[derive(Clone, Debug, PartialEq)]
pub enum GuardedTest<'a> {
TestGuarded {
test: Test<'a>,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
},
// e.g. `_ if True -> ...`
GuardedNoTest {
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
},
TestNotGuarded {
test: Test<'a>,
},
}
#[derive(Clone, Debug, PartialEq)]
pub enum Test<'a> {
IsCtor {
@ -75,16 +95,6 @@ pub enum Test<'a> {
tag_id: u8,
num_alts: usize,
},
// A pattern that always succeeds (like `_`) can still have a guard
Guarded {
opt_test: Option<Box<Test<'a>>>,
/// Symbol that stores a boolean
/// when true this branch is picked, otherwise skipped
symbol: Symbol,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
},
}
use std::hash::{Hash, Hasher};
impl<'a> Hash for Test<'a> {
@ -118,15 +128,23 @@ impl<'a> Hash for Test<'a> {
tag_id.hash(state);
num_alts.hash(state);
}
Guarded { opt_test: None, .. } => {
state.write_u8(6);
}
Guarded {
opt_test: Some(nested),
..
} => {
state.write_u8(7);
nested.hash(state);
}
}
impl<'a> Hash for GuardedTest<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
GuardedTest::TestGuarded { test, .. } => {
state.write_u8(0);
test.hash(state);
}
GuardedTest::GuardedNoTest { id, stmt } => {
state.write_u8(1);
}
GuardedTest::TestNotGuarded { test } => {
state.write_u8(2);
test.hash(state);
}
}
}
@ -182,20 +200,32 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
}
}
fn is_complete(tests: &[Test]) -> bool {
fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool {
let length = tests.len();
debug_assert!(length > 0);
match tests.last() {
None => unreachable!("should never happen"),
Some(v) => match v {
Test::IsCtor { union, .. } => length == union.alternatives.len(),
Test::IsByte { num_alts, .. } => length == *num_alts,
Test::IsBit(_) => length == 2,
match tests.last().unwrap() {
GuardedTest::TestGuarded { .. } => false,
GuardedTest::GuardedNoTest { .. } => false,
GuardedTest::TestNotGuarded { test } => tests_are_complete_help(test, length),
}
}
fn tests_are_complete(tests: &[Test]) -> bool {
let length = tests.len();
debug_assert!(length > 0);
tests_are_complete_help(tests.last().unwrap(), length)
}
fn tests_are_complete_help(last_test: &Test, number_of_tests: usize) -> bool {
match last_test {
Test::IsCtor { union, .. } => number_of_tests == union.alternatives.len(),
Test::IsByte { num_alts, .. } => number_of_tests == *num_alts,
Test::IsBit(_) => number_of_tests == 2,
Test::IsInt(_) => false,
Test::IsFloat(_) => false,
Test::IsStr(_) => false,
Test::Guarded { .. } => false,
},
}
}
@ -293,10 +323,10 @@ fn check_for_match(branches: &[Branch]) -> Option<Label> {
fn gather_edges<'a>(
branches: Vec<Branch<'a>>,
path: &[PathInstruction],
) -> (Vec<(Test<'a>, Vec<Branch<'a>>)>, Vec<Branch<'a>>) {
) -> (Vec<(GuardedTest<'a>, Vec<Branch<'a>>)>, Vec<Branch<'a>>) {
let relevant_tests = tests_at_path(path, &branches);
let check = is_complete(&relevant_tests);
let check = guarded_tests_are_complete(&relevant_tests);
// TODO remove clone
let all_edges = relevant_tests
@ -318,7 +348,10 @@ fn gather_edges<'a>(
/// FIND RELEVANT TESTS
fn tests_at_path<'a>(selected_path: &[PathInstruction], branches: &[Branch<'a>]) -> Vec<Test<'a>> {
fn tests_at_path<'a>(
selected_path: &[PathInstruction],
branches: &[Branch<'a>],
) -> Vec<GuardedTest<'a>> {
// NOTE the ordering of the result is important!
let mut all_tests = Vec::new();
@ -355,7 +388,7 @@ fn tests_at_path<'a>(selected_path: &[PathInstruction], branches: &[Branch<'a>])
fn test_at_path<'a>(
selected_path: &[PathInstruction],
branch: &Branch<'a>,
all_tests: &mut Vec<Test<'a>>,
guarded_tests: &mut Vec<GuardedTest<'a>>,
) {
use Pattern::*;
use Test::*;
@ -367,30 +400,16 @@ fn test_at_path<'a>(
{
None => {}
Some((_, guard, pattern)) => {
let guarded = |test| {
if let Guard::Guard { symbol, id, stmt } = guard {
Guarded {
opt_test: Some(Box::new(test)),
stmt: stmt.clone(),
symbol: *symbol,
id: *id,
}
} else {
test
}
};
match pattern {
// TODO use guard!
let test = match pattern {
Identifier(_) | Underscore => {
if let Guard::Guard { symbol, id, stmt } = guard {
all_tests.push(Guarded {
opt_test: None,
if let Guard::Guard { id, stmt, .. } = guard {
guarded_tests.push(GuardedTest::GuardedNoTest {
stmt: stmt.clone(),
symbol: *symbol,
id: *id,
});
}
return;
}
RecordDestructure(destructs, _) => {
@ -417,12 +436,12 @@ fn test_at_path<'a>(
}
}
all_tests.push(IsCtor {
IsCtor {
tag_id: 0,
tag_name: TagName::Global(RECORD_TAG_NAME.into()),
union,
arguments,
});
}
}
NewtypeDestructure {
@ -432,12 +451,12 @@ fn test_at_path<'a>(
let tag_id = 0;
let union = Union::newtype_wrapper(tag_name.clone(), arguments.len());
all_tests.push(IsCtor {
IsCtor {
tag_id,
tag_name: tag_name.clone(),
union,
arguments: arguments.to_vec(),
});
}
}
AppliedTag {
@ -446,33 +465,33 @@ fn test_at_path<'a>(
arguments,
union,
..
} => {
all_tests.push(IsCtor {
} => IsCtor {
tag_id: *tag_id,
tag_name: tag_name.clone(),
union: union.clone(),
arguments: arguments.to_vec(),
});
}
BitLiteral { value, .. } => {
all_tests.push(IsBit(*value));
}
EnumLiteral { tag_id, union, .. } => {
all_tests.push(IsByte {
},
BitLiteral { value, .. } => IsBit(*value),
EnumLiteral { tag_id, union, .. } => IsByte {
tag_id: *tag_id,
num_alts: union.alternatives.len(),
});
}
IntLiteral(v) => {
all_tests.push(guarded(IsInt(*v)));
}
FloatLiteral(v) => {
all_tests.push(IsFloat(*v));
}
StrLiteral(v) => {
all_tests.push(IsStr(v.clone()));
}
},
IntLiteral(v) => IsInt(*v),
FloatLiteral(v) => IsFloat(*v),
StrLiteral(v) => IsStr(v.clone()),
};
let guarded_test = if let Guard::Guard { symbol, id, stmt } = guard {
GuardedTest::TestGuarded {
test,
stmt: stmt.clone(),
id: *id,
}
} else {
GuardedTest::TestNotGuarded { test }
};
guarded_tests.push(guarded_test);
}
}
}
@ -482,8 +501,8 @@ fn test_at_path<'a>(
fn edges_for<'a>(
path: &[PathInstruction],
branches: Vec<Branch<'a>>,
test: Test<'a>,
) -> (Test<'a>, Vec<Branch<'a>>) {
test: GuardedTest<'a>,
) -> (GuardedTest<'a>, Vec<Branch<'a>>) {
let mut new_branches = Vec::new();
for branch in branches.iter() {
@ -494,7 +513,7 @@ fn edges_for<'a>(
}
fn to_relevant_branch<'a>(
test: &Test<'a>,
guarded_test: &GuardedTest<'a>,
path: &[PathInstruction],
branch: &Branch<'a>,
new_branches: &mut Vec<Branch<'a>>,
@ -509,12 +528,22 @@ fn to_relevant_branch<'a>(
found_pattern: (guard, pattern),
end,
} => {
let actual_test = match test {
Test::Guarded {
opt_test: Some(box_test),
..
} => box_test,
_ => test,
let actual_test = match guarded_test {
GuardedTest::TestGuarded { test, .. } => test,
GuardedTest::GuardedNoTest { .. } => {
let mut new_branch = branch.clone();
// guards can/should only occur at the top level. When we recurse on these
// branches, the guard is not relevant any more. Not setthing the guard to None
// leads to infinite recursion.
new_branch.patterns.iter_mut().for_each(|(_, guard, _)| {
*guard = Guard::NoGuard;
});
new_branches.push(new_branch);
return;
}
GuardedTest::TestNotGuarded { test } => test,
};
if let Some(mut new_branch) =
@ -934,9 +963,6 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz
enum Decider<'a, T> {
Leaf(T),
Guarded {
/// Symbol that stores a boolean
/// when true this branch is picked, otherwise skipped
symbol: Symbol,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
@ -1230,8 +1256,6 @@ fn test_to_equality<'a>(
None,
)
}
Test::Guarded { .. } => unreachable!("should be handled elsewhere"),
}
}
@ -1248,46 +1272,21 @@ fn stores_and_condition<'a>(
cond_symbol: Symbol,
cond_layout: &Layout<'a>,
test_chain: Vec<(Vec<PathInstruction>, Test<'a>)>,
) -> (Tests<'a>, Option<(Symbol, JoinPointId, Stmt<'a>)>) {
) -> Tests<'a> {
let mut tests: Tests = Vec::with_capacity(test_chain.len());
let mut guard = None;
// Assumption: there is at most 1 guard, and it is the outer layer.
for (path, test) in test_chain {
match test {
Test::Guarded {
opt_test,
id,
symbol,
stmt,
} => {
if let Some(nested) = opt_test {
tests.push(test_to_equality(
env,
cond_symbol,
&cond_layout,
&path,
*nested,
));
}
// let (stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout);
guard = Some((symbol, id, stmt));
}
_ => tests.push(test_to_equality(
env,
cond_symbol,
&cond_layout,
&path,
test,
)),
}
))
}
(tests, guard)
tests
}
fn compile_guard<'a>(
@ -1538,7 +1537,6 @@ fn decide_to_branching<'a>(
}
Leaf(Inline(expr)) => expr,
Guarded {
symbol,
id,
stmt,
success,
@ -1625,7 +1623,7 @@ fn decide_to_branching<'a>(
let chain_branch_info =
ConstructorKnown::from_test_chain(cond_symbol, &cond_layout, &test_chain);
let (tests, _) = stores_and_condition(env, cond_symbol, &cond_layout, test_chain);
let tests = stores_and_condition(env, cond_symbol, &cond_layout, test_chain);
let number_of_tests = tests.len() as i64;
@ -1821,6 +1819,14 @@ fn test_always_succeeds(test: &Test) -> bool {
}
}
fn guarded_test_always_succeeds(test: &GuardedTest) -> bool {
match test {
GuardedTest::TestGuarded { test, id, stmt } => false,
GuardedTest::GuardedNoTest { id, stmt } => false,
GuardedTest::TestNotGuarded { test } => test_always_succeeds(test),
}
}
fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
use Decider::*;
use DecisionTree::*;
@ -1842,40 +1848,46 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
}
2 => {
let (_, failure_tree) = edges.remove(1);
let (test, success_tree) = edges.remove(0);
let (guarded_test, success_tree) = edges.remove(0);
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
match guarded_test {
GuardedTest::TestGuarded { test, id, stmt } => {
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success,
failure: failure.clone(),
};
match opt_test {
Some(test) => Chain {
test_chain: vec![(path, *test)],
Chain {
test_chain: vec![(path, test)],
success: Box::new(guarded),
failure,
},
None => guarded,
}
}
GuardedTest::GuardedNoTest { id, stmt } => {
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
Decider::Guarded {
id,
stmt,
success,
failure: failure.clone(),
}
}
GuardedTest::TestNotGuarded { test } => {
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else {
to_chain(path, test, success_tree, failure_tree)
}
}
}
}
_ => {
let fallback_tree = edges.remove(edges.len() - 1).1;
@ -1883,30 +1895,24 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
let necessary_tests = edges
.into_iter()
.map(|(test, dectree)| {
.map(|(guarded_test, dectree)| {
let decider = tree_to_decider(dectree);
if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
match guarded_test {
GuardedTest::TestGuarded { test, id, stmt } => {
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success: Box::new(decider),
failure: Box::new(fallback_decider.clone()),
};
match opt_test {
Some(test) => (*test, guarded),
None => todo!(),
(test, guarded)
}
} else {
(test, decider)
GuardedTest::GuardedNoTest { id, stmt } => {
unreachable!("this would not end up in a switch")
}
GuardedTest::TestNotGuarded { test } => (test, decider),
}
})
.collect();
@ -1923,76 +1929,71 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
0 => tree_to_decider(*last),
1 => {
let failure_tree = *last;
let (test, success_tree) = edges.remove(0);
let (guarded_test, success_tree) = edges.remove(0);
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
match guarded_test {
GuardedTest::TestGuarded { test, id, stmt } => {
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success,
failure: failure.clone(),
};
match opt_test {
Some(test) => Chain {
test_chain: vec![(path, *test)],
Chain {
test_chain: vec![(path, test)],
success: Box::new(guarded),
failure,
},
None => guarded,
}
}
GuardedTest::GuardedNoTest { id, stmt } => {
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
Decider::Guarded {
id,
stmt,
success,
failure: failure.clone(),
}
}
GuardedTest::TestNotGuarded { test } => {
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else {
to_chain(path, test, success_tree, failure_tree)
}
}
}
}
_ => {
let fallback = *last;
let fallback_decider = tree_to_decider(fallback);
// let necessary_tests = edges
// .into_iter()
// .map(|(test, decider)| (test, tree_to_decider(decider)))
// .collect();
let necessary_tests = edges
.into_iter()
.map(|(test, dectree)| {
.map(|(guarded_test, dectree)| {
let decider = tree_to_decider(dectree);
if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
match guarded_test {
GuardedTest::TestGuarded { test, id, stmt } => {
let guarded = Decider::Guarded {
symbol,
id,
stmt,
success: Box::new(decider),
failure: Box::new(fallback_decider.clone()),
};
match opt_test {
Some(test) => (*test, guarded),
None => todo!(),
(test, guarded)
}
} else {
(test, decider)
GuardedTest::GuardedNoTest { id, stmt } => {
unreachable!("this would not end up in a switch")
}
GuardedTest::TestNotGuarded { test } => (test, decider),
}
})
.collect();
@ -2113,13 +2114,11 @@ fn insert_choices<'a>(
}
Guarded {
symbol,
id,
stmt,
success,
failure,
} => Guarded {
symbol,
id,
stmt,
success: Box::new(insert_choices(choice_dict, *success)),