From 2c0aa8a5a1ab1ac337e312a92ecf95fbab09aef6 Mon Sep 17 00:00:00 2001 From: Folkert Date: Sun, 4 Jul 2021 22:35:00 +0200 Subject: [PATCH] handle guards in a first-class way --- compiler/mono/src/decision_tree.rs | 469 ++++++++++++++--------------- 1 file changed, 234 insertions(+), 235 deletions(-) diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index 3bcddaba45..e1987709c1 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -17,7 +17,7 @@ const RECORD_TAG_NAME: &str = "#Record"; /// some normal branches and gives out a decision tree that has "labels" at all /// the leafs and a dictionary that maps these "labels" to the code that should /// run. -pub fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> { +fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> { let formatted = raw_branches .into_iter() .map(|(guard, pattern, index)| Branch { @@ -49,15 +49,35 @@ impl<'a> Guard<'a> { } #[derive(Clone, Debug, PartialEq)] -pub enum DecisionTree<'a> { +enum DecisionTree<'a> { Match(Label), Decision { path: Vec, - edges: Vec<(Test<'a>, DecisionTree<'a>)>, + edges: Vec<(GuardedTest<'a>, DecisionTree<'a>)>, default: Option>>, }, } +#[derive(Clone, Debug, PartialEq)] +pub enum GuardedTest<'a> { + TestGuarded { + test: Test<'a>, + + /// after assigning to symbol, the stmt jumps to this label + id: JoinPointId, + stmt: Stmt<'a>, + }, + // e.g. `_ if True -> ...` + GuardedNoTest { + /// after assigning to symbol, the stmt jumps to this label + id: JoinPointId, + stmt: Stmt<'a>, + }, + TestNotGuarded { + test: Test<'a>, + }, +} + #[derive(Clone, Debug, PartialEq)] pub enum Test<'a> { IsCtor { @@ -75,16 +95,6 @@ pub enum Test<'a> { tag_id: u8, num_alts: usize, }, - // A pattern that always succeeds (like `_`) can still have a guard - Guarded { - opt_test: Option>>, - /// Symbol that stores a boolean - /// when true this branch is picked, otherwise skipped - symbol: Symbol, - /// after assigning to symbol, the stmt jumps to this label - id: JoinPointId, - stmt: Stmt<'a>, - }, } use std::hash::{Hash, Hasher}; impl<'a> Hash for Test<'a> { @@ -118,15 +128,23 @@ impl<'a> Hash for Test<'a> { tag_id.hash(state); num_alts.hash(state); } - Guarded { opt_test: None, .. } => { - state.write_u8(6); + } + } +} + +impl<'a> Hash for GuardedTest<'a> { + fn hash(&self, state: &mut H) { + match self { + GuardedTest::TestGuarded { test, .. } => { + state.write_u8(0); + test.hash(state); } - Guarded { - opt_test: Some(nested), - .. - } => { - state.write_u8(7); - nested.hash(state); + GuardedTest::GuardedNoTest { id, stmt } => { + state.write_u8(1); + } + GuardedTest::TestNotGuarded { test } => { + state.write_u8(2); + test.hash(state); } } } @@ -182,20 +200,32 @@ fn to_decision_tree(raw_branches: Vec) -> DecisionTree { } } -fn is_complete(tests: &[Test]) -> bool { +fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool { let length = tests.len(); debug_assert!(length > 0); - match tests.last() { - None => unreachable!("should never happen"), - Some(v) => match v { - Test::IsCtor { union, .. } => length == union.alternatives.len(), - Test::IsByte { num_alts, .. } => length == *num_alts, - Test::IsBit(_) => length == 2, - Test::IsInt(_) => false, - Test::IsFloat(_) => false, - Test::IsStr(_) => false, - Test::Guarded { .. } => false, - }, + + match tests.last().unwrap() { + GuardedTest::TestGuarded { .. } => false, + GuardedTest::GuardedNoTest { .. } => false, + GuardedTest::TestNotGuarded { test } => tests_are_complete_help(test, length), + } +} + +fn tests_are_complete(tests: &[Test]) -> bool { + let length = tests.len(); + debug_assert!(length > 0); + + tests_are_complete_help(tests.last().unwrap(), length) +} + +fn tests_are_complete_help(last_test: &Test, number_of_tests: usize) -> bool { + match last_test { + Test::IsCtor { union, .. } => number_of_tests == union.alternatives.len(), + Test::IsByte { num_alts, .. } => number_of_tests == *num_alts, + Test::IsBit(_) => number_of_tests == 2, + Test::IsInt(_) => false, + Test::IsFloat(_) => false, + Test::IsStr(_) => false, } } @@ -293,10 +323,10 @@ fn check_for_match(branches: &[Branch]) -> Option