mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-30 07:14:46 +00:00
parity with old implementation
This commit is contained in:
parent
d875f8bfce
commit
e742b77e0b
6 changed files with 345 additions and 303 deletions
|
@ -3,6 +3,10 @@ use crate::expr::Expr;
|
|||
use crate::expr::Pattern;
|
||||
use roc_collections::all::{MutMap, MutSet};
|
||||
use roc_module::ident::TagName;
|
||||
use roc_module::symbol::Symbol;
|
||||
|
||||
use crate::layout::Builtin;
|
||||
use crate::layout::Layout;
|
||||
|
||||
/// COMPILE CASES
|
||||
|
||||
|
@ -12,7 +16,7 @@ type Label = u64;
|
|||
/// some normal branches and gives out a decision tree that has "labels" at all
|
||||
/// the leafs and a dictionary that maps these "labels" to the code that should
|
||||
/// run.
|
||||
pub fn compile<'a>(raw_branches: Vec<(Pattern<'a>, u64)>) -> DecisionTree {
|
||||
pub fn compile(raw_branches: Vec<(Pattern<'_>, u64)>) -> DecisionTree {
|
||||
let formatted = raw_branches
|
||||
.into_iter()
|
||||
.map(|(pattern, index)| Branch {
|
||||
|
@ -110,7 +114,7 @@ fn is_complete(tests: &[Test]) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
fn flatten_patterns<'a>(branch: Branch<'a>) -> Branch<'a> {
|
||||
fn flatten_patterns(branch: Branch) -> Branch {
|
||||
let mut result = Vec::with_capacity(branch.patterns.len());
|
||||
|
||||
for path_pattern in branch.patterns {
|
||||
|
@ -124,9 +128,7 @@ fn flatten_patterns<'a>(branch: Branch<'a>) -> Branch<'a> {
|
|||
|
||||
fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, Pattern<'a>)>) {
|
||||
match &path_pattern.1 {
|
||||
Pattern::AppliedTag {
|
||||
union, arguments, ..
|
||||
} => {
|
||||
Pattern::AppliedTag { union, .. } => {
|
||||
if union.alternatives.len() == 1 {
|
||||
// case map dearg ctorArgs of
|
||||
// [arg] ->
|
||||
|
@ -162,8 +164,9 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
|
|||
/// variables to "how to get their value". So a pattern like (Just (x,_)) will give
|
||||
/// us something like ("x" => value.0.0)
|
||||
fn check_for_match(branches: &Vec<Branch>) -> Option<Label> {
|
||||
match branches.get(branches.len() - 1) {
|
||||
match branches.get(0) {
|
||||
Some(Branch { goal, patterns }) if patterns.iter().all(|(_, p)| !needs_tests(p)) => {
|
||||
println!("the expected case {:?} {:?}", goal, patterns);
|
||||
Some(*goal)
|
||||
}
|
||||
_ => None,
|
||||
|
@ -256,7 +259,7 @@ fn float_to_i64(float: f64) -> i64 {
|
|||
// We assume that `v` is normal (not Nan, Infinity, -Infinity)
|
||||
// those values cannot occur in patterns in Roc, so this code is safe
|
||||
debug_assert!(float.is_normal());
|
||||
unsafe { std::mem::transmute::<f64, i64>(float) }
|
||||
float.to_bits() as i64
|
||||
}
|
||||
|
||||
/// BUILD EDGES
|
||||
|
@ -278,7 +281,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
|
|||
match extract(path, branch.patterns.clone()) {
|
||||
Extract::NotFound => Some(branch),
|
||||
Extract::Found {
|
||||
start: mut start,
|
||||
mut start,
|
||||
found_pattern: pattern,
|
||||
end,
|
||||
} => match pattern {
|
||||
|
@ -306,7 +309,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
|
|||
todo!()
|
||||
}
|
||||
StrLiteral(string) => match test {
|
||||
IsStr(testStr) if string == *testStr => {
|
||||
IsStr(test_str) if string == *test_str => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
|
@ -317,7 +320,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
|
|||
},
|
||||
|
||||
IntLiteral(int) => match test {
|
||||
IsInt(testInt) if int == *testInt => {
|
||||
IsInt(is_int) if int == *is_int => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
|
@ -328,7 +331,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
|
|||
},
|
||||
|
||||
FloatLiteral(float) => match test {
|
||||
IsFloat(testFloat) if float_to_i64(float) == *testFloat => {
|
||||
IsFloat(test_float) if float_to_i64(float) == *test_float => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
|
@ -339,7 +342,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
|
|||
},
|
||||
|
||||
BitLiteral(bit) => match test {
|
||||
IsBit(testBit) if bit == *testBit => {
|
||||
IsBit(test_bit) if bit == *test_bit => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
|
@ -388,7 +391,7 @@ fn extract<'a>(selected_path: &Path, path_patterns: Vec<(Path, Pattern<'a>)>) ->
|
|||
start,
|
||||
found_pattern: current.1,
|
||||
end: {
|
||||
copy.drain(0..index);
|
||||
copy.drain(0..=index);
|
||||
copy
|
||||
},
|
||||
};
|
||||
|
@ -445,7 +448,7 @@ fn pick_path(branches: Vec<Branch>) -> Path {
|
|||
if by_small_defaults.len() == 1 {
|
||||
by_small_defaults.remove(0)
|
||||
} else {
|
||||
debug_assert!(by_small_defaults.len() > 0);
|
||||
debug_assert!(!by_small_defaults.is_empty());
|
||||
let mut result = bests_by_small_branching_factor(&branches, by_small_defaults.into_iter());
|
||||
|
||||
match result.pop() {
|
||||
|
@ -455,7 +458,7 @@ fn pick_path(branches: Vec<Branch>) -> Path {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_choice_path<'a>(path_and_pattern: (Path, Pattern<'a>)) -> Option<Path> {
|
||||
fn is_choice_path(path_and_pattern: (Path, Pattern<'_>)) -> Option<Path> {
|
||||
let (path, pattern) = path_and_pattern;
|
||||
|
||||
if needs_tests(&pattern) {
|
||||
|
@ -477,12 +480,18 @@ where
|
|||
|
||||
for path in all_paths {
|
||||
let weight = small_branching_factor(branches, &path);
|
||||
if weight == min_weight {
|
||||
min_paths.push(path.clone());
|
||||
} else if weight < min_weight {
|
||||
min_weight = weight;
|
||||
min_paths.clear();
|
||||
min_paths.push(path);
|
||||
|
||||
use std::cmp::Ordering;
|
||||
match weight.cmp(&min_weight) {
|
||||
Ordering::Equal => {
|
||||
min_paths.push(path.clone());
|
||||
}
|
||||
Ordering::Less => {
|
||||
min_weight = weight;
|
||||
min_paths.clear();
|
||||
min_paths.push(path);
|
||||
}
|
||||
Ordering::Greater => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,12 +512,18 @@ where
|
|||
|
||||
for path in all_paths {
|
||||
let weight = small_defaults(branches, &path);
|
||||
if weight == min_weight {
|
||||
min_paths.push(path.clone());
|
||||
} else if weight < min_weight {
|
||||
min_weight = weight;
|
||||
min_paths.clear();
|
||||
min_paths.push(path);
|
||||
|
||||
use std::cmp::Ordering;
|
||||
match weight.cmp(&min_weight) {
|
||||
Ordering::Equal => {
|
||||
min_paths.push(path.clone());
|
||||
}
|
||||
Ordering::Less => {
|
||||
min_weight = weight;
|
||||
min_paths.clear();
|
||||
min_paths.push(path);
|
||||
}
|
||||
Ordering::Greater => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -555,7 +570,13 @@ enum Choice<'a> {
|
|||
Jump(Label),
|
||||
}
|
||||
|
||||
fn optimize_when<'a>(env: Env<'a, '_>, opt_branches: Vec<(Pattern<'a>, Expr<'a>)>) -> Expr<'a> {
|
||||
pub fn optimize_when<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
cond_symbol: Symbol,
|
||||
cond_layout: Layout<'a>,
|
||||
ret_layout: Layout<'a>,
|
||||
opt_branches: Vec<(Pattern<'a>, Expr<'a>)>,
|
||||
) -> Expr<'a> {
|
||||
let (patterns, _indexed_branches) = opt_branches
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
|
@ -564,7 +585,8 @@ fn optimize_when<'a>(env: Env<'a, '_>, opt_branches: Vec<(Pattern<'a>, Expr<'a>)
|
|||
|
||||
let indexed_branches: Vec<(u64, Expr<'a>)> = _indexed_branches;
|
||||
|
||||
let decider = tree_to_decider(compile(patterns));
|
||||
let decision_tree = compile(patterns);
|
||||
let decider = tree_to_decider(decision_tree);
|
||||
let target_counts = count_targets(&decider);
|
||||
|
||||
let mut choices = MutMap::default();
|
||||
|
@ -582,7 +604,14 @@ fn optimize_when<'a>(env: Env<'a, '_>, opt_branches: Vec<(Pattern<'a>, Expr<'a>)
|
|||
|
||||
let choice_decider = insert_choices(&choices, decider);
|
||||
|
||||
decide_to_branching(env, choice_decider, jumps)
|
||||
decide_to_branching(
|
||||
env,
|
||||
cond_symbol,
|
||||
cond_layout,
|
||||
ret_layout,
|
||||
choice_decider,
|
||||
&jumps,
|
||||
)
|
||||
}
|
||||
/*
|
||||
*
|
||||
|
@ -606,16 +635,26 @@ enum Choice<'a> {
|
|||
}
|
||||
*/
|
||||
|
||||
fn path_to_expr<'a>(_env: &mut Env<'a, '_>, symbol: Symbol, path: &Path) -> Expr<'a> {
|
||||
match path {
|
||||
Path::Empty => Expr::Load(symbol),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn decide_to_branching<'a>(
|
||||
env: crate::expr::Env<'a, '_>,
|
||||
env: &mut Env<'a, '_>,
|
||||
cond_symbol: Symbol,
|
||||
cond_layout: Layout<'a>,
|
||||
ret_layout: Layout<'a>,
|
||||
decider: Decider<Choice<'a>>,
|
||||
jumps: Vec<(u64, Expr<'a>)>,
|
||||
jumps: &Vec<(u64, Expr<'a>)>,
|
||||
) -> Expr<'a> {
|
||||
use Choice::*;
|
||||
use Decider::*;
|
||||
|
||||
match decider {
|
||||
Leaf(Jump(label)) => todo!(),
|
||||
Leaf(Jump(_label)) => todo!(),
|
||||
Leaf(Inline(expr)) => expr,
|
||||
Chain {
|
||||
test_chain,
|
||||
|
@ -624,15 +663,148 @@ fn decide_to_branching<'a>(
|
|||
} => {
|
||||
// generate a switch based on the test chain
|
||||
|
||||
todo!()
|
||||
let mut tests = Vec::with_capacity(test_chain.len());
|
||||
|
||||
for (path, test) in test_chain {
|
||||
match test {
|
||||
Test::IsInt(test_int) => {
|
||||
let lhs = Expr::Int(test_int);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path);
|
||||
|
||||
let cond = env.arena.alloc(Expr::CallByName(
|
||||
Symbol::INT_EQ_I64,
|
||||
env.arena.alloc([
|
||||
(lhs, Layout::Builtin(Builtin::Int64)),
|
||||
(rhs, Layout::Builtin(Builtin::Int64)),
|
||||
]),
|
||||
));
|
||||
|
||||
tests.push(cond);
|
||||
}
|
||||
|
||||
Test::IsFloat(test_int) => {
|
||||
// TODO maybe we can actually use i64 comparison here?
|
||||
let test_float = f64::from_bits(test_int as u64);
|
||||
let lhs = Expr::Float(test_float);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path);
|
||||
|
||||
let cond = env.arena.alloc(Expr::CallByName(
|
||||
Symbol::FLOAT_EQ,
|
||||
env.arena.alloc([
|
||||
(lhs, Layout::Builtin(Builtin::Float64)),
|
||||
(rhs, Layout::Builtin(Builtin::Float64)),
|
||||
]),
|
||||
));
|
||||
|
||||
tests.push(cond);
|
||||
}
|
||||
|
||||
Test::IsByte {
|
||||
tag_id: test_byte,
|
||||
// num_alts: _,
|
||||
..
|
||||
} => {
|
||||
let lhs = Expr::Byte(test_byte);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path);
|
||||
|
||||
let fake = MutMap::default();
|
||||
|
||||
let cond = env.arena.alloc(Expr::CallByName(
|
||||
Symbol::INT_EQ_I8,
|
||||
env.arena.alloc([
|
||||
(lhs, Layout::Builtin(Builtin::Byte(fake.clone()))),
|
||||
(rhs, Layout::Builtin(Builtin::Byte(fake))),
|
||||
]),
|
||||
));
|
||||
|
||||
tests.push(cond);
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
let cond = tests.remove(0);
|
||||
|
||||
let pass = env.arena.alloc(decide_to_branching(
|
||||
env,
|
||||
cond_symbol,
|
||||
cond_layout.clone(),
|
||||
ret_layout.clone(),
|
||||
*success,
|
||||
jumps,
|
||||
));
|
||||
|
||||
let fail = env.arena.alloc(decide_to_branching(
|
||||
env,
|
||||
cond_symbol,
|
||||
cond_layout.clone(),
|
||||
ret_layout.clone(),
|
||||
*failure,
|
||||
jumps,
|
||||
));
|
||||
|
||||
let cond_layout = Layout::Builtin(Builtin::Bool(
|
||||
TagName::Global("False".into()),
|
||||
TagName::Global("True".into()),
|
||||
));
|
||||
|
||||
Expr::Cond {
|
||||
cond,
|
||||
cond_layout,
|
||||
pass,
|
||||
fail,
|
||||
ret_layout,
|
||||
}
|
||||
}
|
||||
FanOut {
|
||||
path,
|
||||
tests,
|
||||
fallback,
|
||||
} => {
|
||||
let cond = env.arena.alloc(path_to_expr(env, cond_symbol, &path));
|
||||
|
||||
let default_branch = env.arena.alloc(decide_to_branching(
|
||||
env,
|
||||
cond_symbol,
|
||||
cond_layout.clone(),
|
||||
ret_layout.clone(),
|
||||
*fallback,
|
||||
jumps,
|
||||
));
|
||||
|
||||
let mut branches = bumpalo::collections::Vec::with_capacity_in(tests.len(), env.arena);
|
||||
|
||||
for (test, decider) in tests {
|
||||
let branch = decide_to_branching(
|
||||
env,
|
||||
cond_symbol,
|
||||
cond_layout.clone(),
|
||||
ret_layout.clone(),
|
||||
decider,
|
||||
jumps,
|
||||
);
|
||||
|
||||
let tag = match test {
|
||||
Test::IsInt(v) => v as u64,
|
||||
Test::IsFloat(v) => v as u64,
|
||||
Test::IsBit(v) => v as u64,
|
||||
Test::IsByte { tag_id, .. } => tag_id as u64,
|
||||
_ => todo!(),
|
||||
};
|
||||
|
||||
branches.push((tag, branch));
|
||||
}
|
||||
|
||||
// make a jump table based on the tests
|
||||
todo!()
|
||||
Expr::Switch {
|
||||
cond,
|
||||
cond_layout,
|
||||
// branches: &'a [(u64, Expr<'a>)],
|
||||
branches: branches.into_bump_slice(),
|
||||
// default_branch: &'a Expr<'a>,
|
||||
default_branch,
|
||||
ret_layout,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -764,7 +936,9 @@ fn count_targets_help(decision_tree: &Decider<u64>, targets: &mut MutMap<u64, u6
|
|||
None => {
|
||||
targets.insert(*target, 1);
|
||||
}
|
||||
Some(current) => *current = *current + 1,
|
||||
Some(current) => {
|
||||
*current += 1;
|
||||
}
|
||||
},
|
||||
|
||||
Chain {
|
||||
|
@ -791,10 +965,13 @@ fn create_choices<'a>(
|
|||
target: u64,
|
||||
branch: Expr<'a>,
|
||||
) -> ((u64, Choice<'a>), Option<(u64, Expr<'a>)>) {
|
||||
if target_counts[&target] == 1 {
|
||||
((target, Choice::Inline(branch)), None)
|
||||
} else {
|
||||
((target, Choice::Jump(target)), Some((target, branch)))
|
||||
match target_counts.get(&target) {
|
||||
None => unreachable!(
|
||||
"this should never happen: {:?} not in {:?}",
|
||||
target, target_counts
|
||||
),
|
||||
Some(1) => ((target, Choice::Inline(branch)), None),
|
||||
Some(_) => ((target, Choice::Jump(target)), Some((target, branch))),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue