diff --git a/compiler/gen/tests/test_gen.rs b/compiler/gen/tests/test_gen.rs index 6c45eee900..d4671c63fd 100644 --- a/compiler/gen/tests/test_gen.rs +++ b/compiler/gen/tests/test_gen.rs @@ -1313,6 +1313,22 @@ mod test_gen { ); } + #[test] + fn match_on_two_values() { + // this will produce a Chain internally + assert_evals_to!( + indoc!( + r#" + when Pair 2 3 is + Pair 4 3 -> 9 + Pair a b -> a + b + "# + ), + 5, + i64 + ); + } + #[test] fn maybe_is_just() { assert_evals_to!( diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index 78236b70e1..7ccb53a59b 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -6,6 +6,7 @@ use roc_collections::all::{MutMap, MutSet}; use roc_module::ident::TagName; use roc_module::symbol::Symbol; +use crate::expr::specialize_equality; use crate::layout::Builtin; use crate::layout::Layout; use crate::pattern::{Ctor, Union}; @@ -61,7 +62,11 @@ pub enum Test<'a> { #[derive(Clone, Debug, PartialEq)] pub enum Path { - Index { index: u64, path: Box }, + Index { + index: u64, + tag_id: u8, + path: Box, + }, Unbox(Box), Empty, } @@ -141,24 +146,32 @@ fn flatten_patterns(branch: Branch) -> Branch { fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, Pattern<'a>)>) { match &path_pattern.1 { - Pattern::AppliedTag { union, .. } => { + Pattern::AppliedTag { + union, + arguments, + tag_id, + .. + } => { if union.alternatives.len() == 1 { - // case map dearg ctorArgs of - // [arg] -> - // flatten (Unbox path, arg) otherPathPatterns - // - // args -> - // foldr flatten otherPathPatterns (subPositions path args) - // subPositions :: Path -> [Can.Pattern] -> [(Path, Can.Pattern)] - // subPositions path patterns = - // Index.indexedMap (\index pattern -> (Index index path, pattern)) patterns - // - // - // dearg :: Can.PatternCtorArg -> Can.Pattern - // dearg (Can.PatternCtorArg _ _ pattern) = - // pattern - - todo!("alternatives: {:?}", union.alternatives) + let path = path_pattern.0; + // Theory: unbox doesn't have any value for us, because one-element tag unions + // don't store the tag anyway. + // if arguments.len() == 1 { + // path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone())); + // } else { + for (index, (arg_pattern, _)) in arguments.iter().enumerate() { + flatten( + ( + Path::Index { + index: index as u64, + tag_id: *tag_id, + path: Box::new(path.clone()), + }, + arg_pattern.clone(), + ), + path_patterns, + ); + } } else { path_patterns.push(path_pattern); } @@ -331,6 +344,7 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O RecordDestructure(destructs, _) => match test { IsCtor { tag_name: test_name, + tag_id, .. } => { debug_assert!(test_name == &TagName::Global("#Record".into())); @@ -346,6 +360,7 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O ( Path::Index { index: index as u64, + tag_id: *tag_id, path: Box::new(path.clone()), }, pattern, @@ -363,39 +378,40 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O }, AppliedTag { - union, tag_name, arguments, .. } => { - let mut arguments: Vec<_> = arguments.into_iter().map(|v| v.0).collect(); match test { IsCtor { tag_name: test_name, + tag_id, .. } if &tag_name == test_name => { - // TODO can't we unbox whenever there is just one alternative, even if - // there are multiple arguments? - if arguments.len() == 1 && union.alternatives.len() == 1 { - let arg = arguments.remove(0); - { - start.push((Path::Unbox(Box::new(path.clone())), arg)); - start.extend(end); - } - } else { - let sub_positions = - arguments.into_iter().enumerate().map(|(index, pattern)| { + // Theory: Unbox doesn't have any value for us + // if arguments.len() == 1 && union.alternatives.len() == 1 { + // let arg = arguments.remove(0); + // { + // start.push((Path::Unbox(Box::new(path.clone())), arg)); + // start.extend(end); + // } + // } else { + let sub_positions = + arguments + .into_iter() + .enumerate() + .map(|(index, (pattern, _))| { ( Path::Index { index: index as u64, + tag_id: *tag_id, path: Box::new(path.clone()), }, pattern, ) }); - start.extend(sub_positions); - start.extend(end); - } + start.extend(sub_positions); + start.extend(end); Some(Branch { goal: branch.goal, @@ -720,21 +736,53 @@ fn path_to_expr<'a>( env: &mut Env<'a, '_>, symbol: Symbol, path: &Path, - is_unwrapped: bool, - field_layouts: &'a [Layout<'a>], + layout: &Layout<'a>, ) -> Expr<'a> { + path_to_expr_help(env, symbol, path, layout.clone()).0 +} + +fn path_to_expr_help<'a>( + env: &mut Env<'a, '_>, + symbol: Symbol, + path: &Path, + layout: Layout<'a>, +) -> (Expr<'a>, Layout<'a>) { match path { - Path::Unbox(ref path) => path_to_expr(env, symbol, path, true, field_layouts), - - Path::Empty => Expr::Load(symbol), - - // TODO path contains a nested path. Traverse all the way - Path::Index { index, .. } => Expr::AccessAtIndex { - index: *index, - field_layouts, - expr: env.arena.alloc(Expr::Load(symbol)), - is_unwrapped, + Path::Unbox(ref unboxed) => match **unboxed { + _ => todo!(), }, + + Path::Empty => (Expr::Load(symbol), layout), + + Path::Index { + index, + tag_id, + path: nested, + } => { + let (outer_expr, outer_layout) = path_to_expr_help(env, symbol, nested, layout); + + let (is_unwrapped, field_layouts) = match outer_layout { + Layout::Union(layouts) => (layouts.is_empty(), layouts[*tag_id as usize].to_vec()), + Layout::Struct(layouts) => ( + true, + layouts.iter().map(|v| v.1.clone()).collect::>(), + ), + other => (true, vec![other]), + }; + + debug_assert!(*index < field_layouts.len() as u64); + + let inner_layout = field_layouts[*index as usize].clone(); + + let inner_expr = Expr::AccessAtIndex { + index: *index, + field_layouts: env.arena.alloc(field_layouts), + expr: env.arena.alloc(outer_expr), + is_unwrapped, + }; + + (inner_expr, inner_layout) + } } } @@ -794,85 +842,46 @@ fn decide_to_branching<'a>( is_unwrapped: union.alternatives.len() == 1, }; - let cond = Expr::CallByName( - Symbol::INT_EQ_I64, - env.arena.alloc([ - (lhs, Layout::Builtin(Builtin::Int64)), - (rhs, Layout::Builtin(Builtin::Int64)), - ]), - ); - - tests.push(cond); + tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64))); } Test::IsInt(test_int) => { let lhs = Expr::Int(test_int); - let rhs = path_to_expr( - env, - cond_symbol, - &path, - false, - env.arena.alloc([Layout::Builtin(Builtin::Int64)]), - ); + let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); - let cond = Expr::CallByName( - Symbol::INT_EQ_I64, - env.arena.alloc([ - (lhs, Layout::Builtin(Builtin::Int64)), - (rhs, Layout::Builtin(Builtin::Int64)), - ]), - ); - - tests.push(cond); + tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64))); } Test::IsFloat(test_int) => { // TODO maybe we can actually use i64 comparison here? let test_float = f64::from_bits(test_int as u64); let lhs = Expr::Float(test_float); - let rhs = path_to_expr( - env, - cond_symbol, - &path, - false, - env.arena.alloc([Layout::Builtin(Builtin::Float64)]), - ); + let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); - let cond = Expr::CallByName( - Symbol::FLOAT_EQ, - env.arena.alloc([ - (lhs, Layout::Builtin(Builtin::Float64)), - (rhs, Layout::Builtin(Builtin::Float64)), - ]), - ); - - tests.push(cond); + tests.push((lhs, rhs, Layout::Builtin(Builtin::Float64))); } Test::IsByte { - tag_id: test_byte, - // num_alts: _, - .. + tag_id: test_byte, .. } => { let lhs = Expr::Byte(test_byte); - let rhs = path_to_expr( - env, - cond_symbol, - &path, - false, - env.arena.alloc([Layout::Builtin(Builtin::Byte)]), - ); + let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); - let cond = Expr::CallByName( - Symbol::INT_EQ_I8, - env.arena.alloc([ - (lhs, Layout::Builtin(Builtin::Byte)), - (rhs, Layout::Builtin(Builtin::Byte)), - ]), - ); - - tests.push(cond); + tests.push((lhs, rhs, Layout::Builtin(Builtin::Byte))); + } + + Test::IsBit(test_bit) => { + let lhs = Expr::Bool(test_bit); + let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); + + tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool))); + } + + Test::IsStr(test_str) => { + let lhs = Expr::Str(env.arena.alloc(test_str)); + let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout); + + tests.push((lhs, rhs, Layout::Builtin(Builtin::Str))); } - _ => todo!(), } } @@ -911,13 +920,9 @@ fn decide_to_branching<'a>( tests, fallback, } => { - let cond = env.arena.alloc(path_to_expr( - env, - cond_symbol, - &path, - false, - env.arena.alloc([cond_layout.clone()]), - )); + let cond = env + .arena + .alloc(path_to_expr(env, cond_symbol, &path, &cond_layout)); let default_branch = env.arena.alloc(decide_to_branching( env, @@ -964,10 +969,11 @@ fn decide_to_branching<'a>( } } -fn boolean_all<'a>(arena: &'a Bump, tests: Vec>) -> Expr<'a> { +fn boolean_all<'a>(arena: &'a Bump, tests: Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>) -> Expr<'a> { let mut expr = Expr::Bool(true); - for test in tests.into_iter().rev() { + for (lhs, rhs, layout) in tests.into_iter().rev() { + let test = specialize_equality(arena, lhs, rhs, layout); expr = Expr::CallByName( Symbol::BOOL_AND, arena.alloc([ diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index 76b9942277..5d05dab0f3 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -893,7 +893,7 @@ fn store_pattern<'a>( _ => { // store the field in a symbol, and continue matching on it let symbol = env.fresh_symbol(); - stored.push((symbol, layout.clone(), load)); + stored.push((symbol, arg_layout.clone(), load)); store_pattern(env, argument, symbol, arg_layout.clone(), stored)?; } @@ -1402,3 +1402,26 @@ fn from_can_record_destruct<'a>( }, } } + +pub fn specialize_equality<'a>( + arena: &'a Bump, + lhs: Expr<'a>, + rhs: Expr<'a>, + layout: Layout<'a>, +) -> Expr<'a> { + let symbol = match &layout { + Layout::Builtin(builtin) => match builtin { + Builtin::Int64 => Symbol::INT_EQ_I64, + Builtin::Float64 => Symbol::FLOAT_EQ, + Builtin::Byte => Symbol::INT_EQ_I8, + Builtin::Bool => Symbol::INT_EQ_I1, + other => todo!("Cannot yet compare for equality {:?}", other), + }, + other => todo!("Cannot yet compare for equality {:?}", other), + }; + + Expr::CallByName( + symbol, + arena.alloc([(lhs, layout.clone()), (rhs, layout.clone())]), + ) +}