diff --git a/compiler/gen/tests/test_gen.rs b/compiler/gen/tests/test_gen.rs index 0546180c82..697d51afec 100644 --- a/compiler/gen/tests/test_gen.rs +++ b/compiler/gen/tests/test_gen.rs @@ -1540,20 +1540,20 @@ mod test_gen { ); } - // #[test] - // fn if_guard_exhaustiveness() { - // assert_evals_to!( - // indoc!( - // r#" - // when 2 is - // _ if False -> 0 - // _ -> 42 - // "# - // ), - // 42, - // i64 - // ); - // } + #[test] + fn if_guard_exhaustiveness() { + assert_evals_to!( + indoc!( + r#" + when 2 is + _ if False -> 0 + _ -> 42 + "# + ), + 42, + i64 + ); + } // #[test] // fn linked_list_empty() { diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index 739cd36f04..71336a3586 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -444,26 +444,25 @@ fn to_relevant_branch<'a>( start, found_pattern: (guard, pattern), end, - } => match test { - Test::Guarded(None, _guard_expr) => { - // theory: Some(branch) - todo!(); + } => { + let actual_test = match test { + Test::Guarded(Some(box_test), _guard_expr) => box_test, + _ => test, + }; + + if let Some(mut new_branch) = + to_relevant_branch_help(actual_test, path, start, end, branch, guard, pattern) + { + // guards can/should only occur at the top level. When we recurse on these + // branches, the guard is not relevant any more. Not setthing the guard to None + // leads to infinite recursion. + new_branch.patterns.iter_mut().for_each(|(_, guard, _)| { + *guard = None; + }); + + new_branches.push(new_branch); } - Test::Guarded(Some(box_test), _guard_expr) => { - if let Some(new_branch) = - to_relevant_branch_help(box_test, path, start, end, branch, guard, pattern) - { - new_branches.push(new_branch); - } - } - _ => { - if let Some(new_branch) = - to_relevant_branch_help(test, path, start, end, branch, guard, pattern) - { - new_branches.push(new_branch); - } - } - }, + } } } @@ -502,7 +501,7 @@ fn to_relevant_branch_help<'a>( tag_id: *tag_id, path: Box::new(path.clone()), }, - guard.clone(), + None, pattern, ) }); @@ -548,7 +547,7 @@ fn to_relevant_branch_help<'a>( tag_id: *tag_id, path: Box::new(path.clone()), }, - guard.clone(), + None, pattern, ) }); diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index 61c2de4b8d..58134757a9 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -1,5 +1,5 @@ use crate::layout::{Builtin, Layout}; -use crate::pattern::Ctor; +use crate::pattern::{Ctor, Guard}; use bumpalo::collections::Vec; use bumpalo::Bump; use roc_can; @@ -996,9 +996,18 @@ fn from_can_when<'a>( let mono_pattern = from_can_pattern(env, &loc_when_pattern.value); // record pattern matches can have 1 branch and typecheck, but may still not be exhaustive + let guard = if first.guard.is_some() { + Guard::HasGuard + } else { + Guard::NoGuard + }; + match crate::pattern::check( Region::zero(), - &[Located::at(loc_when_pattern.region, mono_pattern.clone())], + &[( + Located::at(loc_when_pattern.region, mono_pattern.clone()), + guard, + )], ) { Ok(_) => {} Err(errors) => panic!("Errors in patterns: {:?}", errors), @@ -1037,10 +1046,19 @@ fn from_can_when<'a>( None }; + let guard = if mono_guard.is_some() { + Guard::HasGuard + } else { + Guard::NoGuard + }; + for loc_pattern in when_branch.patterns { let mono_pattern = from_can_pattern(env, &loc_pattern.value); - loc_branches.push(Located::at(loc_pattern.region, mono_pattern.clone())); + loc_branches.push(( + Located::at(loc_pattern.region, mono_pattern.clone()), + guard.clone(), + )); let mut stores = Vec::with_capacity_in(1, env.arena); diff --git a/compiler/mono/src/pattern.rs b/compiler/mono/src/pattern.rs index 664cf9aeab..11ffe43358 100644 --- a/compiler/mono/src/pattern.rs +++ b/compiler/mono/src/pattern.rs @@ -131,11 +131,17 @@ pub enum Context { BadCase, } +#[derive(Clone, Debug, PartialEq)] +pub enum Guard { + HasGuard, + NoGuard, +} + /// Check pub fn check<'a>( region: Region, - patterns: &[Located>], + patterns: &[(Located>, Guard)], ) -> Result<(), Vec> { let mut errors = Vec::new(); check_patterns(region, Context::BadArg, patterns, &mut errors); @@ -150,7 +156,7 @@ pub fn check<'a>( pub fn check_patterns<'a>( region: Region, context: Context, - patterns: &[Located>], + patterns: &[(Located>, Guard)], errors: &mut Vec, ) { match to_nonredundant_rows(region, patterns) { @@ -283,14 +289,52 @@ fn recover_ctor( /// INVARIANT: Produces a list of rows where (forall row. length row == 1) fn to_nonredundant_rows<'a>( overall_region: Region, - patterns: &[Located>], + patterns: &[(Located>, Guard)], ) -> Result>, Error> { let mut checked_rows = Vec::with_capacity(patterns.len()); - for loc_pat in patterns { + // If any of the branches has a guard, e.g. + // + // when x is + // y if y < 10 -> "foo" + // _ -> "bar" + // + // then we treat it as a pattern match on the pattern and a boolean, wrapped in the #Guard + // constructor. We can use this special constructor name to generate better error messages. + // This transformation of the pattern match only works because we only report exhaustiveness + // errors: the Pattern created in this file is not used for code gen. + // + // when x is + // #Guard y True -> "foo" + // #Guard _ _ -> "bar" + let any_has_guard = patterns.iter().any(|(_, guard)| guard == &Guard::HasGuard); + + for (loc_pat, guard) in patterns { let region = loc_pat.region; - let next_row = vec![simplify(&loc_pat.value)]; + let next_row = if any_has_guard { + let guard_pattern = match guard { + Guard::HasGuard => Pattern::Literal(Literal::Bit(true)), + Guard::NoGuard => Pattern::Anything, + }; + + let union = Union { + alternatives: vec![Ctor { + name: TagName::Global("#Guard".into()), + arity: 2, + }], + }; + + let tag_name = TagName::Global("#Guard".into()); + + vec![Pattern::Ctor( + union, + tag_name, + vec![simplify(&loc_pat.value), guard_pattern], + )] + } else { + vec![simplify(&loc_pat.value)] + }; if is_useful(&checked_rows, &next_row) { checked_rows.push(next_row);