diff --git a/compiler/gen/tests/test_gen.rs b/compiler/gen/tests/test_gen.rs index aec9e2bc2e..ede203696b 100644 --- a/compiler/gen/tests/test_gen.rs +++ b/compiler/gen/tests/test_gen.rs @@ -77,7 +77,6 @@ mod test_gen { // Populate Procs and Subs, and get the low-level Expr from the canonical Expr let mono_expr = Expr::new(&arena, &mut subs, loc_expr.value, &mut procs, home, &mut ident_ids, POINTER_SIZE); - // Put this module's ident_ids back in the interns env.interns.all_ident_ids.insert(home, ident_ids); @@ -1345,6 +1344,44 @@ mod test_gen { 5, i64 ); + + assert_evals_to!( + indoc!( + r#" + when { x: 0x2, y: 3.14 } is + { x: var } -> var + 3 + "# + ), + 5, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + { x } = { x: 0x2, y: 3.14 } + + x + "# + ), + 2, + i64 + ); + } + + #[test] + fn record_guard_pattern() { + assert_evals_to!( + indoc!( + r#" + when { x: 0x2, y: 3.14 } is + { x: 0x4 } -> 5 + { x } -> x + 3 + "# + ), + 5, + i64 + ); } // #[test] diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index 7e9cb052f8..664fd1fd3d 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -7,6 +7,7 @@ use roc_module::symbol::Symbol; use crate::layout::Builtin; use crate::layout::Layout; +use crate::pattern::{Ctor, Union}; /// COMPILE CASES @@ -245,11 +246,33 @@ fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>) -> Option { None => None, Some((_, pattern)) => match pattern { - Identifier(_) - | RecordDestructure(_, _) - | Underscore - | Shadowed(_, _) - | UnsupportedPattern(_) => None, + Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => None, + + RecordDestructure(destructs, _) => { + let union = Union { + alternatives: vec![Ctor { + name: TagName::Global("#Record".into()), + arity: destructs.len(), + }], + }; + + let mut arguments = std::vec::Vec::new(); + + for destruct in destructs { + if let Some(guard) = &destruct.guard { + arguments.push((guard.clone(), destruct.layout.clone())); + } else { + arguments.push((Pattern::Underscore, destruct.layout.clone())); + } + } + + Some(IsCtor { + tag_id: 0, + tag_name: TagName::Global("#Record".into()), + union, + arguments, + }) + } AppliedTag { tag_name, @@ -302,11 +325,42 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O found_pattern: pattern, end, } => match pattern { - RecordDestructure(_, _) - | Identifier(_) - | Underscore - | Shadowed(_, _) - | UnsupportedPattern(_) => Some(branch), + Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => Some(branch), + + RecordDestructure(destructs, _) => match test { + IsCtor { + tag_name: test_name, + .. + } => { + debug_assert!(test_name == &TagName::Global("#Record".into())); + + let sub_positions = + destructs.into_iter().enumerate().map(|(index, destruct)| { + let pattern = if let Some(guard) = destruct.guard { + guard.clone() + } else { + Pattern::Underscore + }; + + ( + Path::Index { + index: index as u64, + path: Box::new(path.clone()), + }, + pattern, + ) + }); + start.extend(sub_positions); + start.extend(end); + + Some(Branch { + goal: branch.goal, + patterns: start, + }) + } + _ => None, + }, + AppliedTag { union, tag_name, @@ -349,22 +403,6 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O } _ => None, } - /* - * - Can.PCtor _ _ (Can.Union _ _ numAlts _) name _ ctorArgs -> - case test of - IsCtor _ testName _ _ _ | name == testName -> - Just $ Branch goal $ - case map dearg ctorArgs of - [arg] | numAlts == 1 -> - start ++ [(Unbox path, arg)] ++ end - - args -> - start ++ subPositions path args ++ end - - _ -> - Nothing - */ } StrLiteral(string) => match test { IsStr(test_str) if string == *test_str => { @@ -682,24 +720,17 @@ fn path_to_expr<'a>( symbol: Symbol, path: &Path, is_unwrapped: bool, - field_layouts: Layout<'a>, + field_layouts: &'a [Layout<'a>], ) -> Expr<'a> { match path { Path::Unbox(ref path) => path_to_expr(env, symbol, path, true, field_layouts), - // TODO make this work with AccessAtIndex. - // that already works for structs, but not for basic types for some reason - // Expr::AccessAtIndex { - // index: 0, - // field_layouts: env.arena.alloc([field_layouts]), - // expr: env.arena.alloc(Expr::Load(symbol)), - // }, Path::Empty => Expr::Load(symbol), // TODO path contains a nested path. Traverse all the way Path::Index { index, .. } => Expr::AccessAtIndex { index: *index, - field_layouts: env.arena.alloc([Layout::Builtin(Builtin::Byte)]), + field_layouts, expr: env.arena.alloc(Expr::Load(symbol)), is_unwrapped, }, @@ -739,15 +770,17 @@ fn decide_to_branching<'a>( arguments, .. } => { + // the IsCtor check should never be generated for tag unions of size 1 + // (e.g. record pattern guard matches) + debug_assert!(union.alternatives.len() > 1); + let lhs = Expr::Int(tag_id as i64); let mut field_layouts = bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena); - if union.alternatives.len() > 1 { - // the tag discriminant - field_layouts.push(Layout::Builtin(Builtin::Int64)); - } + // add the tag discriminant + field_layouts.push(Layout::Builtin(Builtin::Int64)); for (_, layout) in arguments { field_layouts.push(layout); @@ -777,7 +810,7 @@ fn decide_to_branching<'a>( cond_symbol, &path, false, - Layout::Builtin(Builtin::Int64), + env.arena.alloc([Layout::Builtin(Builtin::Int64)]), ); let cond = env.arena.alloc(Expr::CallByName( @@ -800,7 +833,7 @@ fn decide_to_branching<'a>( cond_symbol, &path, false, - Layout::Builtin(Builtin::Float64), + env.arena.alloc([Layout::Builtin(Builtin::Float64)]), ); let cond = env.arena.alloc(Expr::CallByName( @@ -825,7 +858,7 @@ fn decide_to_branching<'a>( cond_symbol, &path, false, - Layout::Builtin(Builtin::Byte), + env.arena.alloc([Layout::Builtin(Builtin::Byte)]), ); let cond = env.arena.alloc(Expr::CallByName( @@ -842,8 +875,6 @@ fn decide_to_branching<'a>( } } - let cond = tests.remove(0); - let pass = env.arena.alloc(decide_to_branching( env, cond_symbol, @@ -862,6 +893,10 @@ fn decide_to_branching<'a>( jumps, )); + // TODO take the boolean and of all the tests + debug_assert!(tests.len() == 1); + let cond = tests.remove(0); + let cond_layout = Layout::Builtin(Builtin::Bool); Expr::Cond { @@ -882,7 +917,7 @@ fn decide_to_branching<'a>( cond_symbol, &path, false, - cond_layout.clone(), + env.arena.alloc([cond_layout.clone()]), )); let default_branch = env.arena.alloc(decide_to_branching( @@ -935,6 +970,15 @@ fn decide_to_branching<'a>( /// Decision trees may have some redundancies, so we convert them to a Decider /// which has special constructs to avoid code duplication when possible. +/// If a test always succeeds, we don't need to branch on it +/// this saves on work and jumps +fn test_always_succeeds(test: &Test) -> bool { + match test { + Test::IsCtor { union, .. } => union.alternatives.len() == 1, + _ => false, + } +} + fn tree_to_decider(tree: DecisionTree) -> Decider { use Decider::*; use DecisionTree::*; @@ -958,7 +1002,11 @@ fn tree_to_decider(tree: DecisionTree) -> Decider { let (_, failure_tree) = edges.remove(1); let (test, success_tree) = edges.remove(0); - to_chain(path, test, success_tree, failure_tree) + if test_always_succeeds(&test) { + tree_to_decider(success_tree) + } else { + to_chain(path, test, success_tree, failure_tree) + } } _ => { @@ -983,7 +1031,11 @@ fn tree_to_decider(tree: DecisionTree) -> Decider { let failure_tree = *last; let (test, success_tree) = edges.remove(0); - to_chain(path, test, success_tree, failure_tree) + if test_always_succeeds(&test) { + tree_to_decider(success_tree) + } else { + to_chain(path, test, success_tree, failure_tree) + } } _ => { diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index bf6990a44b..76b9942277 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -943,7 +943,7 @@ fn store_record_destruct<'a>( stored.push((*symbol, destruct.layout.clone(), load)); } Pattern::Underscore => { - // important that this is special-cased: mono record patterns will extract all the + // important that this is special-cased to do nothing: mono record patterns will extract all the // fields, but those not bound in the source code are guarded with the underscore // pattern. So given some record `{ x : a, y : b }`, a match // @@ -993,6 +993,15 @@ fn from_can_when<'a>( let mono_pattern = from_can_pattern(env, &loc_when_pattern.value); + // record pattern matches can have 1 branch and typecheck, but may still not be exhaustive + match crate::pattern::check( + Region::zero(), + &[Located::at(loc_when_pattern.region, mono_pattern.clone())], + ) { + Ok(_) => {} + Err(errors) => panic!("Errors in patterns: {:?}", errors), + } + let cond_layout = Layout::from_var(env.arena, cond_var, env.subs, env.pointer_size) .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); let cond_symbol = env.fresh_symbol(); diff --git a/compiler/mono/src/pattern.rs b/compiler/mono/src/pattern.rs index cdcb4175d2..664cf9aeab 100644 --- a/compiler/mono/src/pattern.rs +++ b/compiler/mono/src/pattern.rs @@ -72,9 +72,24 @@ fn simplify<'a>(pattern: &crate::expr::Pattern<'a>) -> Pattern { Underscore => Anything, Identifier(_) => Anything, - RecordDestructure { .. } => { - // TODO we must check the guard conditions! - Anything + RecordDestructure(destructures, _) => { + let union = Union { + alternatives: vec![Ctor { + name: TagName::Global("#Record".into()), + arity: destructures.len(), + }], + }; + + let mut patterns = std::vec::Vec::with_capacity(destructures.len()); + + for destruct in destructures { + match &destruct.guard { + None => patterns.push(Anything), + Some(guard) => patterns.push(simplify(guard)), + } + } + + Ctor(union, TagName::Global("#Record".into()), patterns) } Shadowed(_region, _ident) => {