prepare for Tag code gen

This commit is contained in:
Folkert 2020-03-16 20:13:46 +01:00
parent e742b77e0b
commit c9a90c32e3
4 changed files with 281 additions and 110 deletions

View file

@ -1113,6 +1113,24 @@ mod test_gen {
);
}
#[test]
fn when_on_result() {
assert_evals_to!(
indoc!(
r#"
x : Result Int Int
x = Ok 42
when x is
Ok _ -> 0
Err _ -> 4
"#
),
1,
i64
);
}
#[test]
fn basic_record() {
assert_evals_to!(

View file

@ -29,24 +29,32 @@ pub fn compile(raw_branches: Vec<(Pattern<'_>, u64)>) -> DecisionTree {
}
#[derive(Clone, Debug, PartialEq)]
pub enum DecisionTree {
pub enum DecisionTree<'a> {
Match(Label),
Decision {
path: Path,
edges: Vec<(Test, DecisionTree)>,
default: Option<Box<DecisionTree>>,
edges: Vec<(Test<'a>, DecisionTree<'a>)>,
default: Option<Box<DecisionTree<'a>>>,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Test {
IsCtor { tag_name: TagName, num_alts: usize },
pub enum Test<'a> {
IsCtor {
tag_id: u8,
tag_name: TagName,
union: crate::pattern::Union,
arguments: Vec<Pattern<'a>>,
},
IsInt(i64),
// float patterns are stored as i64 so they are comparable/hashable
IsFloat(i64),
// float patterns are stored as u64 so they are comparable/hashable
IsFloat(u64),
IsStr(Box<str>),
IsBit(bool),
IsByte { tag_id: u8, num_alts: usize },
IsByte {
tag_id: u8,
num_alts: usize,
},
}
#[derive(Clone, Debug, PartialEq)]
@ -104,13 +112,16 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
fn is_complete(tests: &[Test]) -> bool {
let length = tests.len();
debug_assert!(length > 0);
match tests[length - 1] {
Test::IsCtor { num_alts, .. } => length == num_alts,
Test::IsByte { num_alts, .. } => length == num_alts,
match tests.get(length - 1) {
None => unreachable!("should never happen"),
Some(v) => match v {
Test::IsCtor { union, .. } => length == union.alternatives.len(),
Test::IsByte { num_alts, .. } => length == *num_alts,
Test::IsBit(_) => length == 2,
Test::IsInt(_) => false,
Test::IsFloat(_) => false,
Test::IsStr(_) => false,
},
}
}
@ -166,7 +177,6 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
fn check_for_match(branches: &Vec<Branch>) -> Option<Label> {
match branches.get(0) {
Some(Branch { goal, patterns }) if patterns.iter().all(|(_, p)| !needs_tests(p)) => {
println!("the expected case {:?} {:?}", goal, patterns);
Some(*goal)
}
_ => None,
@ -178,7 +188,7 @@ fn check_for_match(branches: &Vec<Branch>) -> Option<Label> {
fn gather_edges<'a>(
branches: Vec<Branch<'a>>,
path: &Path,
) -> (Vec<(Test, Vec<Branch<'a>>)>, Vec<Branch<'a>>) {
) -> (Vec<(Test<'a>, Vec<Branch<'a>>)>, Vec<Branch<'a>>) {
// TODO remove clone
let relevant_tests = tests_at_path(path, branches.clone());
@ -204,7 +214,7 @@ fn gather_edges<'a>(
/// FIND RELEVANT TESTS
fn tests_at_path(selected_path: &Path, branches: Vec<Branch>) -> Vec<Test> {
fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Test<'a>> {
// NOTE the ordering of the result is important!
let mut visited = MutSet::default();
@ -224,7 +234,7 @@ fn tests_at_path(selected_path: &Path, branches: Vec<Branch>) -> Vec<Test> {
unique
}
fn test_at_path(selected_path: &Path, branch: Branch) -> Option<Test> {
fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>) -> Option<Test<'a>> {
use Pattern::*;
use Test::*;
@ -241,30 +251,37 @@ fn test_at_path(selected_path: &Path, branch: Branch) -> Option<Test> {
| Shadowed(_, _)
| UnsupportedPattern(_) => None,
AppliedTag { .. } => todo!(),
AppliedTag {
tag_name,
tag_id,
arguments,
union,
..
} => Some(IsCtor {
tag_id: *tag_id,
tag_name: tag_name.clone(),
union: union.clone(),
arguments: arguments.clone().into_iter().collect(),
}),
BitLiteral(v) => Some(IsBit(*v)),
EnumLiteral { tag_id, enum_size } => Some(IsByte {
tag_id: *tag_id,
num_alts: *enum_size as usize,
}),
IntLiteral(v) => Some(IsInt(*v)),
FloatLiteral(v) => Some(IsFloat(float_to_i64(*v))),
FloatLiteral(v) => Some(IsFloat(*v)),
StrLiteral(v) => Some(IsStr(v.clone())),
},
}
}
fn float_to_i64(float: f64) -> i64 {
// To make a float hashable (for storate in a MutMap), transmute to i64
// We assume that `v` is normal (not Nan, Infinity, -Infinity)
// those values cannot occur in patterns in Roc, so this code is safe
debug_assert!(float.is_normal());
float.to_bits() as i64
}
/// BUILD EDGES
fn edges_for<'a>(path: &Path, branches: Vec<Branch<'a>>, test: Test) -> (Test, Vec<Branch<'a>>) {
fn edges_for<'a>(
path: &Path,
branches: Vec<Branch<'a>>,
test: Test<'a>,
) -> (Test<'a>, Vec<Branch<'a>>) {
let new_branches = branches
.into_iter()
.filter_map(|b| to_relevant_branch(&test, path, b))
@ -273,7 +290,7 @@ fn edges_for<'a>(path: &Path, branches: Vec<Branch<'a>>, test: Test) -> (Test, V
(test, new_branches)
}
fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Option<Branch<'a>> {
fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> Option<Branch<'a>> {
use Pattern::*;
use Test::*;
@ -290,8 +307,49 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
| Underscore
| Shadowed(_, _)
| UnsupportedPattern(_) => Some(branch),
AppliedTag { .. } => {
AppliedTag {
union,
tag_name,
mut arguments,
..
} => {
match test {
IsCtor {
tag_name: test_name,
..
} if &tag_name == test_name => {
// TODO can't we unbox whenever there is just one alternative, even if
// there are multiple arguments?
if arguments.len() == 1 && union.alternatives.len() == 1 {
let arg = arguments.remove(0);
{
start.push((Path::Unbox(Box::new(path.clone())), arg));
start.extend(end);
}
} else {
let sub_positions =
arguments.into_iter().enumerate().map(|(index, pattern)| {
(
Path::Index {
index: index as u64,
path: Box::new(path.clone()),
},
pattern,
)
});
start.extend(sub_positions);
start.extend(end);
}
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
}
/*
*
Can.PCtor _ _ (Can.Union _ _ numAlts _) name _ ctorArgs ->
case test of
IsCtor _ testName _ _ _ | name == testName ->
@ -306,7 +364,6 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
_ ->
Nothing
*/
todo!()
}
StrLiteral(string) => match test {
IsStr(test_str) if string == *test_str => {
@ -331,7 +388,7 @@ fn to_relevant_branch<'a>(test: &Test, path: &Path, branch: Branch<'a>) -> Optio
},
FloatLiteral(float) => match test {
IsFloat(test_float) if float_to_i64(float) == *test_float => {
IsFloat(test_float) if float == *test_float => {
start.extend(end);
Some(Branch {
goal: branch.goal,
@ -550,17 +607,17 @@ fn small_branching_factor(branches: &Vec<Branch>, path: &Path) -> usize {
}
#[derive(Clone, Debug, PartialEq)]
enum Decider<T> {
enum Decider<'a, T> {
Leaf(T),
Chain {
test_chain: Vec<(Path, Test)>,
success: Box<Decider<T>>,
failure: Box<Decider<T>>,
test_chain: Vec<(Path, Test<'a>)>,
success: Box<Decider<'a, T>>,
failure: Box<Decider<'a, T>>,
},
FanOut {
path: Path,
tests: Vec<(Test, Decider<T>)>,
fallback: Box<Decider<T>>,
tests: Vec<(Test<'a>, Decider<'a, T>)>,
fallback: Box<Decider<'a, T>>,
},
}
@ -589,6 +646,8 @@ pub fn optimize_when<'a>(
let decider = tree_to_decider(decision_tree);
let target_counts = count_targets(&decider);
dbg!(&target_counts);
let mut choices = MutMap::default();
let mut jumps = Vec::new();
@ -604,40 +663,31 @@ pub fn optimize_when<'a>(
let choice_decider = insert_choices(&choices, decider);
decide_to_branching(
let result = decide_to_branching(
env,
cond_symbol,
cond_layout,
ret_layout,
choice_decider,
&jumps,
)
}
/*
*
Leaf(T),
Chain {
test_chain: Vec<(Path, Test)>,
success: Box<Decider<T>>,
failure: Box<Decider<T>>,
},
FanOut {
path: Path,
tests: Vec<(Test, Decider<T>)>,
fallback: Box<Decider<T>>,
},
}
);
#[derive(Clone, Debug, PartialEq)]
enum Choice<'a> {
Inline(Expr<'a>),
Jump(Label),
// increase the jump counter by the number of jumps in this branching structure
*env.jump_counter += jumps.len() as u64;
result
}
*/
fn path_to_expr<'a>(_env: &mut Env<'a, '_>, symbol: Symbol, path: &Path) -> Expr<'a> {
match path {
match dbg!(path) {
Path::Empty => Expr::Load(symbol),
Path::Index {
index,
path: nested,
} => {
//
Expr::Load(symbol)
}
_ => todo!(),
}
}
@ -653,8 +703,10 @@ fn decide_to_branching<'a>(
use Choice::*;
use Decider::*;
let jump_count = *env.jump_counter;
match decider {
Leaf(Jump(_label)) => todo!(),
Leaf(Jump(label)) => Expr::Jump(label + jump_count),
Leaf(Inline(expr)) => expr,
Chain {
test_chain,
@ -667,6 +719,22 @@ fn decide_to_branching<'a>(
for (path, test) in test_chain {
match test {
Test::IsCtor { tag_id, .. } => {
let lhs = Expr::Byte(tag_id);
let rhs = path_to_expr(env, cond_symbol, &path);
let fake = MutMap::default();
let cond = env.arena.alloc(Expr::CallByName(
Symbol::INT_EQ_I8,
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Byte(fake.clone()))),
(rhs, Layout::Builtin(Builtin::Byte(fake))),
]),
));
tests.push(cond);
}
Test::IsInt(test_int) => {
let lhs = Expr::Int(test_int);
let rhs = path_to_expr(env, cond_symbol, &path);
@ -884,12 +952,12 @@ fn tree_to_decider(tree: DecisionTree) -> Decider<u64> {
}
}
fn to_chain(
fn to_chain<'a>(
path: Path,
test: Test,
success_tree: DecisionTree,
failure_tree: DecisionTree,
) -> Decider<u64> {
test: Test<'a>,
success_tree: DecisionTree<'a>,
failure_tree: DecisionTree<'a>,
) -> Decider<'a, u64> {
use Decider::*;
let failure = tree_to_decider(failure_tree);
@ -977,8 +1045,8 @@ fn create_choices<'a>(
fn insert_choices<'a>(
choice_dict: &MutMap<u64, Choice<'a>>,
decider: Decider<u64>,
) -> Decider<Choice<'a>> {
decider: Decider<'a, u64>,
) -> Decider<'a, Choice<'a>> {
use Decider::*;
match decider {
Leaf(target) => {

View file

@ -7,7 +7,7 @@ use roc_module::ident::{Ident, Lowercase, TagName};
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
use roc_region::all::{Located, Region};
use roc_types::subs::{Content, ContentHash, FlatType, Subs, Variable};
use std::hash::{Hash, Hasher};
#[derive(Clone, Debug, PartialEq, Default)]
pub struct Procs<'a> {
user_defined: MutMap<Symbol, PartialProc<'a>>,
@ -198,6 +198,9 @@ pub enum Expr<'a> {
elems: &'a [Expr<'a>],
},
Label(u64, &'a Expr<'a>),
Jump(u64),
RuntimeError(&'a str),
}
@ -822,28 +825,27 @@ fn from_can_when<'a>(
// We can't know what to return!
panic!("TODO compile a 0-branch when-expression to a RuntimeError");
}
// don't do this for now, see if the decision_tree method can do it
// 1 => {
// // A when-expression with exactly 1 branch is essentially a LetNonRec.
// // As such, we can compile it direcly to a Store.
// let arena = env.arena;
// let mut stored = Vec::with_capacity_in(1, arena);
// let (loc_when_pattern, loc_branch) = branches.into_iter().next().unwrap();
//
// let mono_pattern = from_can_pattern(env, &loc_when_pattern.value);
// store_pattern(
// env,
// mono_pattern,
// loc_cond.value,
// cond_var,
// procs,
// &mut stored,
// );
//
// let ret = from_can(env, loc_branch.value, procs, None);
//
// Expr::Store(stored.into_bump_slice(), arena.alloc(ret))
// }
1 => {
// A when-expression with exactly 1 branch is essentially a LetNonRec.
// As such, we can compile it direcly to a Store.
let arena = env.arena;
let mut stored = Vec::with_capacity_in(1, arena);
let (loc_when_pattern, loc_branch) = branches.into_iter().next().unwrap();
let mono_pattern = from_can_pattern(env, &loc_when_pattern.value);
store_pattern(
env,
mono_pattern,
loc_cond.value,
cond_var,
procs,
&mut stored,
);
let ret = from_can(env, loc_branch.value, procs, None);
Expr::Store(stored.into_bump_slice(), arena.alloc(ret))
}
_ => {
let mut loc_branches = std::vec::Vec::new();
let mut opt_branches = std::vec::Vec::new();
@ -1022,25 +1024,28 @@ fn specialize_proc_body<'a>(
/// A pattern, including possible problems (e.g. shadowing) so that
/// codegen can generate a runtime error if this pattern is reached.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Pattern<'a> {
Identifier(Symbol),
AppliedTag {
tag_name: TagName,
arguments: Vec<'a, Pattern<'a>>,
layout: Layout<'a>,
union: crate::pattern::Union,
},
Underscore,
IntLiteral(i64),
FloatLiteral(u64),
BitLiteral(bool),
EnumLiteral {
tag_id: u8,
enum_size: u8,
},
IntLiteral(i64),
FloatLiteral(f64),
StrLiteral(Box<str>),
RecordDestructure(Vec<'a, RecordDestruct<'a>>, Layout<'a>),
Underscore,
AppliedTag {
tag_name: TagName,
tag_id: u8,
arguments: Vec<'a, Pattern<'a>>,
layout: Layout<'a>,
union: crate::pattern::Union,
},
// Runtime Exceptions
Shadowed(Region, Located<Ident>),
@ -1048,13 +1053,81 @@ pub enum Pattern<'a> {
UnsupportedPattern(Region),
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct RecordDestruct<'a> {
pub label: Lowercase,
pub symbol: Symbol,
pub guard: Option<Pattern<'a>>,
}
impl<'a> Hash for Pattern<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
use Pattern::*;
match self {
Identifier(symbol) => {
state.write_u8(0);
symbol.hash(state);
}
Underscore => {
state.write_u8(1);
}
IntLiteral(v) => {
state.write_u8(2);
v.hash(state);
}
FloatLiteral(v) => {
state.write_u8(3);
v.hash(state);
}
BitLiteral(v) => {
state.write_u8(4);
v.hash(state);
}
EnumLiteral { tag_id, enum_size } => {
state.write_u8(5);
tag_id.hash(state);
enum_size.hash(state);
}
StrLiteral(v) => {
state.write_u8(6);
v.hash(state);
}
RecordDestructure(fields, _layout) => {
state.write_u8(7);
fields.hash(state);
// layout is ignored!
}
AppliedTag {
tag_name,
arguments,
union,
..
} => {
state.write_u8(8);
tag_name.hash(state);
arguments.hash(state);
union.hash(state);
// layout is ignored!
}
Shadowed(region, ident) => {
state.write_u8(9);
region.hash(state);
ident.hash(state);
}
UnsupportedPattern(region) => {
state.write_u8(10);
region.hash(state);
}
}
}
}
fn from_can_pattern<'a>(
env: &mut Env<'a, '_>,
can_pattern: &roc_can::pattern::Pattern,
@ -1064,14 +1137,14 @@ fn from_can_pattern<'a>(
Underscore => Pattern::Underscore,
Identifier(symbol) => Pattern::Identifier(*symbol),
IntLiteral(v) => Pattern::IntLiteral(*v),
FloatLiteral(v) => Pattern::FloatLiteral(*v),
FloatLiteral(v) => Pattern::FloatLiteral(f64::to_bits(*v)),
StrLiteral(v) => Pattern::StrLiteral(v.clone()),
Shadowed(region, ident) => Pattern::Shadowed(*region, ident.clone()),
UnsupportedPattern(region) => Pattern::UnsupportedPattern(*region),
NumLiteral(var, num) => match to_int_or_float(env.subs, *var) {
IntOrFloat::IntType => Pattern::IntLiteral(*num),
IntOrFloat::FloatType => Pattern::FloatLiteral(*num as f64),
IntOrFloat::FloatType => Pattern::FloatLiteral(*num as u64),
},
AppliedTag {
@ -1118,8 +1191,20 @@ fn from_can_pattern<'a>(
Err(content) => panic!("invalid content in ext_var: {:?}", content),
};
use crate::pattern::Ctor;
let mut opt_tag_id = None;
for (index, Ctor { name, .. }) in union.alternatives.iter().enumerate() {
if name == tag_name {
opt_tag_id = Some(index as u8);
break;
}
}
let tag_id = opt_tag_id.expect("Tag must be in its own type");
Pattern::AppliedTag {
tag_name: tag_name.clone(),
tag_id,
arguments: mono_args,
union,
layout,

View file

@ -4,12 +4,12 @@ use roc_region::all::{Located, Region};
use self::Pattern::*;
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Union {
pub alternatives: Vec<Ctor>,
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Ctor {
pub name: TagName,
pub arity: usize,
@ -27,7 +27,7 @@ pub enum Literal {
Int(i64),
Bit(bool),
Byte(u8),
Float(f64),
Float(u64),
Str(Box<str>),
}