mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 21:39:07 +00:00
Decision tree compilation of list patterns
This commit is contained in:
parent
da1d937277
commit
ae71c7efe2
7 changed files with 479 additions and 38 deletions
|
@ -1,11 +1,12 @@
|
|||
use crate::ir::{
|
||||
BranchInfo, DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern, Procs, Stmt,
|
||||
BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern,
|
||||
Procs, Stmt,
|
||||
};
|
||||
use crate::layout::{Builtin, Layout, LayoutCache, TagIdIntType, UnionLayout};
|
||||
use roc_builtins::bitcode::{FloatWidth, IntWidth};
|
||||
use roc_collections::all::{MutMap, MutSet};
|
||||
use roc_error_macros::internal_error;
|
||||
use roc_exhaustive::{Ctor, CtorName, RenderAs, TagId, Union};
|
||||
use roc_exhaustive::{Ctor, CtorName, ListArity, RenderAs, TagId, Union};
|
||||
use roc_module::ident::TagName;
|
||||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::symbol::Symbol;
|
||||
|
@ -77,6 +78,12 @@ enum GuardedTest<'a> {
|
|||
Placeholder,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Hash)]
|
||||
enum ListLenBound {
|
||||
Exact,
|
||||
AtLeast,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
enum Test<'a> {
|
||||
|
@ -95,6 +102,10 @@ enum Test<'a> {
|
|||
tag_id: TagIdIntType,
|
||||
num_alts: usize,
|
||||
},
|
||||
IsListLen {
|
||||
bound: ListLenBound,
|
||||
len: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a> Test<'a> {
|
||||
|
@ -110,6 +121,10 @@ impl<'a> Test<'a> {
|
|||
Test::IsStr(_) => false,
|
||||
Test::IsBit(_) => true,
|
||||
Test::IsByte { .. } => true,
|
||||
Test::IsListLen { bound, .. } => match bound {
|
||||
ListLenBound::Exact => true,
|
||||
ListLenBound::AtLeast => false,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -153,6 +168,10 @@ impl<'a> Hash for Test<'a> {
|
|||
state.write_u8(6);
|
||||
v.hash(state);
|
||||
}
|
||||
IsListLen { len, bound } => {
|
||||
state.write_u8(7);
|
||||
(len, bound).hash(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -331,6 +350,11 @@ fn tests_are_complete_help(last_test: &Test, number_of_tests: usize) -> bool {
|
|||
Test::IsFloat(_, _) => false,
|
||||
Test::IsDecimal(_) => false,
|
||||
Test::IsStr(_) => false,
|
||||
Test::IsListLen {
|
||||
bound: ListLenBound::AtLeast,
|
||||
len: 0,
|
||||
} => true, // [..] test
|
||||
Test::IsListLen { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -578,6 +602,18 @@ fn test_at_path<'a>(
|
|||
arguments: arguments.to_vec(),
|
||||
},
|
||||
|
||||
List {
|
||||
arity,
|
||||
element_layout: _,
|
||||
elements: _,
|
||||
} => IsListLen {
|
||||
bound: match arity {
|
||||
ListArity::Exact(_) => ListLenBound::Exact,
|
||||
ListArity::Slice(_, _) => ListLenBound::AtLeast,
|
||||
},
|
||||
len: arity.min_len() as _,
|
||||
},
|
||||
|
||||
Voided { .. } => internal_error!("unreachable"),
|
||||
|
||||
OpaqueUnwrap { opaque, argument } => {
|
||||
|
@ -755,6 +791,50 @@ fn to_relevant_branch_help<'a>(
|
|||
_ => None,
|
||||
},
|
||||
|
||||
List {
|
||||
arity: my_arity,
|
||||
elements,
|
||||
element_layout: _,
|
||||
} => match test {
|
||||
IsListLen { bound, len }
|
||||
if (match (bound, my_arity) {
|
||||
(ListLenBound::Exact, ListArity::Exact(my_len)) => my_len == *len as _,
|
||||
(ListLenBound::Exact, ListArity::Slice(my_head, my_tail)) => {
|
||||
my_head + my_tail <= *len as _
|
||||
}
|
||||
(ListLenBound::AtLeast, ListArity::Exact(my_len)) => my_len == *len as _,
|
||||
(ListLenBound::AtLeast, ListArity::Slice(my_head, my_tail)) => {
|
||||
my_head + my_tail <= *len as _
|
||||
}
|
||||
}) =>
|
||||
{
|
||||
if matches!(my_arity, ListArity::Slice(_, n) if n > 0) {
|
||||
todo!();
|
||||
}
|
||||
|
||||
let sub_positions = elements.into_iter().enumerate().map(|(index, elem_pat)| {
|
||||
let mut new_path = path.to_vec();
|
||||
let next_instr = PathInstruction::ListIndex {
|
||||
// TODO index into back as well
|
||||
index: index as _,
|
||||
};
|
||||
new_path.push(next_instr);
|
||||
|
||||
(new_path, elem_pat)
|
||||
});
|
||||
start.extend(sub_positions);
|
||||
start.extend(end);
|
||||
|
||||
Some(Branch {
|
||||
goal: branch.goal,
|
||||
guard: branch.guard.clone(),
|
||||
patterns: start,
|
||||
})
|
||||
}
|
||||
|
||||
_ => None,
|
||||
},
|
||||
|
||||
NewtypeDestructure {
|
||||
tag_name,
|
||||
arguments,
|
||||
|
@ -1021,7 +1101,8 @@ fn needs_tests(pattern: &Pattern) -> bool {
|
|||
| IntLiteral(_, _)
|
||||
| FloatLiteral(_, _)
|
||||
| DecimalLiteral(_)
|
||||
| StrLiteral(_) => true,
|
||||
| StrLiteral(_)
|
||||
| List { .. } => true,
|
||||
|
||||
Voided { .. } => internal_error!("unreachable"),
|
||||
}
|
||||
|
@ -1267,7 +1348,15 @@ pub fn optimize_when<'a>(
|
|||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum PathInstruction {
|
||||
NewType,
|
||||
TagIndex { index: u64, tag_id: TagIdIntType },
|
||||
TagIndex {
|
||||
index: u64,
|
||||
tag_id: TagIdIntType,
|
||||
},
|
||||
ListIndex {
|
||||
// Positive if it should be indexed from the front, negative otherwise
|
||||
// (-1 means the last index)
|
||||
index: i64,
|
||||
},
|
||||
}
|
||||
|
||||
fn path_to_expr_help<'a>(
|
||||
|
@ -1337,19 +1426,51 @@ fn path_to_expr_help<'a>(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
PathInstruction::ListIndex { index } => {
|
||||
let list_sym = symbol;
|
||||
let usize_layout = Layout::usize(env.target_info);
|
||||
|
||||
if index < &0 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
match layout {
|
||||
Layout::Builtin(Builtin::List(elem_layout)) => {
|
||||
let index_sym = env.unique_symbol();
|
||||
let index_expr =
|
||||
Expr::Literal(Literal::Int((*index as i128).to_ne_bytes()));
|
||||
|
||||
let load_sym = env.unique_symbol();
|
||||
let load_expr = Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::ListGetUnsafe,
|
||||
update_mode: env.next_update_mode_id(),
|
||||
},
|
||||
arguments: env.arena.alloc([list_sym, index_sym]),
|
||||
});
|
||||
|
||||
stores.push((index_sym, usize_layout, index_expr));
|
||||
stores.push((load_sym, *elem_layout, load_expr));
|
||||
|
||||
layout = *elem_layout;
|
||||
}
|
||||
_ => internal_error!("not a list"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(symbol, stores, layout)
|
||||
}
|
||||
|
||||
fn test_to_equality<'a>(
|
||||
fn test_to_comparison<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
cond_symbol: Symbol,
|
||||
cond_layout: &Layout<'a>,
|
||||
path: &[PathInstruction],
|
||||
test: Test<'a>,
|
||||
) -> (StoresVec<'a>, Symbol, Symbol, Option<ConstructorKnown<'a>>) {
|
||||
) -> (StoresVec<'a>, Comparison, Option<ConstructorKnown<'a>>) {
|
||||
let (rhs_symbol, mut stores, test_layout) =
|
||||
path_to_expr_help(env, cond_symbol, path, *cond_layout);
|
||||
|
||||
|
@ -1379,8 +1500,7 @@ fn test_to_equality<'a>(
|
|||
|
||||
(
|
||||
stores,
|
||||
lhs_symbol,
|
||||
rhs_symbol,
|
||||
(lhs_symbol, Comparator::Eq, rhs_symbol),
|
||||
Some(ConstructorKnown::OnlyPass {
|
||||
scrutinee: path_symbol,
|
||||
layout: *cond_layout,
|
||||
|
@ -1397,7 +1517,7 @@ fn test_to_equality<'a>(
|
|||
let lhs_symbol = env.unique_symbol();
|
||||
stores.push((lhs_symbol, Layout::int_width(precision), lhs));
|
||||
|
||||
(stores, lhs_symbol, rhs_symbol, None)
|
||||
(stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None)
|
||||
}
|
||||
|
||||
Test::IsFloat(test_int, precision) => {
|
||||
|
@ -1407,7 +1527,7 @@ fn test_to_equality<'a>(
|
|||
let lhs_symbol = env.unique_symbol();
|
||||
stores.push((lhs_symbol, Layout::float_width(precision), lhs));
|
||||
|
||||
(stores, lhs_symbol, rhs_symbol, None)
|
||||
(stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None)
|
||||
}
|
||||
|
||||
Test::IsDecimal(test_dec) => {
|
||||
|
@ -1415,7 +1535,7 @@ fn test_to_equality<'a>(
|
|||
let lhs_symbol = env.unique_symbol();
|
||||
stores.push((lhs_symbol, *cond_layout, lhs));
|
||||
|
||||
(stores, lhs_symbol, rhs_symbol, None)
|
||||
(stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None)
|
||||
}
|
||||
|
||||
Test::IsByte {
|
||||
|
@ -1427,7 +1547,7 @@ fn test_to_equality<'a>(
|
|||
let lhs_symbol = env.unique_symbol();
|
||||
stores.push((lhs_symbol, Layout::u8(), lhs));
|
||||
|
||||
(stores, lhs_symbol, rhs_symbol, None)
|
||||
(stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None)
|
||||
}
|
||||
|
||||
Test::IsBit(test_bit) => {
|
||||
|
@ -1435,7 +1555,7 @@ fn test_to_equality<'a>(
|
|||
let lhs_symbol = env.unique_symbol();
|
||||
stores.push((lhs_symbol, Layout::Builtin(Builtin::Bool), lhs));
|
||||
|
||||
(stores, lhs_symbol, rhs_symbol, None)
|
||||
(stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None)
|
||||
}
|
||||
|
||||
Test::IsStr(test_str) => {
|
||||
|
@ -1444,15 +1564,58 @@ fn test_to_equality<'a>(
|
|||
|
||||
stores.push((lhs_symbol, Layout::Builtin(Builtin::Str), lhs));
|
||||
|
||||
(stores, lhs_symbol, rhs_symbol, None)
|
||||
(stores, (lhs_symbol, Comparator::Eq, rhs_symbol), None)
|
||||
}
|
||||
|
||||
Test::IsListLen { bound, len } => {
|
||||
let list_layout = test_layout;
|
||||
let list_sym = rhs_symbol;
|
||||
|
||||
match list_layout {
|
||||
Layout::Builtin(Builtin::List(_elem_layout)) => {
|
||||
let real_len_expr = Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::ListLen,
|
||||
update_mode: env.next_update_mode_id(),
|
||||
},
|
||||
arguments: env.arena.alloc([list_sym]),
|
||||
});
|
||||
let test_len_expr = Expr::Literal(Literal::Int((len as i128).to_ne_bytes()));
|
||||
|
||||
let real_len = env.unique_symbol();
|
||||
let test_len = env.unique_symbol();
|
||||
|
||||
let usize_layout = Layout::usize(env.target_info);
|
||||
|
||||
stores.push((real_len, usize_layout, real_len_expr));
|
||||
stores.push((test_len, usize_layout, test_len_expr));
|
||||
|
||||
let comparison = match bound {
|
||||
ListLenBound::Exact => (real_len, Comparator::Eq, test_len),
|
||||
ListLenBound::AtLeast => (real_len, Comparator::Geq, test_len),
|
||||
};
|
||||
|
||||
(stores, comparison, None)
|
||||
}
|
||||
_ => internal_error!(
|
||||
"test path is not a list: {:#?}",
|
||||
(cond_layout, test_layout, path)
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum Comparator {
|
||||
Eq,
|
||||
Geq,
|
||||
}
|
||||
|
||||
type Comparison = (Symbol, Comparator, Symbol);
|
||||
|
||||
type Tests<'a> = std::vec::Vec<(
|
||||
bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>,
|
||||
Symbol,
|
||||
Symbol,
|
||||
Comparison,
|
||||
Option<ConstructorKnown<'a>>,
|
||||
)>;
|
||||
|
||||
|
@ -1466,7 +1629,13 @@ fn stores_and_condition<'a>(
|
|||
|
||||
// Assumption: there is at most 1 guard, and it is the outer layer.
|
||||
for (path, test) in test_chain {
|
||||
tests.push(test_to_equality(env, cond_symbol, cond_layout, &path, test))
|
||||
tests.push(test_to_comparison(
|
||||
env,
|
||||
cond_symbol,
|
||||
cond_layout,
|
||||
&path,
|
||||
test,
|
||||
))
|
||||
}
|
||||
|
||||
tests
|
||||
|
@ -1477,6 +1646,7 @@ fn compile_test<'a>(
|
|||
ret_layout: Layout<'a>,
|
||||
stores: bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>,
|
||||
lhs: Symbol,
|
||||
cmp: Comparator,
|
||||
rhs: Symbol,
|
||||
fail: &'a Stmt<'a>,
|
||||
cond: Stmt<'a>,
|
||||
|
@ -1487,6 +1657,7 @@ fn compile_test<'a>(
|
|||
ret_layout,
|
||||
stores,
|
||||
lhs,
|
||||
cmp,
|
||||
rhs,
|
||||
fail,
|
||||
cond,
|
||||
|
@ -1500,6 +1671,7 @@ fn compile_test_help<'a>(
|
|||
ret_layout: Layout<'a>,
|
||||
stores: bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>,
|
||||
lhs: Symbol,
|
||||
cmp: Comparator,
|
||||
rhs: Symbol,
|
||||
fail: &'a Stmt<'a>,
|
||||
mut cond: Stmt<'a>,
|
||||
|
@ -1560,7 +1732,10 @@ fn compile_test_help<'a>(
|
|||
default_branch,
|
||||
};
|
||||
|
||||
let op = LowLevel::Eq;
|
||||
let op = match cmp {
|
||||
Comparator::Eq => LowLevel::Eq,
|
||||
Comparator::Geq => LowLevel::NumGte,
|
||||
};
|
||||
let test = Expr::Call(crate::ir::Call {
|
||||
call_type: crate::ir::CallType::LowLevel {
|
||||
op,
|
||||
|
@ -1592,13 +1767,15 @@ fn compile_tests<'a>(
|
|||
fail: &'a Stmt<'a>,
|
||||
mut cond: Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
for (new_stores, lhs, rhs, opt_constructor_info) in tests.into_iter() {
|
||||
for (new_stores, (lhs, cmp, rhs), opt_constructor_info) in tests.into_iter() {
|
||||
match opt_constructor_info {
|
||||
None => {
|
||||
cond = compile_test(env, ret_layout, new_stores, lhs, rhs, fail, cond);
|
||||
cond = compile_test(env, ret_layout, new_stores, lhs, cmp, rhs, fail, cond);
|
||||
}
|
||||
Some(cinfo) => {
|
||||
cond = compile_test_help(env, cinfo, ret_layout, new_stores, lhs, rhs, fail, cond);
|
||||
cond = compile_test_help(
|
||||
env, cinfo, ret_layout, new_stores, lhs, cmp, rhs, fail, cond,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1781,7 +1958,7 @@ fn decide_to_branching<'a>(
|
|||
if number_of_tests == 1 {
|
||||
// if there is just one test, compile to a simple if-then-else
|
||||
|
||||
let (new_stores, lhs, rhs, _cinfo) = tests.into_iter().next().unwrap();
|
||||
let (new_stores, (lhs, cmp, rhs), _cinfo) = tests.into_iter().next().unwrap();
|
||||
|
||||
compile_test_help(
|
||||
env,
|
||||
|
@ -1789,6 +1966,7 @@ fn decide_to_branching<'a>(
|
|||
ret_layout,
|
||||
new_stores,
|
||||
lhs,
|
||||
cmp,
|
||||
rhs,
|
||||
fail,
|
||||
pass_expr,
|
||||
|
@ -1854,6 +2032,12 @@ fn decide_to_branching<'a>(
|
|||
Test::IsBit(v) => v as u64,
|
||||
Test::IsByte { tag_id, .. } => tag_id as u64,
|
||||
Test::IsCtor { tag_id, .. } => tag_id as u64,
|
||||
Test::IsListLen { len, bound } => match bound {
|
||||
ListLenBound::Exact => len as _,
|
||||
ListLenBound::AtLeast => {
|
||||
unreachable!("at-least bounds cannot be switched on")
|
||||
}
|
||||
},
|
||||
Test::IsDecimal(_) => unreachable!("decimals cannot be switched on"),
|
||||
Test::IsStr(_) => unreachable!("strings cannot be switched on"),
|
||||
};
|
||||
|
@ -1911,6 +2095,31 @@ fn decide_to_branching<'a>(
|
|||
union_layout.tag_id_layout(),
|
||||
env.arena.alloc(temp),
|
||||
)
|
||||
} else if let Layout::Builtin(Builtin::List(_)) = inner_cond_layout {
|
||||
let len_symbol = env.unique_symbol();
|
||||
|
||||
let switch = Stmt::Switch {
|
||||
cond_layout: Layout::usize(env.target_info),
|
||||
cond_symbol: len_symbol,
|
||||
branches: branches.into_bump_slice(),
|
||||
default_branch: (default_branch_info, env.arena.alloc(default_branch)),
|
||||
ret_layout,
|
||||
};
|
||||
|
||||
let len_expr = Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::ListLen,
|
||||
update_mode: env.next_update_mode_id(),
|
||||
},
|
||||
arguments: env.arena.alloc([inner_cond_symbol]),
|
||||
});
|
||||
|
||||
Stmt::Let(
|
||||
len_symbol,
|
||||
len_expr,
|
||||
Layout::usize(env.target_info),
|
||||
env.arena.alloc(switch),
|
||||
)
|
||||
} else {
|
||||
Stmt::Switch {
|
||||
cond_layout: inner_cond_layout,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue