parity with old implementation

This commit is contained in:
Folkert 2020-03-16 16:19:44 +01:00
parent d875f8bfce
commit e742b77e0b
6 changed files with 345 additions and 303 deletions

View file

@ -3,6 +3,10 @@ use crate::expr::Expr;
use crate::expr::Pattern;
use roc_collections::all::{MutMap, MutSet};
use roc_module::ident::TagName;
use roc_module::symbol::Symbol;
use crate::layout::Builtin;
use crate::layout::Layout;
/// COMPILE CASES
@ -12,7 +16,7 @@ type Label = u64;
/// some normal branches and gives out a decision tree that has "labels" at all
/// the leafs and a dictionary that maps these "labels" to the code that should
/// run.
pub fn compile<'a>(raw_branches: Vec<(Pattern<'a>, u64)>) -> DecisionTree {
pub fn compile(raw_branches: Vec<(Pattern<'_>, u64)>) -> DecisionTree {
let formatted = raw_branches
.into_iter()
.map(|(pattern, index)| Branch {
@ -110,7 +114,7 @@ fn is_complete(tests: &[Test]) -> bool {
}
}
fn flatten_patterns<'a>(branch: Branch<'a>) -> Branch<'a> {
fn flatten_patterns(branch: Branch) -> Branch {
let mut result = Vec::with_capacity(branch.patterns.len());
for path_pattern in branch.patterns {
@ -124,9 +128,7 @@ fn flatten_patterns<'a>(branch: Branch<'a>) -> Branch<'a> {
fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, Pattern<'a>)>) {
match &path_pattern.1 {
Pattern::AppliedTag {
union, arguments, ..
} => {
Pattern::AppliedTag { union, .. } => {
if union.alternatives.len() == 1 {
// case map dearg ctorArgs of
// [arg] ->
@ -162,8 +164,9 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
/// variables to "how to get their value". So a pattern like (Just (x,_)) will give
/// us something like ("x" => value.0.0)
fn check_for_match(branches: &Vec<Branch>) -> Option<Label> {
match branches.get(branches.len() - 1) {
match branches.get(0) {
Some(Branch { goal, patterns }) if patterns.iter().all(|(_, p)| !needs_tests(p)) => {
println!("the expected case {:?} {:?}", goal, patterns);
Some(*goal)
}
_ => None,
@ -256,7 +259,7 @@ fn float_to_i64(float: f64) -> i64 {
// We assume that `v` is normal (not Nan, Infinity, -Infinity)
// those values cannot occur in patterns in Roc, so this code is safe
debug_assert!(float.is_normal());
unsafe { std::mem::transmute::<f64, i64>(float) }
float.to_bits() as i64
}
/// BUILD EDGES
@ -278,7 +281,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
match extract(path, branch.patterns.clone()) {
Extract::NotFound => Some(branch),
Extract::Found {
start: mut start,
mut start,
found_pattern: pattern,
end,
} => match pattern {
@ -306,7 +309,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
todo!()
}
StrLiteral(string) => match test {
IsStr(testStr) if string == *testStr => {
IsStr(test_str) if string == *test_str => {
start.extend(end);
Some(Branch {
goal: branch.goal,
@ -317,7 +320,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
},
IntLiteral(int) => match test {
IsInt(testInt) if int == *testInt => {
IsInt(is_int) if int == *is_int => {
start.extend(end);
Some(Branch {
goal: branch.goal,
@ -328,7 +331,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
},
FloatLiteral(float) => match test {
IsFloat(testFloat) if float_to_i64(float) == *testFloat => {
IsFloat(test_float) if float_to_i64(float) == *test_float => {
start.extend(end);
Some(Branch {
goal: branch.goal,
@ -339,7 +342,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
},
BitLiteral(bit) => match test {
IsBit(testBit) if bit == *testBit => {
IsBit(test_bit) if bit == *test_bit => {
start.extend(end);
Some(Branch {
goal: branch.goal,
@ -388,7 +391,7 @@ fn extract<'a>(selected_path: &Path, path_patterns: Vec<(Path, Pattern<'a>)>) ->
start,
found_pattern: current.1,
end: {
copy.drain(0..index);
copy.drain(0..=index);
copy
},
};
@ -445,7 +448,7 @@ fn pick_path(branches: Vec<Branch>) -> Path {
if by_small_defaults.len() == 1 {
by_small_defaults.remove(0)
} else {
debug_assert!(by_small_defaults.len() > 0);
debug_assert!(!by_small_defaults.is_empty());
let mut result = bests_by_small_branching_factor(&branches, by_small_defaults.into_iter());
match result.pop() {
@ -455,7 +458,7 @@ fn pick_path(branches: Vec<Branch>) -> Path {
}
}
fn is_choice_path<'a>(path_and_pattern: (Path, Pattern<'a>)) -> Option<Path> {
fn is_choice_path(path_and_pattern: (Path, Pattern<'_>)) -> Option<Path> {
let (path, pattern) = path_and_pattern;
if needs_tests(&pattern) {
@ -477,12 +480,18 @@ where
for path in all_paths {
let weight = small_branching_factor(branches, &path);
if weight == min_weight {
min_paths.push(path.clone());
} else if weight < min_weight {
min_weight = weight;
min_paths.clear();
min_paths.push(path);
use std::cmp::Ordering;
match weight.cmp(&min_weight) {
Ordering::Equal => {
min_paths.push(path.clone());
}
Ordering::Less => {
min_weight = weight;
min_paths.clear();
min_paths.push(path);
}
Ordering::Greater => {}
}
}
@ -503,12 +512,18 @@ where
for path in all_paths {
let weight = small_defaults(branches, &path);
if weight == min_weight {
min_paths.push(path.clone());
} else if weight < min_weight {
min_weight = weight;
min_paths.clear();
min_paths.push(path);
use std::cmp::Ordering;
match weight.cmp(&min_weight) {
Ordering::Equal => {
min_paths.push(path.clone());
}
Ordering::Less => {
min_weight = weight;
min_paths.clear();
min_paths.push(path);
}
Ordering::Greater => {}
}
}
@ -555,7 +570,13 @@ enum Choice<'a> {
Jump(Label),
}
fn optimize_when<'a>(env: Env<'a, '_>, opt_branches: Vec<(Pattern<'a>, Expr<'a>)>) -> Expr<'a> {
pub fn optimize_when<'a>(
env: &mut Env<'a, '_>,
cond_symbol: Symbol,
cond_layout: Layout<'a>,
ret_layout: Layout<'a>,
opt_branches: Vec<(Pattern<'a>, Expr<'a>)>,
) -> Expr<'a> {
let (patterns, _indexed_branches) = opt_branches
.into_iter()
.enumerate()
@ -564,7 +585,8 @@ fn optimize_when<'a>(env: Env<'a, '_>, opt_branches: Vec<(Pattern<'a>, Expr<'a>)
let indexed_branches: Vec<(u64, Expr<'a>)> = _indexed_branches;
let decider = tree_to_decider(compile(patterns));
let decision_tree = compile(patterns);
let decider = tree_to_decider(decision_tree);
let target_counts = count_targets(&decider);
let mut choices = MutMap::default();
@ -582,7 +604,14 @@ fn optimize_when<'a>(env: Env<'a, '_>, opt_branches: Vec<(Pattern<'a>, Expr<'a>)
let choice_decider = insert_choices(&choices, decider);
decide_to_branching(env, choice_decider, jumps)
decide_to_branching(
env,
cond_symbol,
cond_layout,
ret_layout,
choice_decider,
&jumps,
)
}
/*
*
@ -606,16 +635,26 @@ enum Choice<'a> {
}
*/
fn path_to_expr<'a>(_env: &mut Env<'a, '_>, symbol: Symbol, path: &Path) -> Expr<'a> {
match path {
Path::Empty => Expr::Load(symbol),
_ => todo!(),
}
}
fn decide_to_branching<'a>(
env: crate::expr::Env<'a, '_>,
env: &mut Env<'a, '_>,
cond_symbol: Symbol,
cond_layout: Layout<'a>,
ret_layout: Layout<'a>,
decider: Decider<Choice<'a>>,
jumps: Vec<(u64, Expr<'a>)>,
jumps: &Vec<(u64, Expr<'a>)>,
) -> Expr<'a> {
use Choice::*;
use Decider::*;
match decider {
Leaf(Jump(label)) => todo!(),
Leaf(Jump(_label)) => todo!(),
Leaf(Inline(expr)) => expr,
Chain {
test_chain,
@ -624,15 +663,148 @@ fn decide_to_branching<'a>(
} => {
// generate a switch based on the test chain
todo!()
let mut tests = Vec::with_capacity(test_chain.len());
for (path, test) in test_chain {
match test {
Test::IsInt(test_int) => {
let lhs = Expr::Int(test_int);
let rhs = path_to_expr(env, cond_symbol, &path);
let cond = env.arena.alloc(Expr::CallByName(
Symbol::INT_EQ_I64,
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Int64)),
(rhs, Layout::Builtin(Builtin::Int64)),
]),
));
tests.push(cond);
}
Test::IsFloat(test_int) => {
// TODO maybe we can actually use i64 comparison here?
let test_float = f64::from_bits(test_int as u64);
let lhs = Expr::Float(test_float);
let rhs = path_to_expr(env, cond_symbol, &path);
let cond = env.arena.alloc(Expr::CallByName(
Symbol::FLOAT_EQ,
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Float64)),
(rhs, Layout::Builtin(Builtin::Float64)),
]),
));
tests.push(cond);
}
Test::IsByte {
tag_id: test_byte,
// num_alts: _,
..
} => {
let lhs = Expr::Byte(test_byte);
let rhs = path_to_expr(env, cond_symbol, &path);
let fake = MutMap::default();
let cond = env.arena.alloc(Expr::CallByName(
Symbol::INT_EQ_I8,
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Byte(fake.clone()))),
(rhs, Layout::Builtin(Builtin::Byte(fake))),
]),
));
tests.push(cond);
}
_ => todo!(),
}
}
let cond = tests.remove(0);
let pass = env.arena.alloc(decide_to_branching(
env,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
*success,
jumps,
));
let fail = env.arena.alloc(decide_to_branching(
env,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
*failure,
jumps,
));
let cond_layout = Layout::Builtin(Builtin::Bool(
TagName::Global("False".into()),
TagName::Global("True".into()),
));
Expr::Cond {
cond,
cond_layout,
pass,
fail,
ret_layout,
}
}
FanOut {
path,
tests,
fallback,
} => {
let cond = env.arena.alloc(path_to_expr(env, cond_symbol, &path));
let default_branch = env.arena.alloc(decide_to_branching(
env,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
*fallback,
jumps,
));
let mut branches = bumpalo::collections::Vec::with_capacity_in(tests.len(), env.arena);
for (test, decider) in tests {
let branch = decide_to_branching(
env,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
decider,
jumps,
);
let tag = match test {
Test::IsInt(v) => v as u64,
Test::IsFloat(v) => v as u64,
Test::IsBit(v) => v as u64,
Test::IsByte { tag_id, .. } => tag_id as u64,
_ => todo!(),
};
branches.push((tag, branch));
}
// make a jump table based on the tests
todo!()
Expr::Switch {
cond,
cond_layout,
// branches: &'a [(u64, Expr<'a>)],
branches: branches.into_bump_slice(),
// default_branch: &'a Expr<'a>,
default_branch,
ret_layout,
}
}
}
}
@ -764,7 +936,9 @@ fn count_targets_help(decision_tree: &Decider<u64>, targets: &mut MutMap<u64, u6
None => {
targets.insert(*target, 1);
}
Some(current) => *current = *current + 1,
Some(current) => {
*current += 1;
}
},
Chain {
@ -791,10 +965,13 @@ fn create_choices<'a>(
target: u64,
branch: Expr<'a>,
) -> ((u64, Choice<'a>), Option<(u64, Expr<'a>)>) {
if target_counts[&target] == 1 {
((target, Choice::Inline(branch)), None)
} else {
((target, Choice::Jump(target)), Some((target, branch)))
match target_counts.get(&target) {
None => unreachable!(
"this should never happen: {:?} not in {:?}",
target, target_counts
),
Some(1) => ((target, Choice::Inline(branch)), None),
Some(_) => ((target, Choice::Jump(target)), Some((target, branch))),
}
}