support more complex nested patterns

This commit is contained in:
Folkert 2020-03-20 14:27:07 +01:00
parent fbadd9d620
commit e062404a63
3 changed files with 162 additions and 117 deletions

View file

@ -1313,6 +1313,22 @@ mod test_gen {
); );
} }
#[test]
fn match_on_two_values() {
// this will produce a Chain internally
assert_evals_to!(
indoc!(
r#"
when Pair 2 3 is
Pair 4 3 -> 9
Pair a b -> a + b
"#
),
5,
i64
);
}
#[test] #[test]
fn maybe_is_just() { fn maybe_is_just() {
assert_evals_to!( assert_evals_to!(

View file

@ -6,6 +6,7 @@ use roc_collections::all::{MutMap, MutSet};
use roc_module::ident::TagName; use roc_module::ident::TagName;
use roc_module::symbol::Symbol; use roc_module::symbol::Symbol;
use crate::expr::specialize_equality;
use crate::layout::Builtin; use crate::layout::Builtin;
use crate::layout::Layout; use crate::layout::Layout;
use crate::pattern::{Ctor, Union}; use crate::pattern::{Ctor, Union};
@ -61,7 +62,11 @@ pub enum Test<'a> {
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Path { pub enum Path {
Index { index: u64, path: Box<Path> }, Index {
index: u64,
tag_id: u8,
path: Box<Path>,
},
Unbox(Box<Path>), Unbox(Box<Path>),
Empty, Empty,
} }
@ -141,24 +146,32 @@ fn flatten_patterns(branch: Branch) -> Branch {
fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, Pattern<'a>)>) { fn flatten<'a>(path_pattern: (Path, Pattern<'a>), path_patterns: &mut Vec<(Path, Pattern<'a>)>) {
match &path_pattern.1 { match &path_pattern.1 {
Pattern::AppliedTag { union, .. } => { Pattern::AppliedTag {
union,
arguments,
tag_id,
..
} => {
if union.alternatives.len() == 1 { if union.alternatives.len() == 1 {
// case map dearg ctorArgs of let path = path_pattern.0;
// [arg] -> // Theory: unbox doesn't have any value for us, because one-element tag unions
// flatten (Unbox path, arg) otherPathPatterns // don't store the tag anyway.
// // if arguments.len() == 1 {
// args -> // path_patterns.push((Path::Unbox(Box::new(path)), path_pattern.1.clone()));
// foldr flatten otherPathPatterns (subPositions path args) // } else {
// subPositions :: Path -> [Can.Pattern] -> [(Path, Can.Pattern)] for (index, (arg_pattern, _)) in arguments.iter().enumerate() {
// subPositions path patterns = flatten(
// Index.indexedMap (\index pattern -> (Index index path, pattern)) patterns (
// Path::Index {
// index: index as u64,
// dearg :: Can.PatternCtorArg -> Can.Pattern tag_id: *tag_id,
// dearg (Can.PatternCtorArg _ _ pattern) = path: Box::new(path.clone()),
// pattern },
arg_pattern.clone(),
todo!("alternatives: {:?}", union.alternatives) ),
path_patterns,
);
}
} else { } else {
path_patterns.push(path_pattern); path_patterns.push(path_pattern);
} }
@ -331,6 +344,7 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
RecordDestructure(destructs, _) => match test { RecordDestructure(destructs, _) => match test {
IsCtor { IsCtor {
tag_name: test_name, tag_name: test_name,
tag_id,
.. ..
} => { } => {
debug_assert!(test_name == &TagName::Global("#Record".into())); debug_assert!(test_name == &TagName::Global("#Record".into()));
@ -346,6 +360,7 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
( (
Path::Index { Path::Index {
index: index as u64, index: index as u64,
tag_id: *tag_id,
path: Box::new(path.clone()), path: Box::new(path.clone()),
}, },
pattern, pattern,
@ -363,39 +378,40 @@ fn to_relevant_branch<'a>(test: &Test<'a>, path: &Path, branch: Branch<'a>) -> O
}, },
AppliedTag { AppliedTag {
union,
tag_name, tag_name,
arguments, arguments,
.. ..
} => { } => {
let mut arguments: Vec<_> = arguments.into_iter().map(|v| v.0).collect();
match test { match test {
IsCtor { IsCtor {
tag_name: test_name, tag_name: test_name,
tag_id,
.. ..
} if &tag_name == test_name => { } if &tag_name == test_name => {
// TODO can't we unbox whenever there is just one alternative, even if // Theory: Unbox doesn't have any value for us
// there are multiple arguments? // if arguments.len() == 1 && union.alternatives.len() == 1 {
if arguments.len() == 1 && union.alternatives.len() == 1 { // let arg = arguments.remove(0);
let arg = arguments.remove(0); // {
{ // start.push((Path::Unbox(Box::new(path.clone())), arg));
start.push((Path::Unbox(Box::new(path.clone())), arg)); // start.extend(end);
start.extend(end); // }
} // } else {
} else { let sub_positions =
let sub_positions = arguments
arguments.into_iter().enumerate().map(|(index, pattern)| { .into_iter()
.enumerate()
.map(|(index, (pattern, _))| {
( (
Path::Index { Path::Index {
index: index as u64, index: index as u64,
tag_id: *tag_id,
path: Box::new(path.clone()), path: Box::new(path.clone()),
}, },
pattern, pattern,
) )
}); });
start.extend(sub_positions); start.extend(sub_positions);
start.extend(end); start.extend(end);
}
Some(Branch { Some(Branch {
goal: branch.goal, goal: branch.goal,
@ -720,21 +736,53 @@ fn path_to_expr<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
symbol: Symbol, symbol: Symbol,
path: &Path, path: &Path,
is_unwrapped: bool, layout: &Layout<'a>,
field_layouts: &'a [Layout<'a>],
) -> Expr<'a> { ) -> Expr<'a> {
path_to_expr_help(env, symbol, path, layout.clone()).0
}
fn path_to_expr_help<'a>(
env: &mut Env<'a, '_>,
symbol: Symbol,
path: &Path,
layout: Layout<'a>,
) -> (Expr<'a>, Layout<'a>) {
match path { match path {
Path::Unbox(ref path) => path_to_expr(env, symbol, path, true, field_layouts), Path::Unbox(ref unboxed) => match **unboxed {
_ => todo!(),
Path::Empty => Expr::Load(symbol),
// TODO path contains a nested path. Traverse all the way
Path::Index { index, .. } => Expr::AccessAtIndex {
index: *index,
field_layouts,
expr: env.arena.alloc(Expr::Load(symbol)),
is_unwrapped,
}, },
Path::Empty => (Expr::Load(symbol), layout),
Path::Index {
index,
tag_id,
path: nested,
} => {
let (outer_expr, outer_layout) = path_to_expr_help(env, symbol, nested, layout);
let (is_unwrapped, field_layouts) = match outer_layout {
Layout::Union(layouts) => (layouts.is_empty(), layouts[*tag_id as usize].to_vec()),
Layout::Struct(layouts) => (
true,
layouts.iter().map(|v| v.1.clone()).collect::<Vec<_>>(),
),
other => (true, vec![other]),
};
debug_assert!(*index < field_layouts.len() as u64);
let inner_layout = field_layouts[*index as usize].clone();
let inner_expr = Expr::AccessAtIndex {
index: *index,
field_layouts: env.arena.alloc(field_layouts),
expr: env.arena.alloc(outer_expr),
is_unwrapped,
};
(inner_expr, inner_layout)
}
} }
} }
@ -794,85 +842,46 @@ fn decide_to_branching<'a>(
is_unwrapped: union.alternatives.len() == 1, is_unwrapped: union.alternatives.len() == 1,
}; };
let cond = Expr::CallByName( tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
Symbol::INT_EQ_I64,
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Int64)),
(rhs, Layout::Builtin(Builtin::Int64)),
]),
);
tests.push(cond);
} }
Test::IsInt(test_int) => { Test::IsInt(test_int) => {
let lhs = Expr::Int(test_int); let lhs = Expr::Int(test_int);
let rhs = path_to_expr( let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
env,
cond_symbol,
&path,
false,
env.arena.alloc([Layout::Builtin(Builtin::Int64)]),
);
let cond = Expr::CallByName( tests.push((lhs, rhs, Layout::Builtin(Builtin::Int64)));
Symbol::INT_EQ_I64,
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Int64)),
(rhs, Layout::Builtin(Builtin::Int64)),
]),
);
tests.push(cond);
} }
Test::IsFloat(test_int) => { Test::IsFloat(test_int) => {
// TODO maybe we can actually use i64 comparison here? // TODO maybe we can actually use i64 comparison here?
let test_float = f64::from_bits(test_int as u64); let test_float = f64::from_bits(test_int as u64);
let lhs = Expr::Float(test_float); let lhs = Expr::Float(test_float);
let rhs = path_to_expr( let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
env,
cond_symbol,
&path,
false,
env.arena.alloc([Layout::Builtin(Builtin::Float64)]),
);
let cond = Expr::CallByName( tests.push((lhs, rhs, Layout::Builtin(Builtin::Float64)));
Symbol::FLOAT_EQ,
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Float64)),
(rhs, Layout::Builtin(Builtin::Float64)),
]),
);
tests.push(cond);
} }
Test::IsByte { Test::IsByte {
tag_id: test_byte, tag_id: test_byte, ..
// num_alts: _,
..
} => { } => {
let lhs = Expr::Byte(test_byte); let lhs = Expr::Byte(test_byte);
let rhs = path_to_expr( let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
env,
cond_symbol,
&path,
false,
env.arena.alloc([Layout::Builtin(Builtin::Byte)]),
);
let cond = Expr::CallByName( tests.push((lhs, rhs, Layout::Builtin(Builtin::Byte)));
Symbol::INT_EQ_I8, }
env.arena.alloc([
(lhs, Layout::Builtin(Builtin::Byte)), Test::IsBit(test_bit) => {
(rhs, Layout::Builtin(Builtin::Byte)), let lhs = Expr::Bool(test_bit);
]), let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Bool)));
tests.push(cond); }
Test::IsStr(test_str) => {
let lhs = Expr::Str(env.arena.alloc(test_str));
let rhs = path_to_expr(env, cond_symbol, &path, &cond_layout);
tests.push((lhs, rhs, Layout::Builtin(Builtin::Str)));
} }
_ => todo!(),
} }
} }
@ -911,13 +920,9 @@ fn decide_to_branching<'a>(
tests, tests,
fallback, fallback,
} => { } => {
let cond = env.arena.alloc(path_to_expr( let cond = env
env, .arena
cond_symbol, .alloc(path_to_expr(env, cond_symbol, &path, &cond_layout));
&path,
false,
env.arena.alloc([cond_layout.clone()]),
));
let default_branch = env.arena.alloc(decide_to_branching( let default_branch = env.arena.alloc(decide_to_branching(
env, env,
@ -964,10 +969,11 @@ fn decide_to_branching<'a>(
} }
} }
fn boolean_all<'a>(arena: &'a Bump, tests: Vec<Expr<'a>>) -> Expr<'a> { fn boolean_all<'a>(arena: &'a Bump, tests: Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>) -> Expr<'a> {
let mut expr = Expr::Bool(true); let mut expr = Expr::Bool(true);
for test in tests.into_iter().rev() { for (lhs, rhs, layout) in tests.into_iter().rev() {
let test = specialize_equality(arena, lhs, rhs, layout);
expr = Expr::CallByName( expr = Expr::CallByName(
Symbol::BOOL_AND, Symbol::BOOL_AND,
arena.alloc([ arena.alloc([

View file

@ -893,7 +893,7 @@ fn store_pattern<'a>(
_ => { _ => {
// store the field in a symbol, and continue matching on it // store the field in a symbol, and continue matching on it
let symbol = env.fresh_symbol(); let symbol = env.fresh_symbol();
stored.push((symbol, layout.clone(), load)); stored.push((symbol, arg_layout.clone(), load));
store_pattern(env, argument, symbol, arg_layout.clone(), stored)?; store_pattern(env, argument, symbol, arg_layout.clone(), stored)?;
} }
@ -1402,3 +1402,26 @@ fn from_can_record_destruct<'a>(
}, },
} }
} }
pub fn specialize_equality<'a>(
arena: &'a Bump,
lhs: Expr<'a>,
rhs: Expr<'a>,
layout: Layout<'a>,
) -> Expr<'a> {
let symbol = match &layout {
Layout::Builtin(builtin) => match builtin {
Builtin::Int64 => Symbol::INT_EQ_I64,
Builtin::Float64 => Symbol::FLOAT_EQ,
Builtin::Byte => Symbol::INT_EQ_I8,
Builtin::Bool => Symbol::INT_EQ_I1,
other => todo!("Cannot yet compare for equality {:?}", other),
},
other => todo!("Cannot yet compare for equality {:?}", other),
};
Expr::CallByName(
symbol,
arena.alloc([(lhs, layout.clone()), (rhs, layout.clone())]),
)
}