handle guards in a first-class way

This commit is contained in:
Folkert 2021-07-04 22:35:00 +02:00
parent 8f0c13ecc1
commit 2c0aa8a5a1

View file

@ -17,7 +17,7 @@ const RECORD_TAG_NAME: &str = "#Record";
/// some normal branches and gives out a decision tree that has "labels" at all /// some normal branches and gives out a decision tree that has "labels" at all
/// the leafs and a dictionary that maps these "labels" to the code that should /// the leafs and a dictionary that maps these "labels" to the code that should
/// run. /// run.
pub fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> { fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> {
let formatted = raw_branches let formatted = raw_branches
.into_iter() .into_iter()
.map(|(guard, pattern, index)| Branch { .map(|(guard, pattern, index)| Branch {
@ -49,15 +49,35 @@ impl<'a> Guard<'a> {
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum DecisionTree<'a> { enum DecisionTree<'a> {
Match(Label), Match(Label),
Decision { Decision {
path: Vec<PathInstruction>, path: Vec<PathInstruction>,
edges: Vec<(Test<'a>, DecisionTree<'a>)>, edges: Vec<(GuardedTest<'a>, DecisionTree<'a>)>,
default: Option<Box<DecisionTree<'a>>>, default: Option<Box<DecisionTree<'a>>>,
}, },
} }
#[derive(Clone, Debug, PartialEq)]
pub enum GuardedTest<'a> {
TestGuarded {
test: Test<'a>,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
},
// e.g. `_ if True -> ...`
GuardedNoTest {
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
},
TestNotGuarded {
test: Test<'a>,
},
}
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Test<'a> { pub enum Test<'a> {
IsCtor { IsCtor {
@ -75,16 +95,6 @@ pub enum Test<'a> {
tag_id: u8, tag_id: u8,
num_alts: usize, num_alts: usize,
}, },
// A pattern that always succeeds (like `_`) can still have a guard
Guarded {
opt_test: Option<Box<Test<'a>>>,
/// Symbol that stores a boolean
/// when true this branch is picked, otherwise skipped
symbol: Symbol,
/// after assigning to symbol, the stmt jumps to this label
id: JoinPointId,
stmt: Stmt<'a>,
},
} }
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
impl<'a> Hash for Test<'a> { impl<'a> Hash for Test<'a> {
@ -118,15 +128,23 @@ impl<'a> Hash for Test<'a> {
tag_id.hash(state); tag_id.hash(state);
num_alts.hash(state); num_alts.hash(state);
} }
Guarded { opt_test: None, .. } => {
state.write_u8(6);
} }
Guarded { }
opt_test: Some(nested), }
..
} => { impl<'a> Hash for GuardedTest<'a> {
state.write_u8(7); fn hash<H: Hasher>(&self, state: &mut H) {
nested.hash(state); match self {
GuardedTest::TestGuarded { test, .. } => {
state.write_u8(0);
test.hash(state);
}
GuardedTest::GuardedNoTest { id, stmt } => {
state.write_u8(1);
}
GuardedTest::TestNotGuarded { test } => {
state.write_u8(2);
test.hash(state);
} }
} }
} }
@ -182,20 +200,32 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
} }
} }
fn is_complete(tests: &[Test]) -> bool { fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool {
let length = tests.len(); let length = tests.len();
debug_assert!(length > 0); debug_assert!(length > 0);
match tests.last() {
None => unreachable!("should never happen"), match tests.last().unwrap() {
Some(v) => match v { GuardedTest::TestGuarded { .. } => false,
Test::IsCtor { union, .. } => length == union.alternatives.len(), GuardedTest::GuardedNoTest { .. } => false,
Test::IsByte { num_alts, .. } => length == *num_alts, GuardedTest::TestNotGuarded { test } => tests_are_complete_help(test, length),
Test::IsBit(_) => length == 2, }
}
fn tests_are_complete(tests: &[Test]) -> bool {
let length = tests.len();
debug_assert!(length > 0);
tests_are_complete_help(tests.last().unwrap(), length)
}
fn tests_are_complete_help(last_test: &Test, number_of_tests: usize) -> bool {
match last_test {
Test::IsCtor { union, .. } => number_of_tests == union.alternatives.len(),
Test::IsByte { num_alts, .. } => number_of_tests == *num_alts,
Test::IsBit(_) => number_of_tests == 2,
Test::IsInt(_) => false, Test::IsInt(_) => false,
Test::IsFloat(_) => false, Test::IsFloat(_) => false,
Test::IsStr(_) => false, Test::IsStr(_) => false,
Test::Guarded { .. } => false,
},
} }
} }
@ -293,10 +323,10 @@ fn check_for_match(branches: &[Branch]) -> Option<Label> {
fn gather_edges<'a>( fn gather_edges<'a>(
branches: Vec<Branch<'a>>, branches: Vec<Branch<'a>>,
path: &[PathInstruction], path: &[PathInstruction],
) -> (Vec<(Test<'a>, Vec<Branch<'a>>)>, Vec<Branch<'a>>) { ) -> (Vec<(GuardedTest<'a>, Vec<Branch<'a>>)>, Vec<Branch<'a>>) {
let relevant_tests = tests_at_path(path, &branches); let relevant_tests = tests_at_path(path, &branches);
let check = is_complete(&relevant_tests); let check = guarded_tests_are_complete(&relevant_tests);
// TODO remove clone // TODO remove clone
let all_edges = relevant_tests let all_edges = relevant_tests
@ -318,7 +348,10 @@ fn gather_edges<'a>(
/// FIND RELEVANT TESTS /// FIND RELEVANT TESTS
fn tests_at_path<'a>(selected_path: &[PathInstruction], branches: &[Branch<'a>]) -> Vec<Test<'a>> { fn tests_at_path<'a>(
selected_path: &[PathInstruction],
branches: &[Branch<'a>],
) -> Vec<GuardedTest<'a>> {
// NOTE the ordering of the result is important! // NOTE the ordering of the result is important!
let mut all_tests = Vec::new(); let mut all_tests = Vec::new();
@ -355,7 +388,7 @@ fn tests_at_path<'a>(selected_path: &[PathInstruction], branches: &[Branch<'a>])
fn test_at_path<'a>( fn test_at_path<'a>(
selected_path: &[PathInstruction], selected_path: &[PathInstruction],
branch: &Branch<'a>, branch: &Branch<'a>,
all_tests: &mut Vec<Test<'a>>, guarded_tests: &mut Vec<GuardedTest<'a>>,
) { ) {
use Pattern::*; use Pattern::*;
use Test::*; use Test::*;
@ -367,30 +400,16 @@ fn test_at_path<'a>(
{ {
None => {} None => {}
Some((_, guard, pattern)) => { Some((_, guard, pattern)) => {
let guarded = |test| { let test = match pattern {
if let Guard::Guard { symbol, id, stmt } = guard {
Guarded {
opt_test: Some(Box::new(test)),
stmt: stmt.clone(),
symbol: *symbol,
id: *id,
}
} else {
test
}
};
match pattern {
// TODO use guard!
Identifier(_) | Underscore => { Identifier(_) | Underscore => {
if let Guard::Guard { symbol, id, stmt } = guard { if let Guard::Guard { id, stmt, .. } = guard {
all_tests.push(Guarded { guarded_tests.push(GuardedTest::GuardedNoTest {
opt_test: None,
stmt: stmt.clone(), stmt: stmt.clone(),
symbol: *symbol,
id: *id, id: *id,
}); });
} }
return;
} }
RecordDestructure(destructs, _) => { RecordDestructure(destructs, _) => {
@ -417,12 +436,12 @@ fn test_at_path<'a>(
} }
} }
all_tests.push(IsCtor { IsCtor {
tag_id: 0, tag_id: 0,
tag_name: TagName::Global(RECORD_TAG_NAME.into()), tag_name: TagName::Global(RECORD_TAG_NAME.into()),
union, union,
arguments, arguments,
}); }
} }
NewtypeDestructure { NewtypeDestructure {
@ -432,12 +451,12 @@ fn test_at_path<'a>(
let tag_id = 0; let tag_id = 0;
let union = Union::newtype_wrapper(tag_name.clone(), arguments.len()); let union = Union::newtype_wrapper(tag_name.clone(), arguments.len());
all_tests.push(IsCtor { IsCtor {
tag_id, tag_id,
tag_name: tag_name.clone(), tag_name: tag_name.clone(),
union, union,
arguments: arguments.to_vec(), arguments: arguments.to_vec(),
}); }
} }
AppliedTag { AppliedTag {
@ -446,33 +465,33 @@ fn test_at_path<'a>(
arguments, arguments,
union, union,
.. ..
} => { } => IsCtor {
all_tests.push(IsCtor {
tag_id: *tag_id, tag_id: *tag_id,
tag_name: tag_name.clone(), tag_name: tag_name.clone(),
union: union.clone(), union: union.clone(),
arguments: arguments.to_vec(), arguments: arguments.to_vec(),
}); },
} BitLiteral { value, .. } => IsBit(*value),
BitLiteral { value, .. } => { EnumLiteral { tag_id, union, .. } => IsByte {
all_tests.push(IsBit(*value));
}
EnumLiteral { tag_id, union, .. } => {
all_tests.push(IsByte {
tag_id: *tag_id, tag_id: *tag_id,
num_alts: union.alternatives.len(), num_alts: union.alternatives.len(),
}); },
} IntLiteral(v) => IsInt(*v),
IntLiteral(v) => { FloatLiteral(v) => IsFloat(*v),
all_tests.push(guarded(IsInt(*v))); StrLiteral(v) => IsStr(v.clone()),
}
FloatLiteral(v) => {
all_tests.push(IsFloat(*v));
}
StrLiteral(v) => {
all_tests.push(IsStr(v.clone()));
}
}; };
let guarded_test = if let Guard::Guard { symbol, id, stmt } = guard {
GuardedTest::TestGuarded {
test,
stmt: stmt.clone(),
id: *id,
}
} else {
GuardedTest::TestNotGuarded { test }
};
guarded_tests.push(guarded_test);
} }
} }
} }
@ -482,8 +501,8 @@ fn test_at_path<'a>(
fn edges_for<'a>( fn edges_for<'a>(
path: &[PathInstruction], path: &[PathInstruction],
branches: Vec<Branch<'a>>, branches: Vec<Branch<'a>>,
test: Test<'a>, test: GuardedTest<'a>,
) -> (Test<'a>, Vec<Branch<'a>>) { ) -> (GuardedTest<'a>, Vec<Branch<'a>>) {
let mut new_branches = Vec::new(); let mut new_branches = Vec::new();
for branch in branches.iter() { for branch in branches.iter() {
@ -494,7 +513,7 @@ fn edges_for<'a>(
} }
fn to_relevant_branch<'a>( fn to_relevant_branch<'a>(
test: &Test<'a>, guarded_test: &GuardedTest<'a>,
path: &[PathInstruction], path: &[PathInstruction],
branch: &Branch<'a>, branch: &Branch<'a>,
new_branches: &mut Vec<Branch<'a>>, new_branches: &mut Vec<Branch<'a>>,
@ -509,12 +528,22 @@ fn to_relevant_branch<'a>(
found_pattern: (guard, pattern), found_pattern: (guard, pattern),
end, end,
} => { } => {
let actual_test = match test { let actual_test = match guarded_test {
Test::Guarded { GuardedTest::TestGuarded { test, .. } => test,
opt_test: Some(box_test), GuardedTest::GuardedNoTest { .. } => {
.. let mut new_branch = branch.clone();
} => box_test,
_ => test, // guards can/should only occur at the top level. When we recurse on these
// branches, the guard is not relevant any more. Not setthing the guard to None
// leads to infinite recursion.
new_branch.patterns.iter_mut().for_each(|(_, guard, _)| {
*guard = Guard::NoGuard;
});
new_branches.push(new_branch);
return;
}
GuardedTest::TestNotGuarded { test } => test,
}; };
if let Some(mut new_branch) = if let Some(mut new_branch) =
@ -934,9 +963,6 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz
enum Decider<'a, T> { enum Decider<'a, T> {
Leaf(T), Leaf(T),
Guarded { Guarded {
/// Symbol that stores a boolean
/// when true this branch is picked, otherwise skipped
symbol: Symbol,
/// after assigning to symbol, the stmt jumps to this label /// after assigning to symbol, the stmt jumps to this label
id: JoinPointId, id: JoinPointId,
stmt: Stmt<'a>, stmt: Stmt<'a>,
@ -1230,8 +1256,6 @@ fn test_to_equality<'a>(
None, None,
) )
} }
Test::Guarded { .. } => unreachable!("should be handled elsewhere"),
} }
} }
@ -1248,46 +1272,21 @@ fn stores_and_condition<'a>(
cond_symbol: Symbol, cond_symbol: Symbol,
cond_layout: &Layout<'a>, cond_layout: &Layout<'a>,
test_chain: Vec<(Vec<PathInstruction>, Test<'a>)>, test_chain: Vec<(Vec<PathInstruction>, Test<'a>)>,
) -> (Tests<'a>, Option<(Symbol, JoinPointId, Stmt<'a>)>) { ) -> Tests<'a> {
let mut tests: Tests = Vec::with_capacity(test_chain.len()); let mut tests: Tests = Vec::with_capacity(test_chain.len());
let mut guard = None;
// Assumption: there is at most 1 guard, and it is the outer layer. // Assumption: there is at most 1 guard, and it is the outer layer.
for (path, test) in test_chain { for (path, test) in test_chain {
match test {
Test::Guarded {
opt_test,
id,
symbol,
stmt,
} => {
if let Some(nested) = opt_test {
tests.push(test_to_equality( tests.push(test_to_equality(
env,
cond_symbol,
&cond_layout,
&path,
*nested,
));
}
// let (stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout);
guard = Some((symbol, id, stmt));
}
_ => tests.push(test_to_equality(
env, env,
cond_symbol, cond_symbol,
&cond_layout, &cond_layout,
&path, &path,
test, test,
)), ))
}
} }
(tests, guard) tests
} }
fn compile_guard<'a>( fn compile_guard<'a>(
@ -1538,7 +1537,6 @@ fn decide_to_branching<'a>(
} }
Leaf(Inline(expr)) => expr, Leaf(Inline(expr)) => expr,
Guarded { Guarded {
symbol,
id, id,
stmt, stmt,
success, success,
@ -1625,7 +1623,7 @@ fn decide_to_branching<'a>(
let chain_branch_info = let chain_branch_info =
ConstructorKnown::from_test_chain(cond_symbol, &cond_layout, &test_chain); ConstructorKnown::from_test_chain(cond_symbol, &cond_layout, &test_chain);
let (tests, _) = stores_and_condition(env, cond_symbol, &cond_layout, test_chain); let tests = stores_and_condition(env, cond_symbol, &cond_layout, test_chain);
let number_of_tests = tests.len() as i64; let number_of_tests = tests.len() as i64;
@ -1821,6 +1819,14 @@ fn test_always_succeeds(test: &Test) -> bool {
} }
} }
fn guarded_test_always_succeeds(test: &GuardedTest) -> bool {
match test {
GuardedTest::TestGuarded { test, id, stmt } => false,
GuardedTest::GuardedNoTest { id, stmt } => false,
GuardedTest::TestNotGuarded { test } => test_always_succeeds(test),
}
}
fn tree_to_decider(tree: DecisionTree) -> Decider<u64> { fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
use Decider::*; use Decider::*;
use DecisionTree::*; use DecisionTree::*;
@ -1842,40 +1848,46 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
} }
2 => { 2 => {
let (_, failure_tree) = edges.remove(1); let (_, failure_tree) = edges.remove(1);
let (test, success_tree) = edges.remove(0); let (guarded_test, success_tree) = edges.remove(0);
if test_always_succeeds(&test) { match guarded_test {
tree_to_decider(success_tree) GuardedTest::TestGuarded { test, id, stmt } => {
} else if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
let failure = Box::new(tree_to_decider(failure_tree)); let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree)); let success = Box::new(tree_to_decider(success_tree));
let guarded = Decider::Guarded { let guarded = Decider::Guarded {
symbol,
id, id,
stmt, stmt,
success, success,
failure: failure.clone(), failure: failure.clone(),
}; };
match opt_test { Chain {
Some(test) => Chain { test_chain: vec![(path, test)],
test_chain: vec![(path, *test)],
success: Box::new(guarded), success: Box::new(guarded),
failure, failure,
},
None => guarded,
} }
}
GuardedTest::GuardedNoTest { id, stmt } => {
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
Decider::Guarded {
id,
stmt,
success,
failure: failure.clone(),
}
}
GuardedTest::TestNotGuarded { test } => {
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else { } else {
to_chain(path, test, success_tree, failure_tree) to_chain(path, test, success_tree, failure_tree)
} }
} }
}
}
_ => { _ => {
let fallback_tree = edges.remove(edges.len() - 1).1; let fallback_tree = edges.remove(edges.len() - 1).1;
@ -1883,30 +1895,24 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
let necessary_tests = edges let necessary_tests = edges
.into_iter() .into_iter()
.map(|(test, dectree)| { .map(|(guarded_test, dectree)| {
let decider = tree_to_decider(dectree); let decider = tree_to_decider(dectree);
if let Test::Guarded { match guarded_test {
symbol, GuardedTest::TestGuarded { test, id, stmt } => {
id,
stmt,
opt_test,
} = test
{
let guarded = Decider::Guarded { let guarded = Decider::Guarded {
symbol,
id, id,
stmt, stmt,
success: Box::new(decider), success: Box::new(decider),
failure: Box::new(fallback_decider.clone()), failure: Box::new(fallback_decider.clone()),
}; };
match opt_test { (test, guarded)
Some(test) => (*test, guarded),
None => todo!(),
} }
} else { GuardedTest::GuardedNoTest { id, stmt } => {
(test, decider) unreachable!("this would not end up in a switch")
}
GuardedTest::TestNotGuarded { test } => (test, decider),
} }
}) })
.collect(); .collect();
@ -1923,76 +1929,71 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
0 => tree_to_decider(*last), 0 => tree_to_decider(*last),
1 => { 1 => {
let failure_tree = *last; let failure_tree = *last;
let (test, success_tree) = edges.remove(0); let (guarded_test, success_tree) = edges.remove(0);
if test_always_succeeds(&test) { match guarded_test {
tree_to_decider(success_tree) GuardedTest::TestGuarded { test, id, stmt } => {
} else if let Test::Guarded {
symbol,
id,
stmt,
opt_test,
} = test
{
let failure = Box::new(tree_to_decider(failure_tree)); let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree)); let success = Box::new(tree_to_decider(success_tree));
let guarded = Decider::Guarded { let guarded = Decider::Guarded {
symbol,
id, id,
stmt, stmt,
success, success,
failure: failure.clone(), failure: failure.clone(),
}; };
match opt_test { Chain {
Some(test) => Chain { test_chain: vec![(path, test)],
test_chain: vec![(path, *test)],
success: Box::new(guarded), success: Box::new(guarded),
failure, failure,
},
None => guarded,
} }
}
GuardedTest::GuardedNoTest { id, stmt } => {
let failure = Box::new(tree_to_decider(failure_tree));
let success = Box::new(tree_to_decider(success_tree));
Decider::Guarded {
id,
stmt,
success,
failure: failure.clone(),
}
}
GuardedTest::TestNotGuarded { test } => {
if test_always_succeeds(&test) {
tree_to_decider(success_tree)
} else { } else {
to_chain(path, test, success_tree, failure_tree) to_chain(path, test, success_tree, failure_tree)
} }
} }
}
}
_ => { _ => {
let fallback = *last; let fallback = *last;
let fallback_decider = tree_to_decider(fallback); let fallback_decider = tree_to_decider(fallback);
// let necessary_tests = edges
// .into_iter()
// .map(|(test, decider)| (test, tree_to_decider(decider)))
// .collect();
let necessary_tests = edges let necessary_tests = edges
.into_iter() .into_iter()
.map(|(test, dectree)| { .map(|(guarded_test, dectree)| {
let decider = tree_to_decider(dectree); let decider = tree_to_decider(dectree);
if let Test::Guarded { match guarded_test {
symbol, GuardedTest::TestGuarded { test, id, stmt } => {
id,
stmt,
opt_test,
} = test
{
let guarded = Decider::Guarded { let guarded = Decider::Guarded {
symbol,
id, id,
stmt, stmt,
success: Box::new(decider), success: Box::new(decider),
failure: Box::new(fallback_decider.clone()), failure: Box::new(fallback_decider.clone()),
}; };
match opt_test { (test, guarded)
Some(test) => (*test, guarded),
None => todo!(),
} }
} else { GuardedTest::GuardedNoTest { id, stmt } => {
(test, decider) unreachable!("this would not end up in a switch")
}
GuardedTest::TestNotGuarded { test } => (test, decider),
} }
}) })
.collect(); .collect();
@ -2113,13 +2114,11 @@ fn insert_choices<'a>(
} }
Guarded { Guarded {
symbol,
id, id,
stmt, stmt,
success, success,
failure, failure,
} => Guarded { } => Guarded {
symbol,
id, id,
stmt, stmt,
success: Box::new(insert_choices(choice_dict, *success)), success: Box::new(insert_choices(choice_dict, *success)),