mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 14:54:47 +00:00
working decision tree gen for all tests
This commit is contained in:
parent
7abfca4388
commit
e05753afd8
3 changed files with 176 additions and 66 deletions
|
@ -70,6 +70,7 @@ enum GuardedTest<'a> {
|
||||||
TestNotGuarded {
|
TestNotGuarded {
|
||||||
test: Test<'a>,
|
test: Test<'a>,
|
||||||
},
|
},
|
||||||
|
Placeholder,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
|
@ -130,13 +131,17 @@ impl<'a> Hash for Test<'a> {
|
||||||
impl<'a> Hash for GuardedTest<'a> {
|
impl<'a> Hash for GuardedTest<'a> {
|
||||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
match self {
|
match self {
|
||||||
GuardedTest::GuardedNoTest { .. } => {
|
GuardedTest::GuardedNoTest { id, .. } => {
|
||||||
state.write_u8(1);
|
state.write_u8(1);
|
||||||
|
id.hash(state);
|
||||||
}
|
}
|
||||||
GuardedTest::TestNotGuarded { test } => {
|
GuardedTest::TestNotGuarded { test } => {
|
||||||
state.write_u8(0);
|
state.write_u8(0);
|
||||||
test.hash(state);
|
test.hash(state);
|
||||||
}
|
}
|
||||||
|
GuardedTest::Placeholder => {
|
||||||
|
state.write_u8(2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -159,30 +164,31 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
|
||||||
Match::Exact(goal) => DecisionTree::Match(goal),
|
Match::Exact(goal) => DecisionTree::Match(goal),
|
||||||
|
|
||||||
Match::GuardOnly => {
|
Match::GuardOnly => {
|
||||||
// the guard test does not have a path
|
// the first branch has no more tests to do, but it has an if-guard
|
||||||
let path = vec![];
|
|
||||||
|
|
||||||
let mut branches = branches;
|
let mut branches = branches;
|
||||||
let first = branches.remove(0);
|
let first = branches.remove(0);
|
||||||
|
|
||||||
let guarded_test = {
|
match first.guard {
|
||||||
|
Guard::NoGuard => unreachable!(),
|
||||||
|
|
||||||
|
Guard::Guard {
|
||||||
|
symbol: _,
|
||||||
|
id,
|
||||||
|
stmt,
|
||||||
|
} => {
|
||||||
|
let guarded_test = GuardedTest::GuardedNoTest { id, stmt };
|
||||||
|
|
||||||
|
// the guard test does not have a path
|
||||||
|
let path = vec![];
|
||||||
|
|
||||||
// we expect none of the patterns need tests, those decisions should have been made already
|
// we expect none of the patterns need tests, those decisions should have been made already
|
||||||
debug_assert!(first
|
debug_assert!(first
|
||||||
.patterns
|
.patterns
|
||||||
.iter()
|
.iter()
|
||||||
.all(|(_, pattern)| !needs_tests(pattern)));
|
.all(|(_, pattern)| !needs_tests(pattern)));
|
||||||
|
|
||||||
match first.guard {
|
let default = if branches.is_empty() {
|
||||||
Guard::NoGuard => unreachable!("first test must have a guard"),
|
|
||||||
Guard::Guard {
|
|
||||||
symbol: _,
|
|
||||||
id,
|
|
||||||
stmt,
|
|
||||||
} => GuardedTest::GuardedNoTest { id, stmt },
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let rest = if branches.is_empty() {
|
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(Box::new(to_decision_tree(branches)))
|
Some(Box::new(to_decision_tree(branches)))
|
||||||
|
@ -191,7 +197,9 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
|
||||||
DecisionTree::Decision {
|
DecisionTree::Decision {
|
||||||
path,
|
path,
|
||||||
edges: vec![(guarded_test, DecisionTree::Match(first.goal))],
|
edges: vec![(guarded_test, DecisionTree::Match(first.goal))],
|
||||||
default: rest,
|
default,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,11 +207,18 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
|
||||||
// must clone here to release the borrow on `branches`
|
// must clone here to release the borrow on `branches`
|
||||||
let path = pick_path(&branches).clone();
|
let path = pick_path(&branches).clone();
|
||||||
|
|
||||||
|
let bs = branches.clone();
|
||||||
let (edges, fallback) = gather_edges(branches, &path);
|
let (edges, fallback) = gather_edges(branches, &path);
|
||||||
|
|
||||||
let mut decision_edges: Vec<_> = edges
|
let mut decision_edges: Vec<_> = edges
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(a, b)| (a, to_decision_tree(b)))
|
.map(|(test, branches)| {
|
||||||
|
if bs == branches {
|
||||||
|
panic!();
|
||||||
|
} else {
|
||||||
|
(test, to_decision_tree(branches))
|
||||||
|
}
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
match (decision_edges.as_slice(), fallback.as_slice()) {
|
match (decision_edges.as_slice(), fallback.as_slice()) {
|
||||||
|
@ -213,21 +228,54 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
|
||||||
// get the `_decision_tree` without cloning
|
// get the `_decision_tree` without cloning
|
||||||
decision_edges.pop().unwrap().1
|
decision_edges.pop().unwrap().1
|
||||||
}
|
}
|
||||||
(_, []) => DecisionTree::Decision {
|
(_, []) => helper(path, decision_edges, None),
|
||||||
path,
|
|
||||||
edges: decision_edges,
|
|
||||||
default: None,
|
|
||||||
},
|
|
||||||
([], _) => {
|
([], _) => {
|
||||||
// should be guaranteed by the patterns
|
// should be guaranteed by the patterns
|
||||||
debug_assert!(!fallback.is_empty());
|
debug_assert!(!fallback.is_empty());
|
||||||
to_decision_tree(fallback)
|
to_decision_tree(fallback)
|
||||||
}
|
}
|
||||||
(_, _) => DecisionTree::Decision {
|
(_, _) => helper(
|
||||||
path,
|
path,
|
||||||
edges: decision_edges,
|
decision_edges,
|
||||||
default: Some(Box::new(to_decision_tree(fallback))),
|
Some(Box::new(to_decision_tree(fallback))),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn helper<'a>(
|
||||||
|
path: Vec<PathInstruction>,
|
||||||
|
mut edges: Vec<(GuardedTest<'a>, DecisionTree<'a>)>,
|
||||||
|
default: Option<Box<DecisionTree<'a>>>,
|
||||||
|
) -> DecisionTree<'a> {
|
||||||
|
match edges
|
||||||
|
.iter()
|
||||||
|
.position(|(t, _)| matches!(t, GuardedTest::Placeholder))
|
||||||
|
{
|
||||||
|
None => DecisionTree::Decision {
|
||||||
|
path,
|
||||||
|
edges,
|
||||||
|
default,
|
||||||
},
|
},
|
||||||
|
Some(index) => {
|
||||||
|
let (a, b) = edges.split_at_mut(index + 1);
|
||||||
|
|
||||||
|
let new_default = helper(path.clone(), b.to_vec(), default);
|
||||||
|
|
||||||
|
let mut left = a.to_vec();
|
||||||
|
let guard = left.pop().unwrap();
|
||||||
|
|
||||||
|
let help = DecisionTree::Decision {
|
||||||
|
path: path.clone(),
|
||||||
|
edges: vec![guard],
|
||||||
|
default: Some(Box::new(new_default)),
|
||||||
|
};
|
||||||
|
|
||||||
|
DecisionTree::Decision {
|
||||||
|
path,
|
||||||
|
edges: left,
|
||||||
|
default: Some(Box::new(help)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -237,9 +285,14 @@ fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool {
|
||||||
let length = tests.len();
|
let length = tests.len();
|
||||||
debug_assert!(length > 0);
|
debug_assert!(length > 0);
|
||||||
|
|
||||||
|
let no_guard = tests
|
||||||
|
.iter()
|
||||||
|
.all(|t| matches!(t, GuardedTest::TestNotGuarded { .. }));
|
||||||
|
|
||||||
match tests.last().unwrap() {
|
match tests.last().unwrap() {
|
||||||
|
GuardedTest::Placeholder => false,
|
||||||
GuardedTest::GuardedNoTest { .. } => false,
|
GuardedTest::GuardedNoTest { .. } => false,
|
||||||
GuardedTest::TestNotGuarded { test } => tests_are_complete_help(test, length),
|
GuardedTest::TestNotGuarded { test } => no_guard && tests_are_complete_help(test, length),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,8 +385,8 @@ fn check_for_match(branches: &[Branch]) -> Match {
|
||||||
match branches.get(0) {
|
match branches.get(0) {
|
||||||
Some(Branch {
|
Some(Branch {
|
||||||
goal,
|
goal,
|
||||||
guard,
|
|
||||||
patterns,
|
patterns,
|
||||||
|
guard,
|
||||||
}) if patterns.iter().all(|(_, pattern)| !needs_tests(pattern)) => {
|
}) if patterns.iter().all(|(_, pattern)| !needs_tests(pattern)) => {
|
||||||
if guard.is_none() {
|
if guard.is_none() {
|
||||||
Match::Exact(*goal)
|
Match::Exact(*goal)
|
||||||
|
@ -347,6 +400,7 @@ fn check_for_match(branches: &[Branch]) -> Match {
|
||||||
|
|
||||||
/// GATHER OUTGOING EDGES
|
/// GATHER OUTGOING EDGES
|
||||||
|
|
||||||
|
// my understanding: branches that we could jump to based on the pattern at the current path
|
||||||
fn gather_edges<'a>(
|
fn gather_edges<'a>(
|
||||||
branches: Vec<Branch<'a>>,
|
branches: Vec<Branch<'a>>,
|
||||||
path: &[PathInstruction],
|
path: &[PathInstruction],
|
||||||
|
@ -428,11 +482,10 @@ fn test_at_path<'a>(
|
||||||
Some((_, pattern)) => {
|
Some((_, pattern)) => {
|
||||||
let test = match pattern {
|
let test = match pattern {
|
||||||
Identifier(_) | Underscore => {
|
Identifier(_) | Underscore => {
|
||||||
if let Guard::Guard { id, stmt, .. } = &branch.guard {
|
if let Guard::Guard { .. } = &branch.guard {
|
||||||
return Some(GuardedTest::GuardedNoTest {
|
// no tests for this pattern remain, but we cannot discard it yet
|
||||||
stmt: stmt.clone(),
|
// because it has a guard!
|
||||||
id: *id,
|
return Some(GuardedTest::Placeholder);
|
||||||
});
|
|
||||||
} else {
|
} else {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
@ -516,6 +569,7 @@ fn test_at_path<'a>(
|
||||||
|
|
||||||
/// BUILD EDGES
|
/// BUILD EDGES
|
||||||
|
|
||||||
|
// understanding: if the test is successful, where could we go?
|
||||||
fn edges_for<'a>(
|
fn edges_for<'a>(
|
||||||
path: &[PathInstruction],
|
path: &[PathInstruction],
|
||||||
branches: Vec<Branch<'a>>,
|
branches: Vec<Branch<'a>>,
|
||||||
|
@ -523,8 +577,22 @@ fn edges_for<'a>(
|
||||||
) -> (GuardedTest<'a>, Vec<Branch<'a>>) {
|
) -> (GuardedTest<'a>, Vec<Branch<'a>>) {
|
||||||
let mut new_branches = Vec::new();
|
let mut new_branches = Vec::new();
|
||||||
|
|
||||||
for branch in branches.iter() {
|
// if we test for a guard, skip all branches until one that has a guard
|
||||||
to_relevant_branch(&test, path, branch, &mut new_branches);
|
|
||||||
|
let it = match test {
|
||||||
|
GuardedTest::GuardedNoTest { .. } | GuardedTest::Placeholder => {
|
||||||
|
let index = branches
|
||||||
|
.iter()
|
||||||
|
.position(|b| !b.guard.is_none())
|
||||||
|
.expect("if testing for a guard, one branch must have a guard");
|
||||||
|
|
||||||
|
branches[index..].iter()
|
||||||
|
}
|
||||||
|
GuardedTest::TestNotGuarded { .. } => branches.iter(),
|
||||||
|
};
|
||||||
|
|
||||||
|
for branch in it {
|
||||||
|
new_branches.extend(to_relevant_branch(&test, path, branch));
|
||||||
}
|
}
|
||||||
|
|
||||||
(test, new_branches)
|
(test, new_branches)
|
||||||
|
@ -534,27 +602,27 @@ fn to_relevant_branch<'a>(
|
||||||
guarded_test: &GuardedTest<'a>,
|
guarded_test: &GuardedTest<'a>,
|
||||||
path: &[PathInstruction],
|
path: &[PathInstruction],
|
||||||
branch: &Branch<'a>,
|
branch: &Branch<'a>,
|
||||||
new_branches: &mut Vec<Branch<'a>>,
|
) -> Option<Branch<'a>> {
|
||||||
) {
|
|
||||||
// TODO remove clone
|
// TODO remove clone
|
||||||
match extract(path, branch.patterns.clone()) {
|
match extract(path, branch.patterns.clone()) {
|
||||||
Extract::NotFound => {
|
Extract::NotFound => Some(branch.clone()),
|
||||||
new_branches.push(branch.clone());
|
|
||||||
}
|
|
||||||
Extract::Found {
|
Extract::Found {
|
||||||
start,
|
start,
|
||||||
found_pattern: pattern,
|
found_pattern: pattern,
|
||||||
end,
|
end,
|
||||||
} => match guarded_test {
|
} => match guarded_test {
|
||||||
GuardedTest::GuardedNoTest { .. } => {
|
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => {
|
||||||
new_branches.push(branch.clone());
|
// if there is no test, the pattern should not require any
|
||||||
|
debug_assert!(
|
||||||
|
matches!(pattern, Pattern::Identifier(_) | Pattern::Underscore,),
|
||||||
|
"{:?}",
|
||||||
|
pattern,
|
||||||
|
);
|
||||||
|
|
||||||
|
Some(branch.clone())
|
||||||
}
|
}
|
||||||
GuardedTest::TestNotGuarded { test } => {
|
GuardedTest::TestNotGuarded { test } => {
|
||||||
if let Some(new_branch) =
|
|
||||||
to_relevant_branch_help(test, path, start, end, branch, pattern)
|
to_relevant_branch_help(test, path, start, end, branch, pattern)
|
||||||
{
|
|
||||||
new_branches.push(new_branch);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -855,7 +923,9 @@ fn pick_path<'a>(branches: &'a [Branch]) -> &'a Vec<PathInstruction> {
|
||||||
// is choice path
|
// is choice path
|
||||||
for branch in branches {
|
for branch in branches {
|
||||||
for (path, pattern) in &branch.patterns {
|
for (path, pattern) in &branch.patterns {
|
||||||
if !branch.guard.is_none() || needs_tests(&pattern) {
|
// NOTE we no longer check for the guard here
|
||||||
|
// if !branch.guard.is_none() || needs_tests(&pattern) {
|
||||||
|
if needs_tests(&pattern) {
|
||||||
all_paths.push(path);
|
all_paths.push(path);
|
||||||
} else {
|
} else {
|
||||||
// do nothing
|
// do nothing
|
||||||
|
@ -1861,7 +1931,7 @@ fn fanout_decider_help<'a>(
|
||||||
guarded_test: GuardedTest<'a>,
|
guarded_test: GuardedTest<'a>,
|
||||||
) -> (Test<'a>, Decider<'a, u64>) {
|
) -> (Test<'a>, Decider<'a, u64>) {
|
||||||
match guarded_test {
|
match guarded_test {
|
||||||
GuardedTest::GuardedNoTest { .. } => {
|
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => {
|
||||||
unreachable!("this would not end up in a switch")
|
unreachable!("this would not end up in a switch")
|
||||||
}
|
}
|
||||||
GuardedTest::TestNotGuarded { test } => {
|
GuardedTest::TestNotGuarded { test } => {
|
||||||
|
@ -1896,6 +1966,11 @@ fn chain_decider<'a>(
|
||||||
to_chain(path, test, success_tree, failure_tree)
|
to_chain(path, test, success_tree, failure_tree)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GuardedTest::Placeholder => {
|
||||||
|
// ?
|
||||||
|
tree_to_decider(success_tree)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ use roc_types::subs::{Content, FlatType, Subs, Variable};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
|
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
|
||||||
|
|
||||||
pub const PRETTY_PRINT_IR_SYMBOLS: bool = true;
|
pub const PRETTY_PRINT_IR_SYMBOLS: bool = false;
|
||||||
|
|
||||||
macro_rules! return_on_layout_error {
|
macro_rules! return_on_layout_error {
|
||||||
($env:expr, $layout_result:expr) => {
|
($env:expr, $layout_result:expr) => {
|
||||||
|
|
|
@ -501,8 +501,43 @@ fn when_on_single_value_tag() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore]
|
||||||
|
fn if_guard_multiple() {
|
||||||
|
assert_evals_to!(
|
||||||
|
indoc!(
|
||||||
|
r#"
|
||||||
|
f = \n ->
|
||||||
|
when Identity n 0 is
|
||||||
|
Identity x _ if x == 0 -> x + 0
|
||||||
|
Identity x _ if x == 1 -> x + 0
|
||||||
|
Identity x _ if x == 2 -> x + 0
|
||||||
|
Identity x _ -> x - x
|
||||||
|
|
||||||
|
{ a: f 0, b: f 1, c: f 2, d: f 4 }
|
||||||
|
"#
|
||||||
|
),
|
||||||
|
(0, 1, 2, 0),
|
||||||
|
(i64, i64, i64, i64)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn if_guard_constructor_switch() {
|
fn if_guard_constructor_switch() {
|
||||||
|
assert_evals_to!(
|
||||||
|
indoc!(
|
||||||
|
r#"
|
||||||
|
when Identity 32 0 is
|
||||||
|
Identity 41 _ -> 0
|
||||||
|
Identity s 0 if s == 32 -> 3
|
||||||
|
# Identity s 0 -> s
|
||||||
|
Identity z _ -> z
|
||||||
|
"#
|
||||||
|
),
|
||||||
|
3,
|
||||||
|
i64
|
||||||
|
);
|
||||||
|
|
||||||
assert_evals_to!(
|
assert_evals_to!(
|
||||||
indoc!(
|
indoc!(
|
||||||
r#"
|
r#"
|
||||||
|
@ -531,13 +566,13 @@ fn if_guard_constructor_switch() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
|
||||||
fn if_guard_constructor_chain() {
|
fn if_guard_constructor_chain() {
|
||||||
assert_evals_to!(
|
assert_evals_to!(
|
||||||
indoc!(
|
indoc!(
|
||||||
r#"
|
r#"
|
||||||
when Identity 43 "" is
|
when Identity 43 0 is
|
||||||
Identity 42 _ if 3 == 3 -> 1
|
Identity 42 _ if 3 == 3 -> 43
|
||||||
|
# Identity 42 _ -> 1
|
||||||
Identity z _ -> z
|
Identity z _ -> z
|
||||||
"#
|
"#
|
||||||
),
|
),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue