diff --git a/crates/compiler/can/src/exhaustive.rs b/crates/compiler/can/src/exhaustive.rs index 534e7936c5..476fef14d1 100644 --- a/crates/compiler/can/src/exhaustive.rs +++ b/crates/compiler/can/src/exhaustive.rs @@ -1,5 +1,5 @@ use crate::expr::{self, IntValue, WhenBranch}; -use crate::pattern::{DestructType, ListPatterns}; +use crate::pattern::DestructType; use roc_collections::all::HumanIndex; use roc_collections::VecMap; use roc_error_macros::internal_error; @@ -344,23 +344,19 @@ fn sketch_pattern(pattern: &crate::pattern::Pattern) -> SketchedPattern { } List { - patterns: ListPatterns { patterns, opt_rest }, + patterns, list_var: _, elem_var: _, } => { - let list_arity = match opt_rest { - Some(i) => { - let before = *i; - let after = patterns.len() - before; - ListArity::Slice(before, after) - } - None => ListArity::Exact(patterns.len()), - }; + let arity = patterns.arity(); - let sketched_elem_patterns = - patterns.iter().map(|p| sketch_pattern(&p.value)).collect(); + let sketched_elem_patterns = patterns + .patterns + .iter() + .map(|p| sketch_pattern(&p.value)) + .collect(); - SP::List(list_arity, sketched_elem_patterns) + SP::List(arity, sketched_elem_patterns) } AppliedTag { diff --git a/crates/compiler/can/src/pattern.rs b/crates/compiler/can/src/pattern.rs index 6a4917a065..e2e96ee52d 100644 --- a/crates/compiler/can/src/pattern.rs +++ b/crates/compiler/can/src/pattern.rs @@ -6,6 +6,7 @@ use crate::num::{ ParsedNumResult, }; use crate::scope::{PendingAbilitiesInScope, Scope}; +use roc_exhaustive::ListArity; use roc_module::ident::{Ident, Lowercase, TagName}; use roc_module::symbol::Symbol; use roc_parse::ast::{self, StrLiteral, StrSegment}; @@ -189,6 +190,17 @@ impl ListPatterns { fn surely_exhaustive(&self) -> bool { self.patterns.is_empty() && matches!(self.opt_rest, Some(0)) } + + pub fn arity(&self) -> ListArity { + match self.opt_rest { + Some(i) => { + let before = i; + let after = self.patterns.len() - before; + ListArity::Slice(before, after) + } + None => ListArity::Exact(self.patterns.len()), + } + } } #[derive(Clone, Debug)] diff --git a/crates/compiler/exhaustive/src/lib.rs b/crates/compiler/exhaustive/src/lib.rs index ae397abdd1..c26e596080 100644 --- a/crates/compiler/exhaustive/src/lib.rs +++ b/crates/compiler/exhaustive/src/lib.rs @@ -93,7 +93,7 @@ impl ListArity { /// The trivially-exhaustive list pattern `[..]` const ANY: ListArity = ListArity::Slice(0, 0); - fn min_len(&self) -> usize { + pub fn min_len(&self) -> usize { match self { ListArity::Exact(n) => *n, ListArity::Slice(l, r) => l + r, @@ -102,22 +102,19 @@ impl ListArity { /// Could this list pattern include list pattern arity `other`? fn covers_arities_of(&self, other: &Self) -> bool { - match (self, other) { - (ListArity::Exact(l), ListArity::Exact(r)) => l == r, - (ListArity::Exact(this_exact), ListArity::Slice(other_left, other_right)) => { - // [_, _, _] can only cover [_, _, .., _] - *this_exact == (other_left + other_right) + self.covers_length(other.min_len()) + } + + pub fn covers_length(&self, length: usize) -> bool { + match self { + ListArity::Exact(l) => { + // [_, _, _] can only cover [_, _, _] + *l == length } - (ListArity::Slice(this_left, this_right), ListArity::Exact(other_exact)) => { - // [_, _, .., _] can cover [_, _, _], [_, _, _, _], [_, _, _, _, _], and so on - (this_left + this_right) <= *other_exact - } - ( - ListArity::Slice(this_left, this_right), - ListArity::Slice(other_left, other_right), - ) => { - // [_, _, .., _] can cover [_, _, .., _], [_, .., _, _], [_, _, .., _, _], [_, _, _, .., _, _], and so on - (this_left + this_right) <= (other_left + other_right) + ListArity::Slice(head, tail) => { + // [_, _, .., _] can cover infinite arities >=3 , including + // [_, _, .., _], [_, .., _, _], [_, _, .., _, _], [_, _, _, .., _, _], and so on + head + tail <= length } } } diff --git a/crates/compiler/mono/src/decision_tree.rs b/crates/compiler/mono/src/decision_tree.rs index 17a0746c11..5c90e396df 100644 --- a/crates/compiler/mono/src/decision_tree.rs +++ b/crates/compiler/mono/src/decision_tree.rs @@ -1,11 +1,12 @@ use crate::ir::{ - BranchInfo, DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern, Procs, Stmt, + build_list_index_probe, BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId, + ListIndex, Literal, Param, Pattern, Procs, Stmt, }; use crate::layout::{Builtin, Layout, LayoutCache, TagIdIntType, UnionLayout}; use roc_builtins::bitcode::{FloatWidth, IntWidth}; use roc_collections::all::{MutMap, MutSet}; use roc_error_macros::internal_error; -use roc_exhaustive::{Ctor, CtorName, RenderAs, TagId, Union}; +use roc_exhaustive::{Ctor, CtorName, ListArity, RenderAs, TagId, Union}; use roc_module::ident::TagName; use roc_module::low_level::LowLevel; use roc_module::symbol::Symbol; @@ -77,6 +78,12 @@ enum GuardedTest<'a> { Placeholder, } +#[derive(Clone, Copy, Debug, PartialEq, Hash)] +enum ListLenBound { + Exact, + AtLeast, +} + #[derive(Clone, Debug, PartialEq)] #[allow(clippy::enum_variant_names)] enum Test<'a> { @@ -95,6 +102,10 @@ enum Test<'a> { tag_id: TagIdIntType, num_alts: usize, }, + IsListLen { + bound: ListLenBound, + len: u64, + }, } impl<'a> Test<'a> { @@ -110,6 +121,10 @@ impl<'a> Test<'a> { Test::IsStr(_) => false, Test::IsBit(_) => true, Test::IsByte { .. } => true, + Test::IsListLen { bound, .. } => match bound { + ListLenBound::Exact => true, + ListLenBound::AtLeast => false, + }, } } } @@ -153,6 +168,10 @@ impl<'a> Hash for Test<'a> { state.write_u8(6); v.hash(state); } + IsListLen { len, bound } => { + state.write_u8(7); + (len, bound).hash(state); + } } } } @@ -331,6 +350,11 @@ fn tests_are_complete_help(last_test: &Test, number_of_tests: usize) -> bool { Test::IsFloat(_, _) => false, Test::IsDecimal(_) => false, Test::IsStr(_) => false, + Test::IsListLen { + bound: ListLenBound::AtLeast, + len: 0, + } => true, // [..] test + Test::IsListLen { .. } => false, } } @@ -578,6 +602,18 @@ fn test_at_path<'a>( arguments: arguments.to_vec(), }, + List { + arity, + element_layout: _, + elements: _, + } => IsListLen { + bound: match arity { + ListArity::Exact(_) => ListLenBound::Exact, + ListArity::Slice(_, _) => ListLenBound::AtLeast, + }, + len: arity.min_len() as _, + }, + Voided { .. } => internal_error!("unreachable"), OpaqueUnwrap { opaque, argument } => { @@ -755,6 +791,38 @@ fn to_relevant_branch_help<'a>( _ => None, }, + List { + arity: my_arity, + elements, + element_layout: _, + } => match test { + IsListLen { bound: _, len } if my_arity.covers_length(*len as _) => { + let sub_positions = elements.into_iter().enumerate().map(|(index, elem_pat)| { + let mut new_path = path.to_vec(); + + let probe_index = ListIndex::from_pattern_index(index, my_arity); + + let next_instr = PathInstruction::ListIndex { + // TODO index into back as well + index: probe_index as _, + }; + new_path.push(next_instr); + + (new_path, elem_pat) + }); + start.extend(sub_positions); + start.extend(end); + + Some(Branch { + goal: branch.goal, + guard: branch.guard.clone(), + patterns: start, + }) + } + + _ => None, + }, + NewtypeDestructure { tag_name, arguments, @@ -1021,7 +1089,8 @@ fn needs_tests(pattern: &Pattern) -> bool { | IntLiteral(_, _) | FloatLiteral(_, _) | DecimalLiteral(_) - | StrLiteral(_) => true, + | StrLiteral(_) + | List { .. } => true, Voided { .. } => internal_error!("unreachable"), } @@ -1268,6 +1337,7 @@ pub fn optimize_when<'a>( enum PathInstruction { NewType, TagIndex { index: u64, tag_id: TagIdIntType }, + ListIndex { index: ListIndex }, } fn path_to_expr_help<'a>( @@ -1337,19 +1407,46 @@ fn path_to_expr_help<'a>( } } } + + PathInstruction::ListIndex { index } => { + let list_sym = symbol; + + match layout { + Layout::Builtin(Builtin::List(elem_layout)) => { + let (index_sym, new_stores) = build_list_index_probe(env, list_sym, index); + + stores.extend(new_stores); + + let load_sym = env.unique_symbol(); + let load_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::ListGetUnsafe, + update_mode: env.next_update_mode_id(), + }, + arguments: env.arena.alloc([list_sym, index_sym]), + }); + + stores.push((load_sym, *elem_layout, load_expr)); + + layout = *elem_layout; + symbol = load_sym; + } + _ => internal_error!("not a list"), + } + } } } (symbol, stores, layout) } -fn test_to_equality<'a>( +fn test_to_comparison<'a>( env: &mut Env<'a, '_>, cond_symbol: Symbol, cond_layout: &Layout<'a>, path: &[PathInstruction], test: Test<'a>, -) -> (StoresVec<'a>, Symbol, Symbol, Option>) { +) -> (StoresVec<'a>, Comparison, Option>) { let (rhs_symbol, mut stores, test_layout) = path_to_expr_help(env, cond_symbol, path, *cond_layout); @@ -1379,8 +1476,7 @@ fn test_to_equality<'a>( ( stores, - lhs_symbol, - rhs_symbol, + (lhs_symbol, Comparator::Eq, rhs_symbol), Some(ConstructorKnown::OnlyPass { scrutinee: path_symbol, layout: *cond_layout, @@ -1397,7 +1493,7 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); stores.push((lhs_symbol, Layout::int_width(precision), lhs)); - (stores, lhs_symbol, rhs_symbol, None) + (stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None) } Test::IsFloat(test_int, precision) => { @@ -1407,7 +1503,7 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); stores.push((lhs_symbol, Layout::float_width(precision), lhs)); - (stores, lhs_symbol, rhs_symbol, None) + (stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None) } Test::IsDecimal(test_dec) => { @@ -1415,7 +1511,7 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); stores.push((lhs_symbol, *cond_layout, lhs)); - (stores, lhs_symbol, rhs_symbol, None) + (stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None) } Test::IsByte { @@ -1427,7 +1523,7 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); stores.push((lhs_symbol, Layout::u8(), lhs)); - (stores, lhs_symbol, rhs_symbol, None) + (stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None) } Test::IsBit(test_bit) => { @@ -1435,7 +1531,7 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); stores.push((lhs_symbol, Layout::Builtin(Builtin::Bool), lhs)); - (stores, lhs_symbol, rhs_symbol, None) + (stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None) } Test::IsStr(test_str) => { @@ -1444,15 +1540,58 @@ fn test_to_equality<'a>( stores.push((lhs_symbol, Layout::Builtin(Builtin::Str), lhs)); - (stores, lhs_symbol, rhs_symbol, None) + (stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None) + } + + Test::IsListLen { bound, len } => { + let list_layout = test_layout; + let list_sym = rhs_symbol; + + match list_layout { + Layout::Builtin(Builtin::List(_elem_layout)) => { + let real_len_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::ListLen, + update_mode: env.next_update_mode_id(), + }, + arguments: env.arena.alloc([list_sym]), + }); + let test_len_expr = Expr::Literal(Literal::Int((len as i128).to_ne_bytes())); + + let real_len = env.unique_symbol(); + let test_len = env.unique_symbol(); + + let usize_layout = Layout::usize(env.target_info); + + stores.push((real_len, usize_layout, real_len_expr)); + stores.push((test_len, usize_layout, test_len_expr)); + + let comparison = match bound { + ListLenBound::Exact => (real_len, Comparator::Eq, test_len), + ListLenBound::AtLeast => (real_len, Comparator::Geq, test_len), + }; + + (stores, comparison, None) + } + _ => internal_error!( + "test path is not a list: {:#?}", + (cond_layout, test_layout, path) + ), + } } } } +enum Comparator { + Eq, + Geq, +} + +type Comparison = (Symbol, Comparator, Symbol); + type Tests<'a> = std::vec::Vec<( bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>, - Symbol, - Symbol, + Comparison, Option>, )>; @@ -1466,17 +1605,25 @@ fn stores_and_condition<'a>( // Assumption: there is at most 1 guard, and it is the outer layer. for (path, test) in test_chain { - tests.push(test_to_equality(env, cond_symbol, cond_layout, &path, test)) + tests.push(test_to_comparison( + env, + cond_symbol, + cond_layout, + &path, + test, + )) } tests } +#[allow(clippy::too_many_arguments)] fn compile_test<'a>( env: &mut Env<'a, '_>, ret_layout: Layout<'a>, stores: bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>, lhs: Symbol, + cmp: Comparator, rhs: Symbol, fail: &'a Stmt<'a>, cond: Stmt<'a>, @@ -1487,6 +1634,7 @@ fn compile_test<'a>( ret_layout, stores, lhs, + cmp, rhs, fail, cond, @@ -1500,6 +1648,7 @@ fn compile_test_help<'a>( ret_layout: Layout<'a>, stores: bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>, lhs: Symbol, + cmp: Comparator, rhs: Symbol, fail: &'a Stmt<'a>, mut cond: Stmt<'a>, @@ -1560,7 +1709,10 @@ fn compile_test_help<'a>( default_branch, }; - let op = LowLevel::Eq; + let op = match cmp { + Comparator::Eq => LowLevel::Eq, + Comparator::Geq => LowLevel::NumGte, + }; let test = Expr::Call(crate::ir::Call { call_type: crate::ir::CallType::LowLevel { op, @@ -1592,13 +1744,15 @@ fn compile_tests<'a>( fail: &'a Stmt<'a>, mut cond: Stmt<'a>, ) -> Stmt<'a> { - for (new_stores, lhs, rhs, opt_constructor_info) in tests.into_iter() { + for (new_stores, (lhs, cmp, rhs), opt_constructor_info) in tests.into_iter() { match opt_constructor_info { None => { - cond = compile_test(env, ret_layout, new_stores, lhs, rhs, fail, cond); + cond = compile_test(env, ret_layout, new_stores, lhs, cmp, rhs, fail, cond); } Some(cinfo) => { - cond = compile_test_help(env, cinfo, ret_layout, new_stores, lhs, rhs, fail, cond); + cond = compile_test_help( + env, cinfo, ret_layout, new_stores, lhs, cmp, rhs, fail, cond, + ); } } } @@ -1781,7 +1935,7 @@ fn decide_to_branching<'a>( if number_of_tests == 1 { // if there is just one test, compile to a simple if-then-else - let (new_stores, lhs, rhs, _cinfo) = tests.into_iter().next().unwrap(); + let (new_stores, (lhs, cmp, rhs), _cinfo) = tests.into_iter().next().unwrap(); compile_test_help( env, @@ -1789,6 +1943,7 @@ fn decide_to_branching<'a>( ret_layout, new_stores, lhs, + cmp, rhs, fail, pass_expr, @@ -1854,6 +2009,12 @@ fn decide_to_branching<'a>( Test::IsBit(v) => v as u64, Test::IsByte { tag_id, .. } => tag_id as u64, Test::IsCtor { tag_id, .. } => tag_id as u64, + Test::IsListLen { len, bound } => match bound { + ListLenBound::Exact => len as _, + ListLenBound::AtLeast => { + unreachable!("at-least bounds cannot be switched on") + } + }, Test::IsDecimal(_) => unreachable!("decimals cannot be switched on"), Test::IsStr(_) => unreachable!("strings cannot be switched on"), }; @@ -1911,6 +2072,31 @@ fn decide_to_branching<'a>( union_layout.tag_id_layout(), env.arena.alloc(temp), ) + } else if let Layout::Builtin(Builtin::List(_)) = inner_cond_layout { + let len_symbol = env.unique_symbol(); + + let switch = Stmt::Switch { + cond_layout: Layout::usize(env.target_info), + cond_symbol: len_symbol, + branches: branches.into_bump_slice(), + default_branch: (default_branch_info, env.arena.alloc(default_branch)), + ret_layout, + }; + + let len_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::ListLen, + update_mode: env.next_update_mode_id(), + }, + arguments: env.arena.alloc([inner_cond_symbol]), + }); + + Stmt::Let( + len_symbol, + len_expr, + Layout::usize(env.target_info), + env.arena.alloc(switch), + ) } else { Stmt::Switch { cond_layout: inner_cond_layout, diff --git a/crates/compiler/mono/src/ir.rs b/crates/compiler/mono/src/ir.rs index 586722d435..01f4984e32 100644 --- a/crates/compiler/mono/src/ir.rs +++ b/crates/compiler/mono/src/ir.rs @@ -21,7 +21,7 @@ use roc_debug_flags::{ }; use roc_derive::SharedDerivedModule; use roc_error_macros::{internal_error, todo_abilities}; -use roc_exhaustive::{Ctor, CtorName, RenderAs, TagId}; +use roc_exhaustive::{Ctor, CtorName, ListArity, RenderAs, TagId}; use roc_intern::Interner; use roc_late_solve::storage::{ExternalModuleStorage, ExternalModuleStorageSnapshot}; use roc_late_solve::{resolve_ability_specialization, AbilitiesView, Resolved, UnificationFailed}; @@ -7289,6 +7289,24 @@ fn store_pattern_help<'a>( stmt, ); } + + List { + arity, + element_layout, + elements, + } => { + return store_list_pattern( + env, + procs, + layout_cache, + outer_symbol, + *arity, + *element_layout, + elements, + stmt, + ) + } + Voided { .. } => { return StorePattern::NotProductive(stmt); } @@ -7361,6 +7379,199 @@ fn store_pattern_help<'a>( StorePattern::Productive(stmt) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct ListIndex( + /// Positive if we should index from the head, negative if we should index from the tail + /// 0 is lst[0] + /// -1 is lst[List.len lst - 1] + i64, +); + +impl ListIndex { + pub fn from_pattern_index(index: usize, arity: ListArity) -> Self { + match arity { + ListArity::Exact(_) => ListIndex::nth_head(index as _), + ListArity::Slice(head, tail) => { + if index < head { + ListIndex::nth_head(index as _) + } else { + // Slice(2, 6) + // + // s t ... w y z x q + // 0 1 2 3 4 5 6 index + // 0 1 2 3 4 (index - head) + // 4 3 2 1 0 tail - (index - head) + ListIndex::nth_tail((tail - (index - head)) as _) + } + } + } + } + + fn nth_head(offset: u64) -> Self { + Self(offset as _) + } + + fn nth_tail(offset: u64) -> Self { + let offset = offset as i64; + Self(-1 - offset) + } +} + +pub(crate) type Store<'a> = (Symbol, Layout<'a>, Expr<'a>); + +/// Builds the list index we should index into +#[must_use] +pub(crate) fn build_list_index_probe<'a>( + env: &mut Env<'a, '_>, + list_sym: Symbol, + list_index: &ListIndex, +) -> (Symbol, impl DoubleEndedIterator>) { + let usize_layout = Layout::usize(env.target_info); + + let list_index = list_index.0; + let index_sym = env.unique_symbol(); + + let (opt_len_store, opt_offset_store, index_store) = if list_index >= 0 { + let index_expr = Expr::Literal(Literal::Int((list_index as i128).to_ne_bytes())); + + let index_store = (index_sym, usize_layout, index_expr); + + (None, None, index_store) + } else { + let len_sym = env.unique_symbol(); + let len_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::ListLen, + update_mode: env.next_update_mode_id(), + }, + arguments: env.arena.alloc([list_sym]), + }); + + let offset = (list_index + 1).abs(); + let offset_sym = env.unique_symbol(); + let offset_expr = Expr::Literal(Literal::Int((offset as i128).to_ne_bytes())); + + let index_expr = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::NumSub, + update_mode: env.next_update_mode_id(), + }, + arguments: env.arena.alloc([len_sym, offset_sym]), + }); + + let len_store = (len_sym, usize_layout, len_expr); + let offset_store = (offset_sym, usize_layout, offset_expr); + let index_store = (index_sym, usize_layout, index_expr); + + (Some(len_store), Some(offset_store), index_store) + }; + + let stores = (opt_len_store.into_iter()) + .chain(opt_offset_store) + .chain([index_store]); + + (index_sym, stores) +} + +#[allow(clippy::too_many_arguments)] +fn store_list_pattern<'a>( + env: &mut Env<'a, '_>, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, + list_sym: Symbol, + list_arity: ListArity, + element_layout: Layout<'a>, + elements: &[Pattern<'a>], + mut stmt: Stmt<'a>, +) -> StorePattern<'a> { + use Pattern::*; + + let mut is_productive = false; + + for (index, element) in elements.iter().enumerate().rev() { + let compute_element_load = |env: &mut Env<'a, '_>| { + let list_index = ListIndex::from_pattern_index(index, list_arity); + + let (index_sym, needed_stores) = build_list_index_probe(env, list_sym, &list_index); + + let load = Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::ListGetUnsafe, + update_mode: env.next_update_mode_id(), + }, + arguments: env.arena.alloc([list_sym, index_sym]), + }); + + (load, needed_stores) + }; + + let (store_loaded, needed_stores) = match element { + Identifier(symbol) => { + let (load, needed_stores) = compute_element_load(env); + + // Pattern can define only one specialization + let symbol = procs + .symbol_specializations + .remove_single(*symbol) + .unwrap_or(*symbol); + + // store immediately in the given symbol + ( + Stmt::Let(symbol, load, element_layout, env.arena.alloc(stmt)), + needed_stores, + ) + } + Underscore + | IntLiteral(_, _) + | FloatLiteral(_, _) + | DecimalLiteral(_) + | EnumLiteral { .. } + | BitLiteral { .. } + | StrLiteral(_) => { + // ignore + continue; + } + _ => { + // store the field in a symbol, and continue matching on it + let symbol = env.unique_symbol(); + + // first recurse, continuing to unpack symbol + match store_pattern_help(env, procs, layout_cache, element, symbol, stmt) { + StorePattern::Productive(new) => { + stmt = new; + let (load, needed_stores) = compute_element_load(env); + + // only if we bind one of its (sub)fields to a used name should we + // extract the field + ( + Stmt::Let(symbol, load, element_layout, env.arena.alloc(stmt)), + needed_stores, + ) + } + StorePattern::NotProductive(new) => { + // do nothing + stmt = new; + continue; + } + } + } + }; + + is_productive = true; + + stmt = store_loaded; + for (sym, lay, expr) in needed_stores.rev() { + stmt = Stmt::Let(sym, expr, lay, env.arena.alloc(stmt)); + } + } + + if is_productive { + StorePattern::Productive(stmt) + } else { + StorePattern::NotProductive(stmt) + } +} + #[allow(clippy::too_many_arguments)] fn store_tag_pattern<'a>( env: &mut Env<'a, '_>, @@ -8922,6 +9133,11 @@ pub enum Pattern<'a> { opaque: Symbol, argument: Box<(Pattern<'a>, Layout<'a>)>, }, + List { + arity: ListArity, + element_layout: Layout<'a>, + elements: Vec<'a, Pattern<'a>>, + }, } impl<'a> Pattern<'a> { @@ -8958,6 +9174,7 @@ impl<'a> Pattern<'a> { stack.extend(arguments.iter().map(|(t, _)| t)) } Pattern::OpaqueUnwrap { argument, .. } => stack.push(&argument.0), + Pattern::List { elements, .. } => stack.extend(elements), } } @@ -9709,7 +9926,34 @@ fn from_can_pattern_help<'a>( )) } - List { .. } => todo!(), + List { + list_var: _, + elem_var, + patterns, + } => { + let element_layout = match layout_cache.from_var(env.arena, *elem_var, env.subs) { + Ok(lay) => lay, + Err(LayoutProblem::UnresolvedTypeVar(_)) => { + return Err(RuntimeError::UnresolvedTypeVar) + } + Err(LayoutProblem::Erroneous) => return Err(RuntimeError::ErroneousType), + }; + + let arity = patterns.arity(); + + let mut mono_patterns = Vec::with_capacity_in(patterns.patterns.len(), env.arena); + for loc_pat in patterns.patterns.iter() { + let mono_pat = + from_can_pattern_help(env, procs, layout_cache, &loc_pat.value, assignments)?; + mono_patterns.push(mono_pat); + } + + Ok(Pattern::List { + arity, + element_layout, + elements: mono_patterns, + }) + } } } diff --git a/crates/compiler/test_gen/src/gen_list.rs b/crates/compiler/test_gen/src/gen_list.rs index 5c619b0cae..55c7a4cf38 100644 --- a/crates/compiler/test_gen/src/gen_list.rs +++ b/crates/compiler/test_gen/src/gen_list.rs @@ -3578,3 +3578,135 @@ fn list_walk_from_even_prefix_sum() { i64 ); } + +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] +mod pattern_match { + #[cfg(feature = "gen-llvm")] + use crate::helpers::llvm::assert_evals_to; + + #[cfg(feature = "gen-wasm")] + use crate::helpers::wasm::assert_evals_to; + + use super::RocList; + + #[test] + fn unary_exact_size_match() { + assert_evals_to!( + r#" + helper = \l -> when l is + [] -> 1u8 + _ -> 2u8 + + [ helper [], helper [{}] ] + "#, + RocList::from_slice(&[1, 2]), + RocList + ) + } + + #[test] + fn many_exact_size_match() { + assert_evals_to!( + r#" + helper = \l -> when l is + [] -> 1u8 + [_] -> 2u8 + [_, _] -> 3u8 + [_, _, _] -> 4u8 + _ -> 5u8 + + [ helper [], helper [{}], helper [{}, {}], helper [{}, {}, {}], helper [{}, {}, {}, {}] ] + "#, + RocList::from_slice(&[1, 2, 3, 4, 5]), + RocList + ) + } + + #[test] + fn ranged_matches_head() { + assert_evals_to!( + r#" + helper = \l -> when l is + [] -> 1u8 + [A] -> 2u8 + [A, A, ..] -> 3u8 + [A, B, ..] -> 4u8 + [B, ..] -> 5u8 + + [ + helper [], + helper [A], + helper [A, A], helper [A, A, A], helper [A, A, B], helper [A, A, B, A], + helper [A, B], helper [A, B, A], helper [A, B, B], helper [A, B, A, B], + helper [B], helper [B, A], helper [B, B], helper [B, A, B, B], + ] + "#, + RocList::from_slice(&[ + 1, // + 2, // + 3, 3, 3, 3, // + 4, 4, 4, 4, // + 5, 5, 5, 5, // + ]), + RocList + ) + } + + #[test] + fn ranged_matches_tail() { + assert_evals_to!( + r#" + helper = \l -> when l is + [] -> 1u8 + [A] -> 2u8 + [.., A, A] -> 3u8 + [.., B, A] -> 4u8 + [.., B] -> 5u8 + + [ + helper [], + helper [A], + helper [A, A], helper [A, A, A], helper [B, A, A], helper [A, B, A, A], + helper [B, A], helper [A, B, A], helper [B, B, A], helper [B, A, B, A], + helper [B], helper [A, B], helper [B, B], helper [B, A, B, B], + ] + "#, + RocList::from_slice(&[ + 1, // + 2, // + 3, 3, 3, 3, // + 4, 4, 4, 4, // + 5, 5, 5, 5, // + ]), + RocList + ) + } + + #[test] + fn bind_variables() { + assert_evals_to!( + r#" + helper : List U16 -> U16 + helper = \l -> when l is + [] -> 1 + [x] -> x + [.., w, x, y, z] -> w * x * y * z + [x, y, ..] -> x * y + + [ + helper [], + helper [5], + helper [3, 5], helper [3, 5, 7], + helper [2, 3, 5, 7], helper [11, 2, 3, 5, 7], helper [13, 11, 2, 3, 5, 7], + ] + "#, + RocList::from_slice(&[ + 1, // + 5, // + 15, 15, // + 210, 210, 210, // + ]), + RocList + ) + } +} diff --git a/crates/compiler/test_mono/generated/match_list.txt b/crates/compiler/test_mono/generated/match_list.txt new file mode 100644 index 0000000000..3d5dd6dc6c --- /dev/null +++ b/crates/compiler/test_mono/generated/match_list.txt @@ -0,0 +1,67 @@ +procedure Test.0 (): + let Test.36 : Int1 = false; + let Test.37 : Int1 = true; + let Test.1 : List Int1 = Array [Test.36, Test.37]; + joinpoint Test.10: + let Test.8 : Str = "E"; + ret Test.8; + in + joinpoint Test.9: + let Test.5 : Str = "B"; + ret Test.5; + in + let Test.33 : U64 = lowlevel ListLen Test.1; + let Test.34 : U64 = 0i64; + let Test.35 : Int1 = lowlevel Eq Test.33 Test.34; + if Test.35 then + dec Test.1; + let Test.4 : Str = "A"; + ret Test.4; + else + let Test.30 : U64 = lowlevel ListLen Test.1; + let Test.31 : U64 = 1i64; + let Test.32 : Int1 = lowlevel Eq Test.30 Test.31; + if Test.32 then + let Test.11 : U64 = 0i64; + let Test.12 : Int1 = lowlevel ListGetUnsafe Test.1 Test.11; + dec Test.1; + let Test.13 : Int1 = false; + let Test.14 : Int1 = lowlevel Eq Test.13 Test.12; + if Test.14 then + jump Test.9; + else + jump Test.10; + else + let Test.27 : U64 = lowlevel ListLen Test.1; + let Test.28 : U64 = 2i64; + let Test.29 : Int1 = lowlevel NumGte Test.27 Test.28; + if Test.29 then + let Test.19 : U64 = 0i64; + let Test.20 : Int1 = lowlevel ListGetUnsafe Test.1 Test.19; + let Test.21 : Int1 = false; + let Test.22 : Int1 = lowlevel Eq Test.21 Test.20; + if Test.22 then + let Test.15 : U64 = 1i64; + let Test.16 : Int1 = lowlevel ListGetUnsafe Test.1 Test.15; + dec Test.1; + let Test.17 : Int1 = false; + let Test.18 : Int1 = lowlevel Eq Test.17 Test.16; + if Test.18 then + let Test.6 : Str = "C"; + ret Test.6; + else + let Test.7 : Str = "D"; + ret Test.7; + else + dec Test.1; + jump Test.10; + else + let Test.23 : U64 = 0i64; + let Test.24 : Int1 = lowlevel ListGetUnsafe Test.1 Test.23; + dec Test.1; + let Test.25 : Int1 = false; + let Test.26 : Int1 = lowlevel Eq Test.25 Test.24; + if Test.26 then + jump Test.9; + else + jump Test.10; diff --git a/crates/compiler/test_mono/src/tests.rs b/crates/compiler/test_mono/src/tests.rs index 0ecb7cd27e..4436c815f7 100644 --- a/crates/compiler/test_mono/src/tests.rs +++ b/crates/compiler/test_mono/src/tests.rs @@ -2000,3 +2000,19 @@ fn unreachable_branch_is_eliminated_but_produces_lambda_specializations() { "# ) } + +#[mono_test] +fn match_list() { + indoc!( + r#" + l = [A, B] + + when l is + [] -> "A" + [A] -> "B" + [A, A, ..] -> "C" + [A, B, ..] -> "D" + [B, ..] -> "E" + "# + ) +}