mostly fix issues with patterns/guards/switch/cond

This commit is contained in:
Folkert 2020-08-06 01:09:42 +02:00
parent dab00f2e2d
commit db0bed2fe7
7 changed files with 445 additions and 320 deletions

View file

@ -301,7 +301,37 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
call_type: ByPointer(name), call_type: ByPointer(name),
layout, layout,
args, args,
} => todo!(), } => {
let sub_expr = load_symbol(env, scope, name);
let mut arg_vals: Vec<BasicValueEnum> = Vec::with_capacity_in(args.len(), env.arena);
for arg in args.iter() {
arg_vals.push(load_symbol(env, scope, arg));
}
let call = match sub_expr {
BasicValueEnum::PointerValue(ptr) => {
env.builder.build_call(ptr, arg_vals.as_slice(), "tmp")
}
non_ptr => {
panic!(
"Tried to call by pointer, but encountered a non-pointer: {:?}",
non_ptr
);
}
};
// TODO FIXME this should not be hardcoded!
// Need to look up what calling convention is the right one for that function.
// If this is an external-facing function, it'll use the C calling convention.
// If it's an internal-only function, it should (someday) use the fast calling conention.
call.set_call_convention(C_CALL_CONV);
call.try_as_basic_value()
.left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."))
}
Struct(sorted_fields) => { Struct(sorted_fields) => {
let ctx = env.context; let ctx = env.context;
@ -566,7 +596,19 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
Array { elem_layout, elems } => { Array { elem_layout, elems } => {
list_literal2(env, layout_ids, scope, parent, elem_layout, elems) list_literal2(env, layout_ids, scope, parent, elem_layout, elems)
} }
FunctionPointer(_, _) => todo!(), FunctionPointer(symbol, layout) => {
let fn_name = layout_ids
.get(*symbol, layout)
.to_symbol_string(*symbol, &env.interns);
let ptr = env
.module
.get_function(fn_name.as_str())
.unwrap_or_else(|| panic!("Could not get pointer to unknown function {:?}", symbol))
.as_global_value()
.as_pointer_value();
BasicValueEnum::PointerValue(ptr)
}
RuntimeErrorFunction(_) => todo!(), RuntimeErrorFunction(_) => todo!(),
} }
} }

View file

@ -736,155 +736,155 @@ mod gen_list {
}) })
} }
#[test] // #[test]
fn foobar2() { // fn foobar2() {
with_larger_debug_stack(|| { // with_larger_debug_stack(|| {
assert_evals_to_ir!( // assert_evals_to_ir!(
indoc!( // indoc!(
r#" // r#"
quicksort : List (Num a) -> List (Num a) // quicksort : List (Num a) -> List (Num a)
quicksort = \list -> // quicksort = \list ->
quicksortHelp list 0 (List.len list - 1) // quicksortHelp list 0 (List.len list - 1)
//
//
// quicksortHelp : List (Num a), Int, Int -> List (Num a)
// quicksortHelp = \list, low, high ->
// if low < high then
// when partition low high list is
// Pair partitionIndex partitioned ->
// partitioned
// |> quicksortHelp low (partitionIndex - 1)
// |> quicksortHelp (partitionIndex + 1) high
// else
// list
//
//
// swap : Int, Int, List a -> List a
// swap = \i, j, list ->
// when Pair (List.get list i) (List.get list j) is
// Pair (Ok atI) (Ok atJ) ->
// list
// |> List.set i atJ
// |> List.set j atI
//
// _ ->
// []
//
// partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ]
// partition = \low, high, initialList ->
// when List.get initialList high is
// Ok pivot ->
// when partitionHelp (low - 1) low initialList high pivot is
// Pair newI newList ->
// Pair (newI + 1) (swap (newI + 1) high newList)
//
// Err _ ->
// Pair (low - 1) initialList
//
//
// partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ]
// partitionHelp = \i, j, list, high, pivot ->
// # if j < high then
// if False then
// when List.get list j is
// Ok value ->
// if value <= pivot then
// partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot
// else
// partitionHelp i (j + 1) list high pivot
//
// Err _ ->
// Pair i list
// else
// Pair i list
//
//
//
// quicksort [ 7, 4, 21, 19 ]
// "#
// ),
// &[19, 7, 4, 21],
// &'static [i64],
// |x| x,
// true
// );
// })
// }
// #[test]
quicksortHelp : List (Num a), Int, Int -> List (Num a) // fn foobar() {
quicksortHelp = \list, low, high -> // with_larger_debug_stack(|| {
if low < high then // assert_evals_to_ir!(
when partition low high list is // indoc!(
Pair partitionIndex partitioned -> // r#"
partitioned // quicksort : List (Num a) -> List (Num a)
|> quicksortHelp low (partitionIndex - 1) // quicksort = \list ->
|> quicksortHelp (partitionIndex + 1) high // quicksortHelp list 0 (List.len list - 1)
else //
list //
// quicksortHelp : List (Num a), Int, Int -> List (Num a)
// quicksortHelp = \list, low, high ->
swap : Int, Int, List a -> List a // if low < high then
swap = \i, j, list -> // when partition low high list is
when Pair (List.get list i) (List.get list j) is // Pair partitionIndex partitioned ->
Pair (Ok atI) (Ok atJ) -> // partitioned
list // |> quicksortHelp low (partitionIndex - 1)
|> List.set i atJ // |> quicksortHelp (partitionIndex + 1) high
|> List.set j atI // else
// list
_ -> //
[] //
// swap : Int, Int, List a -> List a
partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ] // swap = \i, j, list ->
partition = \low, high, initialList -> // when Pair (List.get list i) (List.get list j) is
when List.get initialList high is // Pair (Ok atI) (Ok atJ) ->
Ok pivot -> // list
when partitionHelp (low - 1) low initialList high pivot is // |> List.set i atJ
Pair newI newList -> // |> List.set j atI
Pair (newI + 1) (swap (newI + 1) high newList) //
// _ ->
Err _ -> // []
Pair (low - 1) initialList //
// partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ]
// partition = \low, high, initialList ->
partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ] // when List.get initialList high is
partitionHelp = \i, j, list, high, pivot -> // Ok pivot ->
# if j < high then // when partitionHelp (low - 1) low initialList high pivot is
if False then // Pair newI newList ->
when List.get list j is // Pair (newI + 1) (swap (newI + 1) high newList)
Ok value -> //
if value <= pivot then // Err _ ->
partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot // Pair (low - 1) initialList
else //
partitionHelp i (j + 1) list high pivot //
// partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ]
Err _ -> // partitionHelp = \i, j, list, high, pivot ->
Pair i list // if j < high then
else // when List.get list j is
Pair i list // Ok value ->
// if value <= pivot then
// partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot
// else
quicksort [ 7, 4, 21, 19 ] // partitionHelp i (j + 1) list high pivot
"# //
), // Err _ ->
&[19, 7, 4, 21], // Pair i list
&'static [i64], // else
|x| x, // Pair i list
true //
); //
}) //
} // when List.first (quicksort [0x1]) is
// _ -> 4
#[test] // "#
fn foobar() { // ),
with_larger_debug_stack(|| { // 4,
assert_evals_to_ir!( // i64,
indoc!( // |x| x,
r#" // false
quicksort : List (Num a) -> List (Num a) // );
quicksort = \list -> // })
quicksortHelp list 0 (List.len list - 1) // }
quicksortHelp : List (Num a), Int, Int -> List (Num a)
quicksortHelp = \list, low, high ->
if low < high then
when partition low high list is
Pair partitionIndex partitioned ->
partitioned
|> quicksortHelp low (partitionIndex - 1)
|> quicksortHelp (partitionIndex + 1) high
else
list
swap : Int, Int, List a -> List a
swap = \i, j, list ->
when Pair (List.get list i) (List.get list j) is
Pair (Ok atI) (Ok atJ) ->
list
|> List.set i atJ
|> List.set j atI
_ ->
[]
partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ]
partition = \low, high, initialList ->
when List.get initialList high is
Ok pivot ->
when partitionHelp (low - 1) low initialList high pivot is
Pair newI newList ->
Pair (newI + 1) (swap (newI + 1) high newList)
Err _ ->
Pair (low - 1) initialList
partitionHelp : Int, Int, List (Num a), Int, Int -> [ Pair Int (List (Num a)) ]
partitionHelp = \i, j, list, high, pivot ->
if j < high then
when List.get list j is
Ok value ->
if value <= pivot then
partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot
else
partitionHelp i (j + 1) list high pivot
Err _ ->
Pair i list
else
Pair i list
when List.first (quicksort [0x1]) is
_ -> 4
"#
),
4,
i64,
|x| x,
false
);
})
}
#[test] #[test]
fn empty_list_increment_decrement() { fn empty_list_increment_decrement() {
@ -958,7 +958,7 @@ mod gen_list {
assert_evals_to_ir!( assert_evals_to_ir!(
indoc!( indoc!(
r#" r#"
id : List Int -> [ Pair (List Int) Int, Nil ] id : List Int -> [ Pair (List Int) Int ]
id = \y -> Pair y 4 id = \y -> Pair y 4
when id [1,2,3] is when id [1,2,3] is

View file

@ -68,6 +68,26 @@ mod gen_num {
#[test] #[test]
fn gen_if_fn() { fn gen_if_fn() {
assert_evals_to_ir!(
indoc!(
r#"
limitedNegate = \num ->
x =
if num == 1 then
-1
else if num == -1 then
1
else
num
x
limitedNegate 1
"#
),
-1,
i64
);
assert_evals_to_ir!( assert_evals_to_ir!(
indoc!( indoc!(
r#" r#"

View file

@ -391,4 +391,20 @@ mod gen_records {
bool bool
); );
} }
#[test]
fn return_record() {
assert_evals_to_ir!(
indoc!(
r#"
x = 4
y = 3
{ x, y }
"#
),
(4, 3),
(i64, i64)
);
}
} }

View file

@ -585,6 +585,8 @@ fn to_relevant_branch_help<'a>(
start.push((Path::Unbox(Box::new(path.clone())), guard, arg.0)); start.push((Path::Unbox(Box::new(path.clone())), guard, arg.0));
start.extend(end); start.extend(end);
} }
} else if union.alternatives.len() == 1 {
todo!("this should need a special index, right?")
} else { } else {
let sub_positions = let sub_positions =
arguments arguments
@ -946,7 +948,7 @@ fn path_to_expr_help2<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
mut symbol: Symbol, mut symbol: Symbol,
mut path: &Path, mut path: &Path,
layout: Layout<'a>, mut layout: Layout<'a>,
) -> (Symbol, StoresVec<'a>, Layout<'a>) { ) -> (Symbol, StoresVec<'a>, Layout<'a>) {
let mut stores = bumpalo::collections::Vec::new_in(env.arena); let mut stores = bumpalo::collections::Vec::new_in(env.arena);
@ -983,8 +985,9 @@ fn path_to_expr_help2<'a>(
}; };
symbol = env.unique_symbol(); symbol = env.unique_symbol();
stores.push((symbol, inner_layout, inner_expr)); stores.push((symbol, inner_layout.clone(), inner_expr));
layout = inner_layout;
path = nested; path = nested;
} }
} }
@ -1230,18 +1233,47 @@ fn decide_to_branching<'a>(
} }
} }
debug_assert!(!tests.is_empty());
let mut current_symbol = branching_symbol; let mut current_symbol = branching_symbol;
let mut condition_symbol = true_symbol;
// TODO There must be some way to remove this iterator/loop
let nr = (tests.len() as i64) - 1 + (guard.is_some() as i64);
let accum_symbols = std::iter::once(true_symbol) let accum_symbols = std::iter::once(true_symbol)
.chain((0..tests.len() - 1).map(|_| env.unique_symbol())) .chain((0..nr).map(|_| env.unique_symbol()))
.rev() .rev()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for ((new_stores, lhs, rhs, layout), accum) in let mut accum_it = accum_symbols.into_iter();
tests.into_iter().rev().zip(accum_symbols)
{ // the guard is the final thing that we check, so needs to be layered on first!
if let Some((_, id, stmt)) = guard {
let accum = accum_it.next().unwrap();
let test_symbol = env.unique_symbol();
let and_expr =
Expr::RunLowLevel(LowLevel::And, env.arena.alloc([test_symbol, accum]));
// write to the branching symbol
cond = Stmt::Let(
current_symbol,
and_expr,
Layout::Builtin(Builtin::Int1),
env.arena.alloc(cond),
);
// calculate the guard value
cond = Stmt::Join {
id,
arguments: env
.arena
.alloc([(test_symbol, Layout::Builtin(Builtin::Int1))]),
remainder: env.arena.alloc(stmt),
continuation: env.arena.alloc(cond),
};
current_symbol = accum;
}
for ((new_stores, lhs, rhs, layout), accum) in tests.into_iter().rev().zip(accum_it) {
let test_symbol = env.unique_symbol(); let test_symbol = env.unique_symbol();
let test = Expr::RunLowLevel( let test = Expr::RunLowLevel(
LowLevel::Eq, LowLevel::Eq,
@ -1271,41 +1303,9 @@ fn decide_to_branching<'a>(
cond = Stmt::Let(symbol, expr, layout, env.arena.alloc(cond)); cond = Stmt::Let(symbol, expr, layout, env.arena.alloc(cond));
} }
condition_symbol = current_symbol;
current_symbol = accum; current_symbol = accum;
} }
/*
// the guard is the final thing that we check, so needs to be layered on first!
if let Some((symbol, id, stmt)) = guard {
let test_symbol = symbol;
let and_expr = Expr::RunLowLevel(
LowLevel::And,
env.arena.alloc([test_symbol, condition_symbol]),
);
// write to the branching symbol
cond = Stmt::Let(
current_symbol,
and_expr,
Layout::Builtin(Builtin::Int1),
env.arena.alloc(cond),
);
// calculate the guard value
cond = Stmt::Join {
id,
arguments: &[],
remainder: env.arena.alloc(stmt),
continuation: env.arena.alloc(cond),
};
condition_symbol = current_symbol;
current_symbol = env.unique_symbol();
}
*/
cond = Stmt::Let( cond = Stmt::Let(
true_symbol, true_symbol,
Expr::Literal(Literal::Bool(true)), Expr::Literal(Literal::Bool(true)),

View file

@ -553,7 +553,7 @@ impl<'a> Stmt<'a> {
) -> Self { ) -> Self {
let mut layout_cache = LayoutCache::default(); let mut layout_cache = LayoutCache::default();
dbg!(from_can(env, can_expr, procs, &mut layout_cache)) from_can(env, can_expr, procs, &mut layout_cache)
} }
pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D, parens: bool) -> DocBuilder<'b, D, A> pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D, parens: bool) -> DocBuilder<'b, D, A>
where where
@ -926,7 +926,9 @@ fn specialize<'a>(
debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_))); debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_)));
let specialized_body = from_can(env, body, procs, layout_cache); let ret_symbol = env.unique_symbol();
let hole = env.arena.alloc(Stmt::Ret(ret_symbol));
let specialized_body = with_hole(env, body, procs, layout_cache, ret_symbol, hole);
// reset subs, so we don't get type errors when specializing for a different signature // reset subs, so we don't get type errors when specializing for a different signature
env.subs.rollback_to(snapshot); env.subs.rollback_to(snapshot);
@ -1031,7 +1033,13 @@ pub fn with_hole<'a>(
todo!() todo!()
} }
} }
Var(symbol) => Stmt::Ret(symbol), Var(symbol) => {
// A bit ugly, but it does the job
match hole {
Stmt::Jump(id, _) => Stmt::Jump(*id, env.arena.alloc([symbol])),
_ => Stmt::Ret(symbol),
}
}
// Var(symbol) => panic!("reached Var {}", symbol), // Var(symbol) => panic!("reached Var {}", symbol),
Tag { Tag {
variant_var, variant_var,
@ -1206,10 +1214,16 @@ pub fn with_hole<'a>(
let mut can_fields = Vec::with_capacity_in(fields.len(), env.arena); let mut can_fields = Vec::with_capacity_in(fields.len(), env.arena);
for (label, layout) in sorted_fields.into_iter() { for (label, layout) in sorted_fields.into_iter() {
field_symbols.push(env.unique_symbol());
field_layouts.push(layout); field_layouts.push(layout);
let field = fields.remove(&label).unwrap(); let field = fields.remove(&label).unwrap();
can_fields.push(field); let field_symbol = if let roc_can::expr::Expr::Var(symbol) = field.loc_expr.value {
field_symbols.push(symbol);
can_fields.push(None);
} else {
field_symbols.push(env.unique_symbol());
can_fields.push(Some(field));
};
} }
// creating a record from the var will unpack it if it's just a single field. // creating a record from the var will unpack it if it's just a single field.
@ -1220,15 +1234,18 @@ pub fn with_hole<'a>(
let field_symbols = field_symbols.into_bump_slice(); let field_symbols = field_symbols.into_bump_slice();
let mut stmt = Stmt::Let(assigned, Expr::Struct(field_symbols), layout, hole); let mut stmt = Stmt::Let(assigned, Expr::Struct(field_symbols), layout, hole);
for (field, symbol) in can_fields.into_iter().rev().zip(field_symbols.iter().rev()) { for (opt_field, symbol) in can_fields.into_iter().rev().zip(field_symbols.iter().rev())
stmt = with_hole( {
env, if let Some(field) = opt_field {
field.loc_expr.value, stmt = with_hole(
procs, env,
layout_cache, field.loc_expr.value,
*symbol, procs,
env.arena.alloc(stmt), layout_cache,
); *symbol,
env.arena.alloc(stmt),
);
}
} }
stmt stmt
@ -1317,7 +1334,7 @@ pub fn with_hole<'a>(
branches, branches,
layout_cache, layout_cache,
procs, procs,
Some((id, assigned)), Some(id),
); );
// TODO define condition // TODO define condition
@ -1464,7 +1481,27 @@ pub fn with_hole<'a>(
Accessor { .. } | Update { .. } => todo!("record access/accessor/update"), Accessor { .. } | Update { .. } => todo!("record access/accessor/update"),
Closure(_, _, _, _, _) => todo!("call"), Closure(ann, name, _, loc_args, boxed_body) => {
let (loc_body, ret_var) = *boxed_body;
match procs.insert_anonymous(env, name, ann, loc_args, loc_body, ret_var, layout_cache)
{
Ok(layout) => {
// TODO should the let have layout Pointer?
Stmt::Let(
assigned,
Expr::FunctionPointer(name, layout.clone()),
layout,
hole,
)
}
Err(_error) => Stmt::RuntimeError(
"TODO convert anonymous function error to a RuntimeError string",
),
}
}
Call(boxed, loc_args, _) => { Call(boxed, loc_args, _) => {
let (fn_var, loc_expr, ret_var) = *boxed; let (fn_var, loc_expr, ret_var) = *boxed;
@ -1532,16 +1569,22 @@ pub fn with_hole<'a>(
panic!("TODO turn fn_var into a RuntimeError {:?}", err) panic!("TODO turn fn_var into a RuntimeError {:?}", err)
}); });
let ret_layout = layout_cache
.from_var(env.arena, ret_var, env.subs, env.pointer_size)
.unwrap_or_else(|err| {
panic!("TODO turn fn_var into a RuntimeError {:?}", err)
});
let function_symbol = env.unique_symbol(); let function_symbol = env.unique_symbol();
let arg_symbols = arg_symbols.into_bump_slice(); let arg_symbols = arg_symbols.into_bump_slice();
let mut result = Stmt::Let( let mut result = Stmt::Let(
assigned, assigned,
Expr::FunctionCall { Expr::FunctionCall {
call_type: CallType::ByPointer(function_symbol), call_type: CallType::ByPointer(function_symbol),
layout: layout.clone(), layout,
args: arg_symbols, args: arg_symbols,
}, },
layout, ret_layout,
arena.alloc(hole), arena.alloc(hole),
); );
@ -1614,7 +1657,7 @@ pub fn with_hole<'a>(
result result
} }
RuntimeError(_) => todo!("runtime error"), RuntimeError(e) => todo!("runtime error {:?}", e),
} }
} }
@ -1680,82 +1723,6 @@ pub fn from_can<'a>(
) )
} }
If {
cond_var,
branch_var,
branches,
final_else,
} => {
let mut expr = from_can(env, final_else.value, procs, layout_cache);
let arena = env.arena;
let ret_layout = layout_cache
.from_var(env.arena, branch_var, env.subs, env.pointer_size)
.expect("invalid ret_layout");
let cond_layout = layout_cache
.from_var(env.arena, cond_var, env.subs, env.pointer_size)
.expect("invalid cond_layout");
for (loc_cond, loc_then) in branches.into_iter().rev() {
let branching_symbol = env.unique_symbol();
let then = from_can(env, loc_then.value, procs, layout_cache);
let cond_stmt = Stmt::Cond {
cond_symbol: branching_symbol,
branching_symbol,
cond_layout: cond_layout.clone(),
branching_layout: cond_layout.clone(),
pass: env.arena.alloc(then),
fail: env.arena.alloc(expr),
ret_layout: ret_layout.clone(),
};
// add condition
let hole = env.arena.alloc(cond_stmt);
expr = with_hole(
env,
loc_cond.value,
procs,
layout_cache,
branching_symbol,
hole,
);
}
expr
}
When {
cond_var,
expr_var,
region,
loc_cond,
branches,
} => {
let cond_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_cond.value {
symbol
} else {
env.unique_symbol()
};
let mono_when = from_can_when(
env,
cond_var,
expr_var,
region,
cond_symbol,
branches,
layout_cache,
procs,
None,
);
if let roc_can::expr::Expr::Var(_) = loc_cond.value {
mono_when
} else {
let hole = env.arena.alloc(mono_when);
with_hole(env, loc_cond.value, procs, layout_cache, cond_symbol, hole)
}
}
_ => { _ => {
let symbol = env.unique_symbol(); let symbol = env.unique_symbol();
let hole = env.arena.alloc(Stmt::Ret(symbol)); let hole = env.arena.alloc(Stmt::Ret(symbol));
@ -1774,7 +1741,7 @@ fn to_opt_branches<'a>(
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
) -> std::vec::Vec<( ) -> std::vec::Vec<(
Pattern<'a>, Pattern<'a>,
crate::decision_tree2::Guard<'a>, Option<Located<roc_can::expr::Expr>>,
roc_can::expr::Expr, roc_can::expr::Expr,
)> { )> {
debug_assert!(!branches.is_empty()); debug_assert!(!branches.is_empty());
@ -1801,10 +1768,12 @@ fn to_opt_branches<'a>(
exhaustive_guard.clone(), exhaustive_guard.clone(),
)); ));
// TODO implement guard again // TODO remove clone?
let mono_guard = crate::decision_tree2::Guard::NoGuard; opt_branches.push((
mono_pattern,
opt_branches.push((mono_pattern, mono_guard, when_branch.value.value.clone())); when_branch.guard.clone(),
when_branch.value.value.clone(),
));
} }
} }
@ -1840,7 +1809,7 @@ fn to_opt_branches<'a>(
if is_not_exhaustive { if is_not_exhaustive {
opt_branches.push(( opt_branches.push((
Pattern::Underscore, Pattern::Underscore,
crate::decision_tree2::Guard::NoGuard, None,
roc_can::expr::Expr::RuntimeError( roc_can::expr::Expr::RuntimeError(
roc_problem::can::RuntimeError::NonExhaustivePattern, roc_problem::can::RuntimeError::NonExhaustivePattern,
), ),
@ -1861,7 +1830,7 @@ fn from_can_when<'a>(
branches: std::vec::Vec<roc_can::expr::WhenBranch>, branches: std::vec::Vec<roc_can::expr::WhenBranch>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
procs: &mut Procs<'a>, procs: &mut Procs<'a>,
join_point: Option<(JoinPointId, Symbol)>, join_point: Option<JoinPointId>,
) -> Stmt<'a> { ) -> Stmt<'a> {
if branches.is_empty() { if branches.is_empty() {
// A when-expression with no branches is a runtime error. // A when-expression with no branches is a runtime error.
@ -1887,36 +1856,58 @@ fn from_can_when<'a>(
.unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err));
let arena = env.arena; let arena = env.arena;
let it = opt_branches.into_iter().map(|(pattern, guard, can_expr)| { let it = opt_branches
let mut stores = Vec::with_capacity_in(1, env.arena); .into_iter()
let res_stores = .map(|(pattern, opt_guard, can_expr)| {
store_pattern(env, &pattern, cond_symbol, cond_layout.clone(), &mut stores); let mut stores = Vec::with_capacity_in(1, env.arena);
let mut stmt = match join_point { let res_stores =
None => from_can(env, can_expr, procs, layout_cache), store_pattern(env, &pattern, cond_symbol, cond_layout.clone(), &mut stores);
Some((id, _symbol)) => { let mut stmt = match join_point {
let symbol = env.unique_symbol(); None => from_can(env, can_expr, procs, layout_cache),
let arguments = bumpalo::vec![in env.arena; symbol].into_bump_slice(); Some(id) => {
let jump = env.arena.alloc(Stmt::Jump(id, arguments)); let symbol = env.unique_symbol();
let arguments = bumpalo::vec![in env.arena; symbol].into_bump_slice();
let jump = env.arena.alloc(Stmt::Jump(id, arguments));
with_hole(env, can_expr, procs, layout_cache, symbol, jump) with_hole(env, can_expr, procs, layout_cache, symbol, jump)
}
};
match res_stores {
Ok(_) => {
for (symbol, layout, expr) in stores.into_iter().rev() {
stmt = Stmt::Let(symbol, expr, layout, env.arena.alloc(stmt));
} }
};
(pattern, guard, stmt) use crate::decision_tree2::Guard;
match res_stores {
Ok(_) => {
for (symbol, layout, expr) in stores.iter().rev() {
stmt =
Stmt::Let(*symbol, expr.clone(), layout.clone(), env.arena.alloc(stmt));
}
let guard = if let Some(loc_expr) = opt_guard {
let id = JoinPointId(env.unique_symbol());
let symbol = env.unique_symbol();
let jump = env.arena.alloc(Stmt::Jump(id, env.arena.alloc([symbol])));
let mut stmt =
with_hole(env, loc_expr.value, procs, layout_cache, symbol, jump);
// guard must have access to bound values
for (symbol, layout, expr) in stores.into_iter().rev() {
stmt = Stmt::Let(symbol, expr, layout, env.arena.alloc(stmt));
}
Guard::Guard { id, symbol, stmt }
} else {
Guard::NoGuard
};
(pattern, guard, stmt)
}
Err(msg) => (
Pattern::Underscore,
Guard::NoGuard,
Stmt::RuntimeError(env.arena.alloc(msg)),
),
} }
Err(msg) => ( });
Pattern::Underscore,
guard,
Stmt::RuntimeError(env.arena.alloc(msg)),
),
}
});
let mono_branches = Vec::from_iter_in(it, arena); let mono_branches = Vec::from_iter_in(it, arena);
crate::decision_tree2::optimize_when( crate::decision_tree2::optimize_when(

View file

@ -1385,7 +1385,7 @@ mod test_mono {
compiles_to_ir( compiles_to_ir(
r#" r#"
when 2 is when 2 is
2 if True -> 42 2 if False -> 42
_ -> 0 _ -> 0
"#, "#,
indoc!( indoc!(
@ -1586,4 +1586,60 @@ mod test_mono {
), ),
) )
} }
#[test]
fn if_multi_branch() {
compiles_to_ir(
r#"
if True then
1
else if False then
2
else
3
"#,
indoc!(
r#"
procedure List.5 (#Attr.2, #Attr.3):
let Test.3 = lowlevel ListPush #Attr.2 #Attr.3;
ret Test.3;
let Test.4 = 1i64;
let Test.1 = Array [Test.4];
let Test.2 = 2i64;
let Test.0 = CallByName List.5 Test.1 Test.2;
ret Test.0;
"#
),
)
}
#[test]
fn when_on_result() {
compiles_to_ir(
r#"
x : Result Int Int
x = Ok 2
y = when x is
Ok 3 -> 1
Ok _ -> 2
Err _ -> 3
y
"#,
indoc!(
r#"
procedure List.5 (#Attr.2, #Attr.3):
let Test.3 = lowlevel ListPush #Attr.2 #Attr.3;
ret Test.3;
let Test.4 = 1i64;
let Test.1 = Array [Test.4];
let Test.2 = 2i64;
let Test.0 = CallByName List.5 Test.1 Test.2;
ret Test.0;
"#
),
)
}
} }