code gen for simple guards

Guards cannot use variables bound in the pattern yet
This commit is contained in:
Folkert 2020-03-21 22:27:36 +01:00
parent bd7ad318cf
commit a16d48a6a9
4 changed files with 490 additions and 293 deletions

View file

@ -771,13 +771,24 @@ fn call_by_name<'a, B: Backend>(
.ins() .ins()
.load(env.ptr_sized_int(), MemFlags::new(), list_ptr, offset) .load(env.ptr_sized_int(), MemFlags::new(), list_ptr, offset)
} }
Symbol::INT_EQ_I64 | Symbol::INT_EQ_I8 | Symbol::INT_EQ_I1 => { Symbol::INT_EQ_I64 | Symbol::INT_EQ_I8 => {
debug_assert!(args.len() == 2); debug_assert!(args.len() == 2);
let a = build_arg(&args[0], env, scope, module, builder, procs); let a = build_arg(&args[0], env, scope, module, builder, procs);
let b = build_arg(&args[1], env, scope, module, builder, procs); let b = build_arg(&args[1], env, scope, module, builder, procs);
builder.ins().icmp(IntCC::Equal, a, b) builder.ins().icmp(IntCC::Equal, a, b)
} }
Symbol::INT_EQ_I1 => {
debug_assert!(args.len() == 2);
let a = build_arg(&args[0], env, scope, module, builder, procs);
let b = build_arg(&args[1], env, scope, module, builder, procs);
// integer comparisons don't work for booleans, and a custom xand gives errors.
let p = builder.ins().bint(types::I8, a);
let q = builder.ins().bint(types::I8, b);
builder.ins().icmp(IntCC::Equal, p, q)
}
Symbol::FLOAT_EQ => { Symbol::FLOAT_EQ => {
debug_assert!(args.len() == 2); debug_assert!(args.len() == 2);
let a = build_arg(&args[0], env, scope, module, builder, procs); let a = build_arg(&args[0], env, scope, module, builder, procs);

View file

@ -1495,6 +1495,66 @@ mod test_gen {
); );
} }
#[test]
fn or_pattern() {
assert_evals_to!(
indoc!(
r#"
when 2 is
1 | 2 -> 42
_ -> 1
"#
),
42,
i64
);
}
#[test]
fn if_guard_pattern_false() {
assert_evals_to!(
indoc!(
r#"
when 2 is
2 if False -> 0
_ -> 42
"#
),
42,
i64
);
}
#[test]
fn if_guard_pattern_true() {
assert_evals_to!(
indoc!(
r#"
when 2 is
2 if True -> 42
_ -> 0
"#
),
42,
i64
);
}
// #[test]
// fn if_guard_exhaustiveness() {
// assert_evals_to!(
// indoc!(
// r#"
// when 2 is
// _ if False -> 0
// _ -> 42
// "#
// ),
// 42,
// i64
// );
// }
// #[test] // #[test]
// fn linked_list_empty() { // fn linked_list_empty() {
// assert_evals_to!( // assert_evals_to!(

View file

@ -19,12 +19,12 @@ type Label = u64;
/// some normal branches and gives out a decision tree that has "labels" at all /// some normal branches and gives out a decision tree that has "labels" at all
/// the leafs and a dictionary that maps these "labels" to the code that should /// the leafs and a dictionary that maps these "labels" to the code that should
/// run. /// run.
pub fn compile(raw_branches: Vec<(Pattern<'_>, u64)>) -> DecisionTree { pub fn compile<'a>(raw_branches: Vec<(Option<Expr<'a>>, Pattern<'a>, u64)>) -> DecisionTree<'a> {
let formatted = raw_branches let formatted = raw_branches
.into_iter() .into_iter()
.map(|(pattern, index)| Branch { .map(|(guard, pattern, index)| Branch {
goal: index, goal: index,
patterns: vec![(Path::Empty, pattern)], patterns: vec![(Path::Empty, guard, pattern)],
}) })
.collect(); .collect();
@ -41,7 +41,7 @@ pub enum DecisionTree<'a> {
}, },
} }
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq)]
pub enum Test<'a> { pub enum Test<'a> {
IsCtor { IsCtor {
tag_id: u8, tag_id: u8,
@ -58,6 +58,8 @@ pub enum Test<'a> {
tag_id: u8, tag_id: u8,
num_alts: usize, num_alts: usize,
}, },
// A pattern that always succeeds (like `_`) can still have a guard
Guarded(Option<Box<Test<'a>>>, Expr<'a>),
} }
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
impl<'a> Hash for Test<'a> { impl<'a> Hash for Test<'a> {
@ -89,7 +91,14 @@ impl<'a> Hash for Test<'a> {
IsByte { tag_id, num_alts } => { IsByte { tag_id, num_alts } => {
state.write_u8(5); state.write_u8(5);
tag_id.hash(state); tag_id.hash(state);
num_alts.hash(state) num_alts.hash(state);
}
Guarded(None, _) => {
state.write_u8(6);
}
Guarded(Some(nested), _) => {
state.write_u8(7);
nested.hash(state);
} }
} }
} }
@ -111,7 +120,7 @@ pub enum Path {
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
struct Branch<'a> { struct Branch<'a> {
goal: Label, goal: Label,
patterns: Vec<(Path, Pattern<'a>)>, patterns: Vec<(Path, Option<Expr<'a>>, Pattern<'a>)>,
} }
fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree { fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
@ -163,6 +172,7 @@ fn is_complete(tests: &[Test]) -> bool {
Test::IsInt(_) => false, Test::IsInt(_) => false,
Test::IsFloat(_) => false, Test::IsFloat(_) => false,
Test::IsStr(_) => false, Test::IsStr(_) => false,
Test::Guarded(_, _) => false,
}, },
} }
} }
@ -179,20 +189,28 @@ fn flatten_patterns(branch: Branch) -> Branch {
} }
} }
fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, Pattern<'a>)>) { fn flatten<'a>(
match &path_pattern.1 { path_pattern: (Path, Option<Expr<'a>>, Pattern<'a>),
path_patterns: &mut Vec<(Path, Option<Expr<'a>>, Pattern<'a>)>,
) {
match &path_pattern.2 {
Pattern::AppliedTag { Pattern::AppliedTag {
union, union,
arguments, arguments,
tag_id, tag_id,
.. ..
} => { } => {
// TODO do we need to check that guard.is_none() here?
if union.alternatives.len() == 1 { if union.alternatives.len() == 1 {
let path = path_pattern.0; let path = path_pattern.0;
// Theory: unbox doesn't have any value for us, because one-element tag unions // Theory: unbox doesn't have any value for us, because one-element tag unions
// don't store the tag anyway. // don't store the tag anyway.
if arguments.len() == 1 { if arguments.len() == 1 {
path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone())); path_patterns.push((
Path::Unbox(Box::new(path)),
path_pattern.1.clone(),
path_pattern.2.clone(),
));
} else { } else {
for (index, (arg_pattern, _)) in arguments.iter().enumerate() { for (index, (arg_pattern, _)) in arguments.iter().enumerate() {
flatten( flatten(
@ -202,6 +220,8 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
tag_id: *tag_id, tag_id: *tag_id,
path: Box::new(path.clone()), path: Box::new(path.clone()),
}, },
// same guard here?
path_pattern.1.clone(),
arg_pattern.clone(), arg_pattern.clone(),
), ),
path_patterns, path_patterns,
@ -227,7 +247,11 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
/// us something like ("x" => value.0.0) /// us something like ("x" => value.0.0)
fn check_for_match(branches: &Vec<Branch>) -> Option<Label> { fn check_for_match(branches: &Vec<Branch>) -> Option<Label> {
match branches.get(0) { match branches.get(0) {
Some(Branch { goal, patterns }) if patterns.iter().all(|(_, p)| !needs_tests(p)) => { Some(Branch { goal, patterns })
if patterns
.iter()
.all(|(_, guard, pattern)| guard.is_none() && !needs_tests(pattern)) =>
{
Some(*goal) Some(*goal)
} }
_ => None, _ => None,
@ -268,12 +292,11 @@ fn gather_edges<'a>(
fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Test<'a>> { fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Test<'a>> {
// NOTE the ordering of the result is important! // NOTE the ordering of the result is important!
let mut visited = MutSet::default(); let mut all_tests = Vec::new();
let mut unique = Vec::new();
let all_tests = branches for branch in branches.into_iter() {
.into_iter() test_at_path(selected_path, branch, &mut all_tests);
.filter_map(|b| test_at_path(selected_path, b)); }
// The rust HashMap also uses equality, here we really want to use the custom hash function // The rust HashMap also uses equality, here we really want to use the custom hash function
// defined on Test to determine whether a test is unique. So we have to do the hashing // defined on Test to determine whether a test is unique. So we have to do the hashing
@ -281,6 +304,9 @@ fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Tes
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
let mut visited = MutSet::default();
let mut unique = Vec::new();
for test in all_tests { for test in all_tests {
let hash = { let hash = {
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
@ -297,66 +323,93 @@ fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Tes
unique unique
} }
fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>) -> Option<Test<'a>> { fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>, all_tests: &mut Vec<Test<'a>>) {
use Pattern::*; use Pattern::*;
use Test::*; use Test::*;
match branch match branch
.patterns .patterns
.iter() .iter()
.find(|(path, _)| path == selected_path) .find(|(path, _, _)| path == selected_path)
{ {
None => None, None => {}
Some((_, pattern)) => match pattern { Some((_, guard, pattern)) => {
Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => None, let guarded = |test| {
if let Some(guard) = guard {
Guarded(Some(Box::new(test)), guard.clone())
} else {
test
}
};
RecordDestructure(destructs, _) => { match pattern {
let union = Union { // TODO use guard!
alternatives: vec![Ctor { Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => {
name: TagName::Global("#Record".into()), if let Some(guard) = guard {
arity: destructs.len(), all_tests.push(Guarded(None, guard.clone()));
}],
};
let mut arguments = std::vec::Vec::new();
for destruct in destructs {
if let Some(guard) = &destruct.guard {
arguments.push((guard.clone(), destruct.layout.clone()));
} else {
arguments.push((Pattern::Underscore, destruct.layout.clone()));
} }
} }
Some(IsCtor { RecordDestructure(destructs, _) => {
tag_id: 0, let union = Union {
tag_name: TagName::Global("#Record".into()), alternatives: vec![Ctor {
union, name: TagName::Global("#Record".into()),
arguments, arity: destructs.len(),
}) }],
} };
AppliedTag { let mut arguments = std::vec::Vec::new();
tag_name,
tag_id, for destruct in destructs {
arguments, if let Some(guard) = &destruct.guard {
union, arguments.push((guard.clone(), destruct.layout.clone()));
.. } else {
} => Some(IsCtor { arguments.push((Pattern::Underscore, destruct.layout.clone()));
tag_id: *tag_id, }
tag_name: tag_name.clone(), }
union: union.clone(),
arguments: arguments.to_vec(), all_tests.push(IsCtor {
}), tag_id: 0,
BitLiteral(v) => Some(IsBit(*v)), tag_name: TagName::Global("#Record".into()),
EnumLiteral { tag_id, enum_size } => Some(IsByte { union,
tag_id: *tag_id, arguments,
num_alts: *enum_size as usize, });
}), }
IntLiteral(v) => Some(IsInt(*v)),
FloatLiteral(v) => Some(IsFloat(*v)), AppliedTag {
StrLiteral(v) => Some(IsStr(v.clone())), tag_name,
}, tag_id,
arguments,
union,
..
} => {
all_tests.push(IsCtor {
tag_id: *tag_id,
tag_name: tag_name.clone(),
union: union.clone(),
arguments: arguments.to_vec(),
});
}
BitLiteral(v) => {
all_tests.push(IsBit(*v));
}
EnumLiteral { tag_id, enum_size } => {
all_tests.push(IsByte {
tag_id: *tag_id,
num_alts: *enum_size as usize,
});
}
IntLiteral(v) => {
all_tests.push(guarded(IsInt(*v)));
}
FloatLiteral(v) => {
all_tests.push(IsFloat(*v));
}
StrLiteral(v) => {
all_tests.push(IsStr(v.clone()));
}
};
}
} }
} }
@ -367,180 +420,223 @@ fn edges_for<'a>(
branches: Vec<Branch<'a>>, branches: Vec<Branch<'a>>,
test: Test<'a>, test: Test<'a>,
) -> (Test<'a>, Vec<Branch<'a>>) { ) -> (Test<'a>, Vec<Branch<'a>>) {
let new_branches = branches let mut new_branches = Vec::new();
.into_iter()
.filter_map(|b| to_relevant_branch(&test, path, b)) for branch in branches.into_iter() {
.collect(); to_relevant_branch(&test, path, branch, &mut new_branches);
}
(test, new_branches) (test, new_branches)
} }
fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> Option<Branch<'a>> { fn to_relevant_branch<'a>(
test: &Test<'a>,
path: &Path,
branch: Branch<'a>,
new_branches: &mut Vec<Branch<'a>>,
) {
// TODO remove clone
match extract(path, branch.patterns.clone()) {
Extract::NotFound => {
new_branches.push(branch);
}
Extract::Found {
start,
found_pattern: (guard, pattern),
end,
} => match test {
Test::Guarded(None, _guard_expr) => {
// theory: Some(branch)
todo!();
}
Test::Guarded(Some(box_test), _guard_expr) => {
if let Some(new_branch) =
to_relevant_branch_help(box_test, path, start, end, branch, guard, pattern)
{
new_branches.push(new_branch);
}
}
_ => {
if let Some(new_branch) =
to_relevant_branch_help(test, path, start, end, branch, guard, pattern)
{
new_branches.push(new_branch);
}
}
},
}
}
fn to_relevant_branch_help<'a>(
test: &Test<'a>,
path: &Path,
mut start: Vec<(Path, Option<Expr<'a>>, Pattern<'a>)>,
end: Vec<(Path, Option<Expr<'a>>, Pattern<'a>)>,
branch: Branch<'a>,
guard: Option<Expr<'a>>,
pattern: Pattern<'a>,
) -> Option<Branch<'a>> {
use Pattern::*; use Pattern::*;
use Test::*; use Test::*;
// TODO remove clone match pattern {
match extract(path, branch.patterns.clone()) { Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch),
Extract::NotFound => Some(branch),
Extract::Found { RecordDestructure(destructs, _) => match test {
mut start, IsCtor {
found_pattern: pattern, tag_name: test_name,
end, tag_id,
..
} => {
debug_assert!(test_name == &TagName::Global("#Record".into()));
let sub_positions = destructs.into_iter().enumerate().map(|(index, destruct)| {
let pattern = if let Some(guard) = destruct.guard {
guard.clone()
} else {
Pattern::Underscore
};
(
Path::Index {
index: index as u64,
tag_id: *tag_id,
path: Box::new(path.clone()),
},
guard.clone(),
pattern,
)
});
start.extend(sub_positions);
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
AppliedTag {
tag_name,
arguments,
union,
..
} => { } => {
match pattern { match test {
Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch), IsCtor {
tag_name: test_name,
RecordDestructure(destructs, _) => match test { tag_id,
IsCtor { ..
tag_name: test_name, } if &tag_name == test_name => {
tag_id, // Theory: Unbox doesn't have any value for us
.. if arguments.len() == 1 && union.alternatives.len() == 1 {
} => { let arg = arguments[0].clone();
debug_assert!(test_name == &TagName::Global("#Record".into())); {
start.push((Path::Unbox(Box::new(path.clone())), guard, arg.0));
start.extend(end);
}
} else {
let sub_positions = let sub_positions =
destructs.into_iter().enumerate().map(|(index, destruct)| { arguments
let pattern = if let Some(guard) = destruct.guard { .into_iter()
guard.clone() .enumerate()
} else { .map(|(index, (pattern, _))| {
Pattern::Underscore (
}; Path::Index {
index: index as u64,
( tag_id: *tag_id,
Path::Index { path: Box::new(path.clone()),
index: index as u64, },
tag_id: *tag_id, guard.clone(),
path: Box::new(path.clone()), pattern,
}, )
pattern, });
)
});
start.extend(sub_positions); start.extend(sub_positions);
start.extend(end); start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
} }
_ => None,
},
AppliedTag { Some(Branch {
tag_name, goal: branch.goal,
arguments, patterns: start,
union, })
..
} => {
match test {
IsCtor {
tag_name: test_name,
tag_id,
..
} if &tag_name == test_name => {
// Theory: Unbox doesn't have any value for us
if arguments.len() == 1 && union.alternatives.len() == 1 {
let arg = arguments[0].clone();
{
start.push((Path::Unbox(Box::new(path.clone())), arg.0));
start.extend(end);
}
} else {
let sub_positions = arguments.into_iter().enumerate().map(
|(index, (pattern, _))| {
(
Path::Index {
index: index as u64,
tag_id: *tag_id,
path: Box::new(path.clone()),
},
pattern,
)
},
);
start.extend(sub_positions);
start.extend(end);
}
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
}
} }
StrLiteral(string) => match test { _ => None,
IsStr(test_str) if string == *test_str => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
IntLiteral(int) => match test {
IsInt(is_int) if int == *is_int => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
FloatLiteral(float) => match test {
IsFloat(test_float) if float == *test_float => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
BitLiteral(bit) => match test {
IsBit(test_bit) if bit == *test_bit => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
EnumLiteral { tag_id, .. } => match test {
IsByte {
tag_id: test_id, ..
} if tag_id == *test_id => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
} }
} }
StrLiteral(string) => match test {
IsStr(test_str) if string == *test_str => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
IntLiteral(int) => match test {
IsInt(is_int) if int == *is_int => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
FloatLiteral(float) => match test {
IsFloat(test_float) if float == *test_float => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
BitLiteral(bit) => match test {
IsBit(test_bit) if bit == *test_bit => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
EnumLiteral { tag_id, .. } => match test {
IsByte {
tag_id: test_id, ..
} if tag_id == *test_id => {
start.extend(end);
Some(Branch {
goal: branch.goal,
patterns: start,
})
}
_ => None,
},
} }
} }
enum Extract<'a> { enum Extract<'a> {
NotFound, NotFound,
Found { Found {
start: Vec<(Path, Pattern<'a>)>, start: Vec<(Path, Option<Expr<'a>>, Pattern<'a>)>,
found_pattern: Pattern<'a>, found_pattern: (Option<Expr<'a>>, Pattern<'a>),
end: Vec<(Path, Pattern<'a>)>, end: Vec<(Path, Option<Expr<'a>>, Pattern<'a>)>,
}, },
} }
fn extract<'a>(selected_path: &Path, path_patterns: Vec<(Path, Pattern<'a>)>) -> Extract<'a> { fn extract<'a>(
selected_path: &Path,
path_patterns: Vec<(Path, Option<Expr<'a>>, Pattern<'a>)>,
) -> Extract<'a> {
let mut start = Vec::new(); let mut start = Vec::new();
// TODO remove this clone // TODO remove this clone
@ -551,7 +647,7 @@ fn extract<'a>(selected_path: &Path, path_patterns: Vec<(Path, Pattern<'a>)>) ->
if &current.0 == selected_path { if &current.0 == selected_path {
return Extract::Found { return Extract::Found {
start, start,
found_pattern: current.1, found_pattern: (current.1, current.2),
end: { end: {
copy.drain(0..=index); copy.drain(0..=index);
copy copy
@ -571,10 +667,10 @@ fn is_irrelevant_to<'a>(selected_path: &Path, branch: &Branch<'a>) -> bool {
match branch match branch
.patterns .patterns
.iter() .iter()
.find(|(path, _)| path == selected_path) .find(|(path, _, _)| path == selected_path)
{ {
None => true, None => true,
Some((_, pattern)) => !needs_tests(pattern), Some((_, guard, pattern)) => guard.is_none() && !needs_tests(pattern),
} }
} }
@ -620,10 +716,10 @@ fn pick_path(branches: Vec<Branch>) -> Path {
} }
} }
fn is_choice_path(path_and_pattern: (Path, Pattern<'_>)) -> Option<Path> { fn is_choice_path<'a>(path_and_pattern: (Path, Option<Expr<'a>>, Pattern<'a>)) -> Option<Path> {
let (path, pattern) = path_and_pattern; let (path, guard, pattern) = path_and_pattern;
if needs_tests(&pattern) { if guard.is_some() || needs_tests(&pattern) {
Some(path) Some(path)
} else { } else {
None None
@ -737,12 +833,14 @@ pub fn optimize_when<'a>(
cond_symbol: Symbol, cond_symbol: Symbol,
cond_layout: Layout<'a>, cond_layout: Layout<'a>,
ret_layout: Layout<'a>, ret_layout: Layout<'a>,
opt_branches: Vec<(Pattern<'a>, Expr<'a>)>, opt_branches: Vec<(Pattern<'a>, Option<Expr<'a>>, Expr<'a>)>,
) -> Expr<'a> { ) -> Expr<'a> {
let (patterns, _indexed_branches) = opt_branches let (patterns, _indexed_branches) = opt_branches
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(index, (pattern, branch))| ((pattern, index as u64), (index as u64, branch))) .map(|(index, (pattern, guard, branch))| {
((guard, pattern, index as u64), (index as u64, branch))
})
.unzip(); .unzip();
let indexed_branches: Vec<(u64, Expr<'a>)> = _indexed_branches; let indexed_branches: Vec<(u64, Expr<'a>)> = _indexed_branches;
@ -835,6 +933,98 @@ fn path_to_expr_help<'a>(
} }
} }
fn test_to_equality<'a>(
env: &mut Env<'a, '_>,
cond_symbol: Symbol,
cond_layout: &Layout<'a>,
path: &Path,
test: Test<'a>,
tests: &mut Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>,
) {
match test {
Test::IsCtor {
tag_id,
union,
arguments,
..
} => {
// the IsCtor check should never be generated for tag unions of size 1
// (e.g. record pattern guard matches)
debug_assert!(union.alternatives.len() > 1);
let lhs = Expr::Int(tag_id as i64);
let mut field_layouts =
bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena);
// add the tag discriminant
field_layouts.push(Layout::Builtin(Builtin::Int64));
for (_, layout) in arguments {
field_layouts.push(layout);
}
let rhs = Expr::AccessAtIndex {
index: 0,
field_layouts: field_layouts.into_bump_slice(),
expr: env.arena.alloc(Expr::Load(cond_symbol)),
is_unwrapped: union.alternatives.len() == 1,
};
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
}
Test::IsInt(test_int) => {
let lhs = Expr::Int(test_int);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
}
Test::IsFloat(test_int) => {
// TODO maybe we can actually use i64 comparison here?
let test_float = f64::from_bits(test_int as u64);
let lhs = Expr::Float(test_float);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Float64)));
}
Test::IsByte {
tag_id: test_byte, ..
} => {
let lhs = Expr::Byte(test_byte);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Byte)));
}
Test::IsBit(test_bit) => {
let lhs = Expr::Bool(test_bit);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool)));
}
Test::IsStr(test_str) => {
let lhs = Expr::Str(env.arena.alloc(test_str));
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Str)));
}
Test::Guarded(test, expr) => {
if let Some(nested) = test {
test_to_equality(env, cond_symbol, cond_layout, path, *nested, tests);
}
let lhs = Expr::Bool(true);
let rhs = expr;
tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool)));
}
}
}
fn decide_to_branching<'a>( fn decide_to_branching<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
cond_symbol: Symbol, cond_symbol: Symbol,
@ -861,77 +1051,7 @@ fn decide_to_branching<'a>(
let mut tests = Vec::with_capacity(test_chain.len()); let mut tests = Vec::with_capacity(test_chain.len());
for (path, test) in test_chain { for (path, test) in test_chain {
match test { test_to_equality(env, cond_symbol, &cond_layout, &path, test, &mut tests);
Test::IsCtor {
tag_id,
union,
arguments,
..
} => {
// the IsCtor check should never be generated for tag unions of size 1
// (e.g. record pattern guard matches)
debug_assert!(union.alternatives.len() > 1);
let lhs = Expr::Int(tag_id as i64);
let mut field_layouts =
bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena);
// add the tag discriminant
field_layouts.push(Layout::Builtin(Builtin::Int64));
for (_, layout) in arguments {
field_layouts.push(layout);
}
let rhs = Expr::AccessAtIndex {
index: 0,
field_layouts: field_layouts.into_bump_slice(),
expr: env.arena.alloc(Expr::Load(cond_symbol)),
is_unwrapped: union.alternatives.len() == 1,
};
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
}
Test::IsInt(test_int) => {
let lhs = Expr::Int(test_int);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
}
Test::IsFloat(test_int) => {
// TODO maybe we can actually use i64 comparison here?
let test_float = f64::from_bits(test_int as u64);
let lhs = Expr::Float(test_float);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Float64)));
}
Test::IsByte {
tag_id: test_byte, ..
} => {
let lhs = Expr::Byte(test_byte);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Byte)));
}
Test::IsBit(test_bit) => {
let lhs = Expr::Bool(test_bit);
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool)));
}
Test::IsStr(test_str) => {
let lhs = Expr::Str(env.arena.alloc(test_str));
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Str)));
}
}
} }
let pass = env.arena.alloc(decide_to_branching( let pass = env.arena.alloc(decide_to_branching(

View file

@ -1031,6 +1031,12 @@ fn from_can_when<'a>(
for when_branch in branches { for when_branch in branches {
let mono_expr = from_can(env, when_branch.value.value, procs, None); let mono_expr = from_can(env, when_branch.value.value, procs, None);
let mono_guard = if let Some(loc_guard) = when_branch.guard {
Some(from_can(env, loc_guard.value, procs, None))
} else {
None
};
for loc_pattern in when_branch.patterns { for loc_pattern in when_branch.patterns {
let mono_pattern = from_can_pattern(env, &loc_pattern.value); let mono_pattern = from_can_pattern(env, &loc_pattern.value);
@ -1051,7 +1057,7 @@ fn from_can_when<'a>(
Err(message) => Expr::RuntimeError(env.arena.alloc(message)), Err(message) => Expr::RuntimeError(env.arena.alloc(message)),
}; };
opt_branches.push((mono_pattern, mono_expr)); opt_branches.push((mono_pattern, mono_guard.clone(), mono_expr));
} }
} }