working decision tree gen for all tests

This commit is contained in:
Folkert 2021-07-10 21:28:32 +02:00
parent 7abfca4388
commit e05753afd8
3 changed files with 176 additions and 66 deletions

View file

@ -70,6 +70,7 @@ enum GuardedTest<'a> {
TestNotGuarded {
test: Test<'a>,
},
Placeholder,
}
#[derive(Clone, Debug, PartialEq)]
@ -130,13 +131,17 @@ impl<'a> Hash for Test<'a> {
impl<'a> Hash for GuardedTest<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
GuardedTest::GuardedNoTest { .. } => {
GuardedTest::GuardedNoTest { id, .. } => {
state.write_u8(1);
id.hash(state);
}
GuardedTest::TestNotGuarded { test } => {
state.write_u8(0);
test.hash(state);
}
GuardedTest::Placeholder => {
state.write_u8(2);
}
}
}
}
@ -159,30 +164,31 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
Match::Exact(goal) => DecisionTree::Match(goal),
Match::GuardOnly => {
// the guard test does not have a path
let path = vec![];
// the first branch has no more tests to do, but it has an if-guard
let mut branches = branches;
let first = branches.remove(0);
let guarded_test = {
match first.guard {
Guard::NoGuard => unreachable!(),
Guard::Guard {
symbol: _,
id,
stmt,
} => {
let guarded_test = GuardedTest::GuardedNoTest { id, stmt };
// the guard test does not have a path
let path = vec![];
// we expect none of the patterns need tests, those decisions should have been made already
debug_assert!(first
.patterns
.iter()
.all(|(_, pattern)| !needs_tests(pattern)));
match first.guard {
Guard::NoGuard => unreachable!("first test must have a guard"),
Guard::Guard {
symbol: _,
id,
stmt,
} => GuardedTest::GuardedNoTest { id, stmt },
}
};
let rest = if branches.is_empty() {
let default = if branches.is_empty() {
None
} else {
Some(Box::new(to_decision_tree(branches)))
@ -191,7 +197,9 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
DecisionTree::Decision {
path,
edges: vec![(guarded_test, DecisionTree::Match(first.goal))],
default: rest,
default,
}
}
}
}
@ -199,11 +207,18 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
// must clone here to release the borrow on `branches`
let path = pick_path(&branches).clone();
let bs = branches.clone();
let (edges, fallback) = gather_edges(branches, &path);
let mut decision_edges: Vec<_> = edges
.into_iter()
.map(|(a, b)| (a, to_decision_tree(b)))
.map(|(test, branches)| {
if bs == branches {
panic!();
} else {
(test, to_decision_tree(branches))
}
})
.collect();
match (decision_edges.as_slice(), fallback.as_slice()) {
@ -213,21 +228,54 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
// get the `_decision_tree` without cloning
decision_edges.pop().unwrap().1
}
(_, []) => DecisionTree::Decision {
path,
edges: decision_edges,
default: None,
},
(_, []) => helper(path, decision_edges, None),
([], _) => {
// should be guaranteed by the patterns
debug_assert!(!fallback.is_empty());
to_decision_tree(fallback)
}
(_, _) => DecisionTree::Decision {
(_, _) => helper(
path,
edges: decision_edges,
default: Some(Box::new(to_decision_tree(fallback))),
decision_edges,
Some(Box::new(to_decision_tree(fallback))),
),
}
}
}
}
fn helper<'a>(
path: Vec<PathInstruction>,
mut edges: Vec<(GuardedTest<'a>, DecisionTree<'a>)>,
default: Option<Box<DecisionTree<'a>>>,
) -> DecisionTree<'a> {
match edges
.iter()
.position(|(t, _)| matches!(t, GuardedTest::Placeholder))
{
None => DecisionTree::Decision {
path,
edges,
default,
},
Some(index) => {
let (a, b) = edges.split_at_mut(index + 1);
let new_default = helper(path.clone(), b.to_vec(), default);
let mut left = a.to_vec();
let guard = left.pop().unwrap();
let help = DecisionTree::Decision {
path: path.clone(),
edges: vec![guard],
default: Some(Box::new(new_default)),
};
DecisionTree::Decision {
path,
edges: left,
default: Some(Box::new(help)),
}
}
}
@ -237,9 +285,14 @@ fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool {
let length = tests.len();
debug_assert!(length > 0);
let no_guard = tests
.iter()
.all(|t| matches!(t, GuardedTest::TestNotGuarded { .. }));
match tests.last().unwrap() {
GuardedTest::Placeholder => false,
GuardedTest::GuardedNoTest { .. } => false,
GuardedTest::TestNotGuarded { test } => tests_are_complete_help(test, length),
GuardedTest::TestNotGuarded { test } => no_guard && tests_are_complete_help(test, length),
}
}
@ -332,8 +385,8 @@ fn check_for_match(branches: &[Branch]) -> Match {
match branches.get(0) {
Some(Branch {
goal,
guard,
patterns,
guard,
}) if patterns.iter().all(|(_, pattern)| !needs_tests(pattern)) => {
if guard.is_none() {
Match::Exact(*goal)
@ -347,6 +400,7 @@ fn check_for_match(branches: &[Branch]) -> Match {
/// GATHER OUTGOING EDGES
// my understanding: branches that we could jump to based on the pattern at the current path
fn gather_edges<'a>(
branches: Vec<Branch<'a>>,
path: &[PathInstruction],
@ -428,11 +482,10 @@ fn test_at_path<'a>(
Some((_, pattern)) => {
let test = match pattern {
Identifier(_) | Underscore => {
if let Guard::Guard { id, stmt, .. } = &branch.guard {
return Some(GuardedTest::GuardedNoTest {
stmt: stmt.clone(),
id: *id,
});
if let Guard::Guard { .. } = &branch.guard {
// no tests for this pattern remain, but we cannot discard it yet
// because it has a guard!
return Some(GuardedTest::Placeholder);
} else {
return None;
}
@ -516,6 +569,7 @@ fn test_at_path<'a>(
/// BUILD EDGES
// understanding: if the test is successful, where could we go?
fn edges_for<'a>(
path: &[PathInstruction],
branches: Vec<Branch<'a>>,
@ -523,8 +577,22 @@ fn edges_for<'a>(
) -> (GuardedTest<'a>, Vec<Branch<'a>>) {
let mut new_branches = Vec::new();
for branch in branches.iter() {
to_relevant_branch(&test, path, branch, &mut new_branches);
// if we test for a guard, skip all branches until one that has a guard
let it = match test {
GuardedTest::GuardedNoTest { .. } | GuardedTest::Placeholder => {
let index = branches
.iter()
.position(|b| !b.guard.is_none())
.expect("if testing for a guard, one branch must have a guard");
branches[index..].iter()
}
GuardedTest::TestNotGuarded { .. } => branches.iter(),
};
for branch in it {
new_branches.extend(to_relevant_branch(&test, path, branch));
}
(test, new_branches)
@ -534,27 +602,27 @@ fn to_relevant_branch<'a>(
guarded_test: &GuardedTest<'a>,
path: &[PathInstruction],
branch: &Branch<'a>,
new_branches: &mut Vec<Branch<'a>>,
) {
) -> Option<Branch<'a>> {
// TODO remove clone
match extract(path, branch.patterns.clone()) {
Extract::NotFound => {
new_branches.push(branch.clone());
}
Extract::NotFound => Some(branch.clone()),
Extract::Found {
start,
found_pattern: pattern,
end,
} => match guarded_test {
GuardedTest::GuardedNoTest { .. } => {
new_branches.push(branch.clone());
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => {
// if there is no test, the pattern should not require any
debug_assert!(
matches!(pattern, Pattern::Identifier(_) | Pattern::Underscore,),
"{:?}",
pattern,
);
Some(branch.clone())
}
GuardedTest::TestNotGuarded { test } => {
if let Some(new_branch) =
to_relevant_branch_help(test, path, start, end, branch, pattern)
{
new_branches.push(new_branch);
}
}
},
}
@ -855,7 +923,9 @@ fn pick_path<'a>(branches: &'a [Branch]) -> &'a Vec<PathInstruction> {
// is choice path
for branch in branches {
for (path, pattern) in &branch.patterns {
if !branch.guard.is_none() || needs_tests(&pattern) {
// NOTE we no longer check for the guard here
// if !branch.guard.is_none() || needs_tests(&pattern) {
if needs_tests(&pattern) {
all_paths.push(path);
} else {
// do nothing
@ -1861,7 +1931,7 @@ fn fanout_decider_help<'a>(
guarded_test: GuardedTest<'a>,
) -> (Test<'a>, Decider<'a, u64>) {
match guarded_test {
GuardedTest::GuardedNoTest { .. } => {
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => {
unreachable!("this would not end up in a switch")
}
GuardedTest::TestNotGuarded { test } => {
@ -1896,6 +1966,11 @@ fn chain_decider<'a>(
to_chain(path, test, success_tree, failure_tree)
}
}
GuardedTest::Placeholder => {
// ?
tree_to_decider(success_tree)
}
}
}

View file

@ -19,7 +19,7 @@ use roc_types::subs::{Content, FlatType, Subs, Variable};
use std::collections::HashMap;
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
pub const PRETTY_PRINT_IR_SYMBOLS: bool = true;
pub const PRETTY_PRINT_IR_SYMBOLS: bool = false;
macro_rules! return_on_layout_error {
($env:expr, $layout_result:expr) => {

View file

@ -501,8 +501,43 @@ fn when_on_single_value_tag() {
);
}
#[test]
#[ignore]
fn if_guard_multiple() {
assert_evals_to!(
indoc!(
r#"
f = \n ->
when Identity n 0 is
Identity x _ if x == 0 -> x + 0
Identity x _ if x == 1 -> x + 0
Identity x _ if x == 2 -> x + 0
Identity x _ -> x - x
{ a: f 0, b: f 1, c: f 2, d: f 4 }
"#
),
(0, 1, 2, 0),
(i64, i64, i64, i64)
);
}
#[test]
fn if_guard_constructor_switch() {
assert_evals_to!(
indoc!(
r#"
when Identity 32 0 is
Identity 41 _ -> 0
Identity s 0 if s == 32 -> 3
# Identity s 0 -> s
Identity z _ -> z
"#
),
3,
i64
);
assert_evals_to!(
indoc!(
r#"
@ -531,13 +566,13 @@ fn if_guard_constructor_switch() {
}
#[test]
#[ignore]
fn if_guard_constructor_chain() {
assert_evals_to!(
indoc!(
r#"
when Identity 43 "" is
Identity 42 _ if 3 == 3 -> 1
when Identity 43 0 is
Identity 42 _ if 3 == 3 -> 43
# Identity 42 _ -> 1
Identity z _ -> z
"#
),