mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-02 16:21:11 +00:00
Merge pull request #272 from rtfeldman/more-patterns
More pattern features
This commit is contained in:
commit
4cec4d2c8a
10 changed files with 734 additions and 336 deletions
|
@ -179,13 +179,19 @@ pub fn desugar_expr<'a>(arena: &'a Bump, loc_expr: &'a Located<Expr<'a>>) -> &'a
|
|||
})
|
||||
}
|
||||
|
||||
let desugared_guard = if let Some(guard) = &branch.guard {
|
||||
Some(desugar_expr(arena, guard).clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
desugared_branches.push(&*arena.alloc(WhenBranch {
|
||||
patterns: alternatives,
|
||||
value: Located {
|
||||
region: desugared.region,
|
||||
value: Nested(&desugared.value),
|
||||
},
|
||||
guard: None,
|
||||
guard: desugared_guard,
|
||||
}));
|
||||
}
|
||||
|
||||
|
|
|
@ -685,26 +685,43 @@ fn constrain_when_branch(
|
|||
}
|
||||
|
||||
if let Some(loc_guard) = &when_branch.guard {
|
||||
state.constraints.push(constrain_expr(
|
||||
let guard_constraint = constrain_expr(
|
||||
env,
|
||||
region,
|
||||
&when_branch.value.value,
|
||||
&loc_guard.value,
|
||||
Expected::ForReason(
|
||||
Reason::WhenGuard,
|
||||
Type::Variable(Variable::BOOL),
|
||||
loc_guard.region,
|
||||
),
|
||||
));
|
||||
};
|
||||
);
|
||||
|
||||
Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: state.vars,
|
||||
def_types: state.headers,
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: Constraint::And(state.constraints),
|
||||
ret_constraint,
|
||||
}))
|
||||
// must introduce the headers from the pattern before constraining the guard
|
||||
Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: state.vars,
|
||||
def_types: state.headers,
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: Constraint::And(state.constraints),
|
||||
ret_constraint: Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: Vec::new(),
|
||||
def_types: SendMap::default(),
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: guard_constraint,
|
||||
ret_constraint,
|
||||
})),
|
||||
}))
|
||||
} else {
|
||||
Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: state.vars,
|
||||
def_types: state.headers,
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: Constraint::And(state.constraints),
|
||||
ret_constraint,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn constrain_field(env: &Env, field_var: Variable, loc_expr: &Located<Expr>) -> (Type, Constraint) {
|
||||
|
|
|
@ -1431,7 +1431,8 @@ fn constrain_by_usage_record(
|
|||
|
||||
// TODO trim down these arguments
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline(always)]
|
||||
// NOTE enabling the inline pragma can blow the stack in debug mode
|
||||
// #[inline(always)]
|
||||
fn constrain_when_branch(
|
||||
var_store: &VarStore,
|
||||
var_usage: &VarUsage,
|
||||
|
@ -1470,13 +1471,13 @@ fn constrain_when_branch(
|
|||
|
||||
if let Some(loc_guard) = &when_branch.guard {
|
||||
let guard_uniq_var = var_store.fresh();
|
||||
state.vars.push(guard_uniq_var);
|
||||
|
||||
let bool_type = attr_type(
|
||||
Bool::variable(guard_uniq_var),
|
||||
Type::Variable(Variable::BOOL),
|
||||
);
|
||||
state.constraints.push(constrain_expr(
|
||||
|
||||
let guard_constraint = constrain_expr(
|
||||
env,
|
||||
var_store,
|
||||
var_usage,
|
||||
|
@ -1484,17 +1485,33 @@ fn constrain_when_branch(
|
|||
loc_guard.region,
|
||||
&loc_guard.value,
|
||||
Expected::ForReason(Reason::WhenGuard, bool_type, loc_guard.region),
|
||||
));
|
||||
}
|
||||
);
|
||||
|
||||
Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: state.vars,
|
||||
def_types: state.headers,
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: Constraint::And(state.constraints),
|
||||
ret_constraint,
|
||||
}))
|
||||
Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: state.vars,
|
||||
def_types: state.headers,
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: Constraint::And(state.constraints),
|
||||
ret_constraint: Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: vec![guard_uniq_var],
|
||||
def_types: SendMap::default(),
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: guard_constraint,
|
||||
ret_constraint,
|
||||
})),
|
||||
}))
|
||||
} else {
|
||||
Constraint::Let(Box::new(LetConstraint {
|
||||
rigid_vars: Vec::new(),
|
||||
flex_vars: state.vars,
|
||||
def_types: state.headers,
|
||||
def_aliases: SendMap::default(),
|
||||
defs_constraint: Constraint::And(state.constraints),
|
||||
ret_constraint,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn constrain_def_pattern(
|
||||
|
|
|
@ -771,13 +771,24 @@ fn call_by_name<'a, B: Backend>(
|
|||
.ins()
|
||||
.load(env.ptr_sized_int(), MemFlags::new(), list_ptr, offset)
|
||||
}
|
||||
Symbol::INT_EQ_I64 | Symbol::INT_EQ_I8 | Symbol::INT_EQ_I1 => {
|
||||
Symbol::INT_EQ_I64 | Symbol::INT_EQ_I8 => {
|
||||
debug_assert!(args.len() == 2);
|
||||
let a = build_arg(&args[0], env, scope, module, builder, procs);
|
||||
let b = build_arg(&args[1], env, scope, module, builder, procs);
|
||||
|
||||
builder.ins().icmp(IntCC::Equal, a, b)
|
||||
}
|
||||
Symbol::INT_EQ_I1 => {
|
||||
debug_assert!(args.len() == 2);
|
||||
let a = build_arg(&args[0], env, scope, module, builder, procs);
|
||||
let b = build_arg(&args[1], env, scope, module, builder, procs);
|
||||
|
||||
// integer comparisons don't work for booleans, and a custom xand gives errors.
|
||||
let p = builder.ins().bint(types::I8, a);
|
||||
let q = builder.ins().bint(types::I8, b);
|
||||
|
||||
builder.ins().icmp(IntCC::Equal, p, q)
|
||||
}
|
||||
Symbol::FLOAT_EQ => {
|
||||
debug_assert!(args.len() == 2);
|
||||
let a = build_arg(&args[0], env, scope, module, builder, procs);
|
||||
|
|
|
@ -1495,6 +1495,93 @@ mod test_gen {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn or_pattern() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when 2 is
|
||||
1 | 2 -> 42
|
||||
_ -> 1
|
||||
"#
|
||||
),
|
||||
42,
|
||||
i64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_guard_pattern_false() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when 2 is
|
||||
2 if False -> 0
|
||||
_ -> 42
|
||||
"#
|
||||
),
|
||||
42,
|
||||
i64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_guard_pattern_true() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when 2 is
|
||||
2 if True -> 42
|
||||
_ -> 0
|
||||
"#
|
||||
),
|
||||
42,
|
||||
i64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_guard_exhaustiveness() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when 2 is
|
||||
_ if False -> 0
|
||||
_ -> 42
|
||||
"#
|
||||
),
|
||||
42,
|
||||
i64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_guard_bind_variable() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when 10 is
|
||||
x if x == 5 -> 0
|
||||
_ -> 42
|
||||
"#
|
||||
),
|
||||
42,
|
||||
i64
|
||||
);
|
||||
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when 10 is
|
||||
x if x == 10 -> 42
|
||||
_ -> 0
|
||||
"#
|
||||
),
|
||||
42,
|
||||
i64
|
||||
);
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn linked_list_empty() {
|
||||
// assert_evals_to!(
|
||||
|
|
|
@ -19,18 +19,33 @@ type Label = u64;
|
|||
/// some normal branches and gives out a decision tree that has "labels" at all
|
||||
/// the leafs and a dictionary that maps these "labels" to the code that should
|
||||
/// run.
|
||||
pub fn compile(raw_branches: Vec<(Pattern<'_>, u64)>) -> DecisionTree {
|
||||
pub fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> {
|
||||
let formatted = raw_branches
|
||||
.into_iter()
|
||||
.map(|(pattern, index)| Branch {
|
||||
.map(|(guard, pattern, index)| Branch {
|
||||
goal: index,
|
||||
patterns: vec![(Path::Empty, pattern)],
|
||||
patterns: vec![(Path::Empty, guard, pattern)],
|
||||
})
|
||||
.collect();
|
||||
|
||||
to_decision_tree(formatted)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Guard<'a> {
|
||||
NoGuard,
|
||||
Guard {
|
||||
stores: &'a [(Symbol, Layout<'a>, Expr<'a>)],
|
||||
expr: Expr<'a>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a> Guard<'a> {
|
||||
fn is_none(&self) -> bool {
|
||||
self == &Guard::NoGuard
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum DecisionTree<'a> {
|
||||
Match(Label),
|
||||
|
@ -41,7 +56,7 @@ pub enum DecisionTree<'a> {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Test<'a> {
|
||||
IsCtor {
|
||||
tag_id: u8,
|
||||
|
@ -58,6 +73,12 @@ pub enum Test<'a> {
|
|||
tag_id: u8,
|
||||
num_alts: usize,
|
||||
},
|
||||
// A pattern that always succeeds (like `_`) can still have a guard
|
||||
Guarded {
|
||||
opt_test: Option<Box<Test<'a>>>,
|
||||
stores: &'a [(Symbol, Layout<'a>, Expr<'a>)],
|
||||
expr: Expr<'a>,
|
||||
},
|
||||
}
|
||||
use std::hash::{Hash, Hasher};
|
||||
impl<'a> Hash for Test<'a> {
|
||||
|
@ -89,7 +110,17 @@ impl<'a> Hash for Test<'a> {
|
|||
IsByte { tag_id, num_alts } => {
|
||||
state.write_u8(5);
|
||||
tag_id.hash(state);
|
||||
num_alts.hash(state)
|
||||
num_alts.hash(state);
|
||||
}
|
||||
Guarded { opt_test: None, .. } => {
|
||||
state.write_u8(6);
|
||||
}
|
||||
Guarded {
|
||||
opt_test: Some(nested),
|
||||
..
|
||||
} => {
|
||||
state.write_u8(7);
|
||||
nested.hash(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -111,7 +142,7 @@ pub enum Path {
|
|||
#[derive(Clone, Debug, PartialEq)]
|
||||
struct Branch<'a> {
|
||||
goal: Label,
|
||||
patterns: Vec<(Path, Pattern<'a>)>,
|
||||
patterns: Vec<(Path, Guard<'a>, Pattern<'a>)>,
|
||||
}
|
||||
|
||||
fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
|
||||
|
@ -163,6 +194,7 @@ fn is_complete(tests: &[Test]) -> bool {
|
|||
Test::IsInt(_) => false,
|
||||
Test::IsFloat(_) => false,
|
||||
Test::IsStr(_) => false,
|
||||
Test::Guarded { .. } => false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -179,20 +211,28 @@ fn flatten_patterns(branch: Branch) -> Branch {
|
|||
}
|
||||
}
|
||||
|
||||
fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, Pattern<'a>)>) {
|
||||
match &path_pattern.1 {
|
||||
fn flatten<'a>(
|
||||
path_pattern: (Path, Guard<'a>, Pattern<'a>),
|
||||
path_patterns: &mut Vec<(Path, Guard<'a>, Pattern<'a>)>,
|
||||
) {
|
||||
match &path_pattern.2 {
|
||||
Pattern::AppliedTag {
|
||||
union,
|
||||
arguments,
|
||||
tag_id,
|
||||
..
|
||||
} => {
|
||||
// TODO do we need to check that guard.is_none() here?
|
||||
if union.alternatives.len() == 1 {
|
||||
let path = path_pattern.0;
|
||||
// Theory: unbox doesn't have any value for us, because one-element tag unions
|
||||
// don't store the tag anyway.
|
||||
if arguments.len() == 1 {
|
||||
path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone()));
|
||||
path_patterns.push((
|
||||
Path::Unbox(Box::new(path)),
|
||||
path_pattern.1.clone(),
|
||||
path_pattern.2.clone(),
|
||||
));
|
||||
} else {
|
||||
for (index, (arg_pattern, _)) in arguments.iter().enumerate() {
|
||||
flatten(
|
||||
|
@ -202,6 +242,8 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
|
|||
tag_id: *tag_id,
|
||||
path: Box::new(path.clone()),
|
||||
},
|
||||
// same guard here?
|
||||
path_pattern.1.clone(),
|
||||
arg_pattern.clone(),
|
||||
),
|
||||
path_patterns,
|
||||
|
@ -225,9 +267,13 @@ fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path,
|
|||
/// path. If that is the case we give the resulting label and a mapping from free
|
||||
/// variables to "how to get their value". So a pattern like (Just (x,_)) will give
|
||||
/// us something like ("x" => value.0.0)
|
||||
fn check_for_match(branches: &Vec<Branch>) -> Option<Label> {
|
||||
fn check_for_match<'a>(branches: &Vec<Branch<'a>>) -> Option<Label> {
|
||||
match branches.get(0) {
|
||||
Some(Branch { goal, patterns }) if patterns.iter().all(|(_, p)| !needs_tests(p)) => {
|
||||
Some(Branch { goal, patterns })
|
||||
if patterns
|
||||
.iter()
|
||||
.all(|(_, guard, pattern)| guard.is_none() && !needs_tests(pattern)) =>
|
||||
{
|
||||
Some(*goal)
|
||||
}
|
||||
_ => None,
|
||||
|
@ -268,12 +314,11 @@ fn gather_edges<'a>(
|
|||
fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Test<'a>> {
|
||||
// NOTE the ordering of the result is important!
|
||||
|
||||
let mut visited = MutSet::default();
|
||||
let mut unique = Vec::new();
|
||||
let mut all_tests = Vec::new();
|
||||
|
||||
let all_tests = branches
|
||||
.into_iter()
|
||||
.filter_map(|b| test_at_path(selected_path, b));
|
||||
for branch in branches.into_iter() {
|
||||
test_at_path(selected_path, branch, &mut all_tests);
|
||||
}
|
||||
|
||||
// The rust HashMap also uses equality, here we really want to use the custom hash function
|
||||
// defined on Test to determine whether a test is unique. So we have to do the hashing
|
||||
|
@ -281,6 +326,9 @@ fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Tes
|
|||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
|
||||
let mut visited = MutSet::default();
|
||||
let mut unique = Vec::new();
|
||||
|
||||
for test in all_tests {
|
||||
let hash = {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
|
@ -297,66 +345,101 @@ fn tests_at_path<'a>(selected_path: &Path, branches: Vec<Branch<'a>>) -> Vec<Tes
|
|||
unique
|
||||
}
|
||||
|
||||
fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>) -> Option<Test<'a>> {
|
||||
fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>, all_tests: &mut Vec<Test<'a>>) {
|
||||
use Pattern::*;
|
||||
use Test::*;
|
||||
|
||||
match branch
|
||||
.patterns
|
||||
.iter()
|
||||
.find(|(path, _)| path == selected_path)
|
||||
.find(|(path, _, _)| path == selected_path)
|
||||
{
|
||||
None => None,
|
||||
Some((_, pattern)) => match pattern {
|
||||
Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => None,
|
||||
None => {}
|
||||
Some((_, guard, pattern)) => {
|
||||
let guarded = |test| {
|
||||
if let Guard::Guard { stores, expr } = guard {
|
||||
Guarded {
|
||||
opt_test: Some(Box::new(test)),
|
||||
stores,
|
||||
expr: expr.clone(),
|
||||
}
|
||||
} else {
|
||||
test
|
||||
}
|
||||
};
|
||||
|
||||
RecordDestructure(destructs, _) => {
|
||||
let union = Union {
|
||||
alternatives: vec![Ctor {
|
||||
name: TagName::Global("#Record".into()),
|
||||
arity: destructs.len(),
|
||||
}],
|
||||
};
|
||||
|
||||
let mut arguments = std::vec::Vec::new();
|
||||
|
||||
for destruct in destructs {
|
||||
if let Some(guard) = &destruct.guard {
|
||||
arguments.push((guard.clone(), destruct.layout.clone()));
|
||||
} else {
|
||||
arguments.push((Pattern::Underscore, destruct.layout.clone()));
|
||||
match pattern {
|
||||
// TODO use guard!
|
||||
Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => {
|
||||
if let Guard::Guard { stores, expr } = guard {
|
||||
all_tests.push(Guarded {
|
||||
opt_test: None,
|
||||
stores,
|
||||
expr: expr.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Some(IsCtor {
|
||||
tag_id: 0,
|
||||
tag_name: TagName::Global("#Record".into()),
|
||||
union,
|
||||
arguments,
|
||||
})
|
||||
}
|
||||
RecordDestructure(destructs, _) => {
|
||||
let union = Union {
|
||||
alternatives: vec![Ctor {
|
||||
name: TagName::Global("#Record".into()),
|
||||
arity: destructs.len(),
|
||||
}],
|
||||
};
|
||||
|
||||
AppliedTag {
|
||||
tag_name,
|
||||
tag_id,
|
||||
arguments,
|
||||
union,
|
||||
..
|
||||
} => Some(IsCtor {
|
||||
tag_id: *tag_id,
|
||||
tag_name: tag_name.clone(),
|
||||
union: union.clone(),
|
||||
arguments: arguments.to_vec(),
|
||||
}),
|
||||
BitLiteral(v) => Some(IsBit(*v)),
|
||||
EnumLiteral { tag_id, enum_size } => Some(IsByte {
|
||||
tag_id: *tag_id,
|
||||
num_alts: *enum_size as usize,
|
||||
}),
|
||||
IntLiteral(v) => Some(IsInt(*v)),
|
||||
FloatLiteral(v) => Some(IsFloat(*v)),
|
||||
StrLiteral(v) => Some(IsStr(v.clone())),
|
||||
},
|
||||
let mut arguments = std::vec::Vec::new();
|
||||
|
||||
for destruct in destructs {
|
||||
if let Some(guard) = &destruct.guard {
|
||||
arguments.push((guard.clone(), destruct.layout.clone()));
|
||||
} else {
|
||||
arguments.push((Pattern::Underscore, destruct.layout.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
all_tests.push(IsCtor {
|
||||
tag_id: 0,
|
||||
tag_name: TagName::Global("#Record".into()),
|
||||
union,
|
||||
arguments,
|
||||
});
|
||||
}
|
||||
|
||||
AppliedTag {
|
||||
tag_name,
|
||||
tag_id,
|
||||
arguments,
|
||||
union,
|
||||
..
|
||||
} => {
|
||||
all_tests.push(IsCtor {
|
||||
tag_id: *tag_id,
|
||||
tag_name: tag_name.clone(),
|
||||
union: union.clone(),
|
||||
arguments: arguments.to_vec(),
|
||||
});
|
||||
}
|
||||
BitLiteral(v) => {
|
||||
all_tests.push(IsBit(*v));
|
||||
}
|
||||
EnumLiteral { tag_id, enum_size } => {
|
||||
all_tests.push(IsByte {
|
||||
tag_id: *tag_id,
|
||||
num_alts: *enum_size as usize,
|
||||
});
|
||||
}
|
||||
IntLiteral(v) => {
|
||||
all_tests.push(guarded(IsInt(*v)));
|
||||
}
|
||||
FloatLiteral(v) => {
|
||||
all_tests.push(IsFloat(*v));
|
||||
}
|
||||
StrLiteral(v) => {
|
||||
all_tests.push(IsStr(v.clone()));
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -367,180 +450,225 @@ fn edges_for<'a>(
|
|||
branches: Vec<Branch<'a>>,
|
||||
test: Test<'a>,
|
||||
) -> (Test<'a>, Vec<Branch<'a>>) {
|
||||
let new_branches = branches
|
||||
.into_iter()
|
||||
.filter_map(|b| to_relevant_branch(&test, path, b))
|
||||
.collect();
|
||||
let mut new_branches = Vec::new();
|
||||
|
||||
for branch in branches.into_iter() {
|
||||
to_relevant_branch(&test, path, branch, &mut new_branches);
|
||||
}
|
||||
|
||||
(test, new_branches)
|
||||
}
|
||||
|
||||
fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> Option<Branch<'a>> {
|
||||
fn to_relevant_branch<'a>(
|
||||
test: &Test<'a>,
|
||||
path: &Path,
|
||||
branch: Branch<'a>,
|
||||
new_branches: &mut Vec<Branch<'a>>,
|
||||
) {
|
||||
// TODO remove clone
|
||||
match extract(path, branch.patterns.clone()) {
|
||||
Extract::NotFound => {
|
||||
new_branches.push(branch);
|
||||
}
|
||||
Extract::Found {
|
||||
start,
|
||||
found_pattern: (guard, pattern),
|
||||
end,
|
||||
} => {
|
||||
let actual_test = match test {
|
||||
Test::Guarded {
|
||||
opt_test: Some(box_test),
|
||||
..
|
||||
} => box_test,
|
||||
_ => test,
|
||||
};
|
||||
|
||||
if let Some(mut new_branch) =
|
||||
to_relevant_branch_help(actual_test, path, start, end, branch, guard, pattern)
|
||||
{
|
||||
// guards can/should only occur at the top level. When we recurse on these
|
||||
// branches, the guard is not relevant any more. Not setthing the guard to None
|
||||
// leads to infinite recursion.
|
||||
new_branch.patterns.iter_mut().for_each(|(_, guard, _)| {
|
||||
*guard = Guard::NoGuard;
|
||||
});
|
||||
|
||||
new_branches.push(new_branch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_relevant_branch_help<'a>(
|
||||
test: &Test<'a>,
|
||||
path: &Path,
|
||||
mut start: Vec<(Path, Guard<'a>, Pattern<'a>)>,
|
||||
end: Vec<(Path, Guard<'a>, Pattern<'a>)>,
|
||||
branch: Branch<'a>,
|
||||
guard: Guard<'a>,
|
||||
pattern: Pattern<'a>,
|
||||
) -> Option<Branch<'a>> {
|
||||
use Pattern::*;
|
||||
use Test::*;
|
||||
|
||||
// TODO remove clone
|
||||
match extract(path, branch.patterns.clone()) {
|
||||
Extract::NotFound => Some(branch),
|
||||
Extract::Found {
|
||||
mut start,
|
||||
found_pattern: pattern,
|
||||
end,
|
||||
match pattern {
|
||||
Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch),
|
||||
|
||||
RecordDestructure(destructs, _) => match test {
|
||||
IsCtor {
|
||||
tag_name: test_name,
|
||||
tag_id,
|
||||
..
|
||||
} => {
|
||||
debug_assert!(test_name == &TagName::Global("#Record".into()));
|
||||
let sub_positions = destructs.into_iter().enumerate().map(|(index, destruct)| {
|
||||
let pattern = if let Some(guard) = destruct.guard {
|
||||
guard.clone()
|
||||
} else {
|
||||
Pattern::Underscore
|
||||
};
|
||||
|
||||
(
|
||||
Path::Index {
|
||||
index: index as u64,
|
||||
tag_id: *tag_id,
|
||||
path: Box::new(path.clone()),
|
||||
},
|
||||
Guard::NoGuard,
|
||||
pattern,
|
||||
)
|
||||
});
|
||||
start.extend(sub_positions);
|
||||
start.extend(end);
|
||||
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
AppliedTag {
|
||||
tag_name,
|
||||
arguments,
|
||||
union,
|
||||
..
|
||||
} => {
|
||||
match pattern {
|
||||
Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch),
|
||||
|
||||
RecordDestructure(destructs, _) => match test {
|
||||
IsCtor {
|
||||
tag_name: test_name,
|
||||
tag_id,
|
||||
..
|
||||
} => {
|
||||
debug_assert!(test_name == &TagName::Global("#Record".into()));
|
||||
match test {
|
||||
IsCtor {
|
||||
tag_name: test_name,
|
||||
tag_id,
|
||||
..
|
||||
} if &tag_name == test_name => {
|
||||
// Theory: Unbox doesn't have any value for us
|
||||
if arguments.len() == 1 && union.alternatives.len() == 1 {
|
||||
let arg = arguments[0].clone();
|
||||
{
|
||||
start.push((Path::Unbox(Box::new(path.clone())), guard, arg.0));
|
||||
start.extend(end);
|
||||
}
|
||||
} else {
|
||||
let sub_positions =
|
||||
destructs.into_iter().enumerate().map(|(index, destruct)| {
|
||||
let pattern = if let Some(guard) = destruct.guard {
|
||||
guard.clone()
|
||||
} else {
|
||||
Pattern::Underscore
|
||||
};
|
||||
|
||||
(
|
||||
Path::Index {
|
||||
index: index as u64,
|
||||
tag_id: *tag_id,
|
||||
path: Box::new(path.clone()),
|
||||
},
|
||||
pattern,
|
||||
)
|
||||
});
|
||||
arguments
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (pattern, _))| {
|
||||
(
|
||||
Path::Index {
|
||||
index: index as u64,
|
||||
tag_id: *tag_id,
|
||||
path: Box::new(path.clone()),
|
||||
},
|
||||
Guard::NoGuard,
|
||||
pattern,
|
||||
)
|
||||
});
|
||||
start.extend(sub_positions);
|
||||
start.extend(end);
|
||||
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
AppliedTag {
|
||||
tag_name,
|
||||
arguments,
|
||||
union,
|
||||
..
|
||||
} => {
|
||||
match test {
|
||||
IsCtor {
|
||||
tag_name: test_name,
|
||||
tag_id,
|
||||
..
|
||||
} if &tag_name == test_name => {
|
||||
// Theory: Unbox doesn't have any value for us
|
||||
if arguments.len() == 1 && union.alternatives.len() == 1 {
|
||||
let arg = arguments[0].clone();
|
||||
{
|
||||
start.push((Path::Unbox(Box::new(path.clone())), arg.0));
|
||||
start.extend(end);
|
||||
}
|
||||
} else {
|
||||
let sub_positions = arguments.into_iter().enumerate().map(
|
||||
|(index, (pattern, _))| {
|
||||
(
|
||||
Path::Index {
|
||||
index: index as u64,
|
||||
tag_id: *tag_id,
|
||||
path: Box::new(path.clone()),
|
||||
},
|
||||
pattern,
|
||||
)
|
||||
},
|
||||
);
|
||||
start.extend(sub_positions);
|
||||
start.extend(end);
|
||||
}
|
||||
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
StrLiteral(string) => match test {
|
||||
IsStr(test_str) if string == *test_str => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
IntLiteral(int) => match test {
|
||||
IsInt(is_int) if int == *is_int => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
FloatLiteral(float) => match test {
|
||||
IsFloat(test_float) if float == *test_float => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
BitLiteral(bit) => match test {
|
||||
IsBit(test_bit) if bit == *test_bit => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
EnumLiteral { tag_id, .. } => match test {
|
||||
IsByte {
|
||||
tag_id: test_id, ..
|
||||
} if tag_id == *test_id => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
StrLiteral(string) => match test {
|
||||
IsStr(test_str) if string == *test_str => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
IntLiteral(int) => match test {
|
||||
IsInt(is_int) if int == *is_int => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
FloatLiteral(float) => match test {
|
||||
IsFloat(test_float) if float == *test_float => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
BitLiteral(bit) => match test {
|
||||
IsBit(test_bit) if bit == *test_bit => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
|
||||
EnumLiteral { tag_id, .. } => match test {
|
||||
IsByte {
|
||||
tag_id: test_id, ..
|
||||
} if tag_id == *test_id => {
|
||||
start.extend(end);
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
enum Extract<'a> {
|
||||
NotFound,
|
||||
Found {
|
||||
start: Vec<(Path, Pattern<'a>)>,
|
||||
found_pattern: Pattern<'a>,
|
||||
end: Vec<(Path, Pattern<'a>)>,
|
||||
start: Vec<(Path, Guard<'a>, Pattern<'a>)>,
|
||||
found_pattern: (Guard<'a>, Pattern<'a>),
|
||||
end: Vec<(Path, Guard<'a>, Pattern<'a>)>,
|
||||
},
|
||||
}
|
||||
|
||||
fn extract<'a>(selected_path: &Path, path_patterns: Vec<(Path, Pattern<'a>)>) -> Extract<'a> {
|
||||
fn extract<'a>(
|
||||
selected_path: &Path,
|
||||
path_patterns: Vec<(Path, Guard<'a>, Pattern<'a>)>,
|
||||
) -> Extract<'a> {
|
||||
let mut start = Vec::new();
|
||||
|
||||
// TODO remove this clone
|
||||
|
@ -551,7 +679,7 @@ fn extract<'a>(selected_path: &Path, path_patterns: Vec<(Path, Pattern<'a>)>) ->
|
|||
if ¤t.0 == selected_path {
|
||||
return Extract::Found {
|
||||
start,
|
||||
found_pattern: current.1,
|
||||
found_pattern: (current.1, current.2),
|
||||
end: {
|
||||
copy.drain(0..=index);
|
||||
copy
|
||||
|
@ -571,10 +699,10 @@ fn is_irrelevant_to<'a>(selected_path: &Path, branch: &Branch<'a>) -> bool {
|
|||
match branch
|
||||
.patterns
|
||||
.iter()
|
||||
.find(|(path, _)| path == selected_path)
|
||||
.find(|(path, _, _)| path == selected_path)
|
||||
{
|
||||
None => true,
|
||||
Some((_, pattern)) => !needs_tests(pattern),
|
||||
Some((_, guard, pattern)) => guard.is_none() && !needs_tests(pattern),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -620,10 +748,10 @@ fn pick_path(branches: Vec<Branch>) -> Path {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_choice_path(path_and_pattern: (Path, Pattern<'_>)) -> Option<Path> {
|
||||
let (path, pattern) = path_and_pattern;
|
||||
fn is_choice_path<'a>(path_and_pattern: (Path, Guard<'a>, Pattern<'a>)) -> Option<Path> {
|
||||
let (path, guard, pattern) = path_and_pattern;
|
||||
|
||||
if needs_tests(&pattern) {
|
||||
if !guard.is_none() || needs_tests(&pattern) {
|
||||
Some(path)
|
||||
} else {
|
||||
None
|
||||
|
@ -737,12 +865,14 @@ pub fn optimize_when<'a>(
|
|||
cond_symbol: Symbol,
|
||||
cond_layout: Layout<'a>,
|
||||
ret_layout: Layout<'a>,
|
||||
opt_branches: Vec<(Pattern<'a>, Expr<'a>)>,
|
||||
opt_branches: Vec<(Pattern<'a>, Guard<'a>, Expr<'a>)>,
|
||||
) -> Expr<'a> {
|
||||
let (patterns, _indexed_branches) = opt_branches
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (pattern, branch))| ((pattern, index as u64), (index as u64, branch)))
|
||||
.map(|(index, (pattern, guard, branch))| {
|
||||
((guard, pattern, index as u64), (index as u64, branch))
|
||||
})
|
||||
.unzip();
|
||||
|
||||
let indexed_branches: Vec<(u64, Expr<'a>)> = _indexed_branches;
|
||||
|
@ -835,6 +965,102 @@ fn path_to_expr_help<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
fn test_to_equality<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
cond_symbol: Symbol,
|
||||
cond_layout: &Layout<'a>,
|
||||
path: &Path,
|
||||
test: Test<'a>,
|
||||
tests: &mut Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>,
|
||||
) {
|
||||
match test {
|
||||
Test::IsCtor {
|
||||
tag_id,
|
||||
union,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
// the IsCtor check should never be generated for tag unions of size 1
|
||||
// (e.g. record pattern guard matches)
|
||||
debug_assert!(union.alternatives.len() > 1);
|
||||
|
||||
let lhs = Expr::Int(tag_id as i64);
|
||||
|
||||
let mut field_layouts =
|
||||
bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena);
|
||||
|
||||
// add the tag discriminant
|
||||
field_layouts.push(Layout::Builtin(Builtin::Int64));
|
||||
|
||||
for (_, layout) in arguments {
|
||||
field_layouts.push(layout);
|
||||
}
|
||||
|
||||
let rhs = Expr::AccessAtIndex {
|
||||
index: 0,
|
||||
field_layouts: field_layouts.into_bump_slice(),
|
||||
expr: env.arena.alloc(Expr::Load(cond_symbol)),
|
||||
is_unwrapped: union.alternatives.len() == 1,
|
||||
};
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
|
||||
}
|
||||
Test::IsInt(test_int) => {
|
||||
let lhs = Expr::Int(test_int);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
|
||||
}
|
||||
|
||||
Test::IsFloat(test_int) => {
|
||||
// TODO maybe we can actually use i64 comparison here?
|
||||
let test_float = f64::from_bits(test_int as u64);
|
||||
let lhs = Expr::Float(test_float);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Float64)));
|
||||
}
|
||||
|
||||
Test::IsByte {
|
||||
tag_id: test_byte, ..
|
||||
} => {
|
||||
let lhs = Expr::Byte(test_byte);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Byte)));
|
||||
}
|
||||
|
||||
Test::IsBit(test_bit) => {
|
||||
let lhs = Expr::Bool(test_bit);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool)));
|
||||
}
|
||||
|
||||
Test::IsStr(test_str) => {
|
||||
let lhs = Expr::Str(env.arena.alloc(test_str));
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Str)));
|
||||
}
|
||||
|
||||
Test::Guarded {
|
||||
opt_test,
|
||||
stores,
|
||||
expr,
|
||||
} => {
|
||||
if let Some(nested) = opt_test {
|
||||
test_to_equality(env, cond_symbol, cond_layout, path, *nested, tests);
|
||||
}
|
||||
|
||||
let lhs = Expr::Bool(true);
|
||||
let rhs = Expr::Store(stores, env.arena.alloc(expr));
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decide_to_branching<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
cond_symbol: Symbol,
|
||||
|
@ -861,77 +1087,7 @@ fn decide_to_branching<'a>(
|
|||
let mut tests = Vec::with_capacity(test_chain.len());
|
||||
|
||||
for (path, test) in test_chain {
|
||||
match test {
|
||||
Test::IsCtor {
|
||||
tag_id,
|
||||
union,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
// the IsCtor check should never be generated for tag unions of size 1
|
||||
// (e.g. record pattern guard matches)
|
||||
debug_assert!(union.alternatives.len() > 1);
|
||||
|
||||
let lhs = Expr::Int(tag_id as i64);
|
||||
|
||||
let mut field_layouts =
|
||||
bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena);
|
||||
|
||||
// add the tag discriminant
|
||||
field_layouts.push(Layout::Builtin(Builtin::Int64));
|
||||
|
||||
for (_, layout) in arguments {
|
||||
field_layouts.push(layout);
|
||||
}
|
||||
|
||||
let rhs = Expr::AccessAtIndex {
|
||||
index: 0,
|
||||
field_layouts: field_layouts.into_bump_slice(),
|
||||
expr: env.arena.alloc(Expr::Load(cond_symbol)),
|
||||
is_unwrapped: union.alternatives.len() == 1,
|
||||
};
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
|
||||
}
|
||||
Test::IsInt(test_int) => {
|
||||
let lhs = Expr::Int(test_int);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
|
||||
}
|
||||
|
||||
Test::IsFloat(test_int) => {
|
||||
// TODO maybe we can actually use i64 comparison here?
|
||||
let test_float = f64::from_bits(test_int as u64);
|
||||
let lhs = Expr::Float(test_float);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Float64)));
|
||||
}
|
||||
|
||||
Test::IsByte {
|
||||
tag_id: test_byte, ..
|
||||
} => {
|
||||
let lhs = Expr::Byte(test_byte);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Byte)));
|
||||
}
|
||||
|
||||
Test::IsBit(test_bit) => {
|
||||
let lhs = Expr::Bool(test_bit);
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool)));
|
||||
}
|
||||
|
||||
Test::IsStr(test_str) => {
|
||||
let lhs = Expr::Str(env.arena.alloc(test_str));
|
||||
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
|
||||
|
||||
tests.push((lhs, rhs, Layout::Builtin(Builtin::Str)));
|
||||
}
|
||||
}
|
||||
test_to_equality(env, cond_symbol, &cond_layout, &path, test, &mut tests);
|
||||
}
|
||||
|
||||
let pass = env.arena.alloc(decide_to_branching(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::layout::{Builtin, Layout};
|
||||
use crate::pattern::Ctor;
|
||||
use crate::pattern::{Ctor, Guard};
|
||||
use bumpalo::collections::Vec;
|
||||
use bumpalo::Bump;
|
||||
use roc_can;
|
||||
|
@ -996,9 +996,18 @@ fn from_can_when<'a>(
|
|||
let mono_pattern = from_can_pattern(env, &loc_when_pattern.value);
|
||||
|
||||
// record pattern matches can have 1 branch and typecheck, but may still not be exhaustive
|
||||
let guard = if first.guard.is_some() {
|
||||
Guard::HasGuard
|
||||
} else {
|
||||
Guard::NoGuard
|
||||
};
|
||||
|
||||
match crate::pattern::check(
|
||||
Region::zero(),
|
||||
&[Located::at(loc_when_pattern.region, mono_pattern.clone())],
|
||||
&[(
|
||||
Located::at(loc_when_pattern.region, mono_pattern.clone()),
|
||||
guard,
|
||||
)],
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(errors) => panic!("Errors in patterns: {:?}", errors),
|
||||
|
@ -1031,14 +1040,23 @@ fn from_can_when<'a>(
|
|||
for when_branch in branches {
|
||||
let mono_expr = from_can(env, when_branch.value.value, procs, None);
|
||||
|
||||
let exhaustive_guard = if when_branch.guard.is_some() {
|
||||
Guard::HasGuard
|
||||
} else {
|
||||
Guard::NoGuard
|
||||
};
|
||||
|
||||
for loc_pattern in when_branch.patterns {
|
||||
let mono_pattern = from_can_pattern(env, &loc_pattern.value);
|
||||
|
||||
loc_branches.push(Located::at(loc_pattern.region, mono_pattern.clone()));
|
||||
loc_branches.push((
|
||||
Located::at(loc_pattern.region, mono_pattern.clone()),
|
||||
exhaustive_guard.clone(),
|
||||
));
|
||||
|
||||
let mut stores = Vec::with_capacity_in(1, env.arena);
|
||||
|
||||
let mono_expr = match store_pattern(
|
||||
let (mono_guard, expr_with_stores) = match store_pattern(
|
||||
env,
|
||||
&mono_pattern,
|
||||
cond_symbol,
|
||||
|
@ -1046,12 +1064,52 @@ fn from_can_when<'a>(
|
|||
&mut stores,
|
||||
) {
|
||||
Ok(_) => {
|
||||
Expr::Store(stores.into_bump_slice(), env.arena.alloc(mono_expr.clone()))
|
||||
// if the branch is guarded, the guard can use variables bound in the
|
||||
// pattern. They must be available, so we give the stores to the
|
||||
// decision_tree. A branch with guard can only be entered with the guard
|
||||
// evaluated, so variables will also be loaded in the branch's body expr.
|
||||
//
|
||||
// otherwise, we modify the branch's expression to include the stores
|
||||
if let Some(loc_guard) = when_branch.guard.clone() {
|
||||
let expr = from_can(env, loc_guard.value, procs, None);
|
||||
(
|
||||
crate::decision_tree::Guard::Guard {
|
||||
stores: stores.into_bump_slice(),
|
||||
expr,
|
||||
},
|
||||
mono_expr.clone(),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
crate::decision_tree::Guard::NoGuard,
|
||||
Expr::Store(
|
||||
stores.into_bump_slice(),
|
||||
env.arena.alloc(mono_expr.clone()),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(message) => {
|
||||
// when the pattern is invalid, a guard must give a runtime error too
|
||||
if when_branch.guard.is_some() {
|
||||
(
|
||||
crate::decision_tree::Guard::Guard {
|
||||
stores: &[],
|
||||
expr: Expr::RuntimeError(env.arena.alloc(message)),
|
||||
},
|
||||
// we can never hit this
|
||||
Expr::RuntimeError(&"invalid pattern with guard: unreachable"),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
crate::decision_tree::Guard::NoGuard,
|
||||
Expr::RuntimeError(env.arena.alloc(message)),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(message) => Expr::RuntimeError(env.arena.alloc(message)),
|
||||
};
|
||||
|
||||
opt_branches.push((mono_pattern, mono_expr));
|
||||
opt_branches.push((mono_pattern, mono_guard, expr_with_stores));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -131,11 +131,17 @@ pub enum Context {
|
|||
BadCase,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Guard {
|
||||
HasGuard,
|
||||
NoGuard,
|
||||
}
|
||||
|
||||
/// Check
|
||||
|
||||
pub fn check<'a>(
|
||||
region: Region,
|
||||
patterns: &[Located<crate::expr::Pattern<'a>>],
|
||||
patterns: &[(Located<crate::expr::Pattern<'a>>, Guard)],
|
||||
) -> Result<(), Vec<Error>> {
|
||||
let mut errors = Vec::new();
|
||||
check_patterns(region, Context::BadArg, patterns, &mut errors);
|
||||
|
@ -150,7 +156,7 @@ pub fn check<'a>(
|
|||
pub fn check_patterns<'a>(
|
||||
region: Region,
|
||||
context: Context,
|
||||
patterns: &[Located<crate::expr::Pattern<'a>>],
|
||||
patterns: &[(Located<crate::expr::Pattern<'a>>, Guard)],
|
||||
errors: &mut Vec<Error>,
|
||||
) {
|
||||
match to_nonredundant_rows(region, patterns) {
|
||||
|
@ -283,14 +289,52 @@ fn recover_ctor(
|
|||
/// INVARIANT: Produces a list of rows where (forall row. length row == 1)
|
||||
fn to_nonredundant_rows<'a>(
|
||||
overall_region: Region,
|
||||
patterns: &[Located<crate::expr::Pattern<'a>>],
|
||||
patterns: &[(Located<crate::expr::Pattern<'a>>, Guard)],
|
||||
) -> Result<Vec<Vec<Pattern>>, Error> {
|
||||
let mut checked_rows = Vec::with_capacity(patterns.len());
|
||||
|
||||
for loc_pat in patterns {
|
||||
// If any of the branches has a guard, e.g.
|
||||
//
|
||||
// when x is
|
||||
// y if y < 10 -> "foo"
|
||||
// _ -> "bar"
|
||||
//
|
||||
// then we treat it as a pattern match on the pattern and a boolean, wrapped in the #Guard
|
||||
// constructor. We can use this special constructor name to generate better error messages.
|
||||
// This transformation of the pattern match only works because we only report exhaustiveness
|
||||
// errors: the Pattern created in this file is not used for code gen.
|
||||
//
|
||||
// when x is
|
||||
// #Guard y True -> "foo"
|
||||
// #Guard _ _ -> "bar"
|
||||
let any_has_guard = patterns.iter().any(|(_, guard)| guard == &Guard::HasGuard);
|
||||
|
||||
for (loc_pat, guard) in patterns {
|
||||
let region = loc_pat.region;
|
||||
|
||||
let next_row = vec![simplify(&loc_pat.value)];
|
||||
let next_row = if any_has_guard {
|
||||
let guard_pattern = match guard {
|
||||
Guard::HasGuard => Pattern::Literal(Literal::Bit(true)),
|
||||
Guard::NoGuard => Pattern::Anything,
|
||||
};
|
||||
|
||||
let union = Union {
|
||||
alternatives: vec![Ctor {
|
||||
name: TagName::Global("#Guard".into()),
|
||||
arity: 2,
|
||||
}],
|
||||
};
|
||||
|
||||
let tag_name = TagName::Global("#Guard".into());
|
||||
|
||||
vec![Pattern::Ctor(
|
||||
union,
|
||||
tag_name,
|
||||
vec![simplify(&loc_pat.value), guard_pattern],
|
||||
)]
|
||||
} else {
|
||||
vec![simplify(&loc_pat.value)]
|
||||
};
|
||||
|
||||
if is_useful(&checked_rows, &next_row) {
|
||||
checked_rows.push(next_row);
|
||||
|
|
|
@ -2423,7 +2423,7 @@ mod test_solve {
|
|||
when x is
|
||||
2 | 3 -> 0
|
||||
a if a < 20 -> 1
|
||||
3 | 4 if -> 2
|
||||
3 | 4 if False -> 2
|
||||
_ -> 3
|
||||
"#
|
||||
),
|
||||
|
|
|
@ -2144,7 +2144,8 @@ mod test_uniq_solve {
|
|||
|
||||
#[test]
|
||||
fn cheapest_open() {
|
||||
infer_eq(
|
||||
with_larger_debug_stack(|| {
|
||||
infer_eq(
|
||||
indoc!(
|
||||
r#"
|
||||
Model position : { evaluated : Set position
|
||||
|
@ -2180,7 +2181,8 @@ mod test_uniq_solve {
|
|||
"#
|
||||
),
|
||||
"Attr * (Attr * (Attr Shared position -> Attr Shared Float), Attr * (Model (Attr Shared position)) -> Attr * (Result (Attr Shared position) (Attr * [ KeyNotFound ]*)))"
|
||||
);
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -2389,11 +2391,11 @@ mod test_uniq_solve {
|
|||
when x is
|
||||
2 | 3 -> 0
|
||||
a if a < 20 -> 1
|
||||
3 | 4 if -> 2
|
||||
3 | 4 if False -> 2
|
||||
_ -> 3
|
||||
"#
|
||||
),
|
||||
"Attr * (Attr * (Num (Attr * *)) -> Attr * (Num (Attr * *)))",
|
||||
"Attr * (Attr Shared (Num (Attr * *)) -> Attr * (Num (Attr * *)))",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue