join points with arguments

This commit is contained in:
Folkert 2020-08-05 22:33:07 +02:00
parent b22fa7c9cd
commit c18bbe9a63
6 changed files with 394 additions and 332 deletions

View file

@ -41,7 +41,7 @@ pub enum OptLevel {
#[derive(Default, Debug, Clone, PartialEq)] #[derive(Default, Debug, Clone, PartialEq)]
pub struct Scope<'a, 'ctx> { pub struct Scope<'a, 'ctx> {
symbols: ImMap<Symbol, (Layout<'a>, PointerValue<'ctx>)>, symbols: ImMap<Symbol, (Layout<'a>, PointerValue<'ctx>)>,
join_points: ImMap<JoinPointId, BasicBlock<'ctx>>, join_points: ImMap<JoinPointId, (BasicBlock<'ctx>, &'a [PointerValue<'ctx>])>,
} }
impl<'a, 'ctx> Scope<'a, 'ctx> { impl<'a, 'ctx> Scope<'a, 'ctx> {
@ -604,11 +604,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
result result
} }
Ret(symbol) => { Ret(symbol) => load_symbol(env, scope, symbol),
dbg!(symbol, &scope);
load_symbol(env, scope, symbol)
}
Cond { Cond {
branching_symbol, branching_symbol,
@ -714,11 +710,24 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
let builder = env.builder; let builder = env.builder;
let context = env.context; let context = env.context;
let mut joinpoint_args = Vec::with_capacity_in(arguments.len(), env.arena);
for (_, layout) in arguments.iter() {
let btype = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes);
joinpoint_args.push(create_entry_block_alloca(
env,
parent,
btype,
"joinpointarg",
));
}
// create new block // create new block
let cont_block = context.append_basic_block(parent, "joinpointcont"); let cont_block = context.append_basic_block(parent, "joinpointcont");
// store this join point // store this join point
scope.join_points.insert(*id, cont_block); let joinpoint_args = joinpoint_args.into_bump_slice();
scope.join_points.insert(*id, (cont_block, joinpoint_args));
// construct the blocks that may jump to this join point // construct the blocks that may jump to this join point
build_exp_stmt(env, layout_ids, scope, parent, remainder); build_exp_stmt(env, layout_ids, scope, parent, remainder);
@ -726,14 +735,11 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
// remove this join point again // remove this join point again
scope.join_points.remove(&id); scope.join_points.remove(&id);
// Assumptions for (ptr, (argument, layout)) in joinpoint_args.iter().zip(arguments.iter()) {
// scope.insert(*argument, (layout.clone(), *ptr));
// - `remainder` is either a Cond or Switch where }
// - all branches jump to this join point
//
// we should improve this in the future!
let phi_block = builder.get_insert_block().unwrap(); let phi_block = builder.get_insert_block().unwrap();
//builder.build_unconditional_branch(cont_block);
// put the cont block at the back // put the cont block at the back
builder.position_at_end(cont_block); builder.position_at_end(cont_block);
@ -745,15 +751,20 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
result result
} }
Jump(join_point, _arguments) => { Jump(join_point, arguments) => {
let builder = env.builder; let builder = env.builder;
let context = env.context; let context = env.context;
let cont_block = scope.join_points.get(join_point).unwrap(); let (cont_block, argument_pointers) = scope.join_points.get(join_point).unwrap();
let jmp = builder.build_unconditional_branch(*cont_block);
// builder.insert_instruction(&jmp, None); for (pointer, argument) in argument_pointers.iter().zip(arguments.iter()) {
let value = load_symbol(env, scope, argument);
builder.build_store(*pointer, value);
}
builder.build_unconditional_branch(*cont_block);
// This doesn't currently do anything // This doesn't currently do anything
context.i64_type().const_int(0, false).into() context.i64_type().const_zero().into()
} }
_ => todo!("unsupported expr {:?}", stmt), _ => todo!("unsupported expr {:?}", stmt),
} }
@ -1724,21 +1735,30 @@ fn build_switch_ir<'a, 'ctx, 'env>(
let branch_val = build_exp_stmt(env, layout_ids, scope, parent, branch_expr); let branch_val = build_exp_stmt(env, layout_ids, scope, parent, branch_expr);
if block.get_terminator().is_none() {
builder.build_unconditional_branch(cont_block); builder.build_unconditional_branch(cont_block);
incoming.push((branch_val, block)); incoming.push((branch_val, block));
} }
}
// The block for the conditional's default branch. // The block for the conditional's default branch.
builder.position_at_end(default_block); builder.position_at_end(default_block);
let default_val = build_exp_stmt(env, layout_ids, scope, parent, default_branch); let default_val = build_exp_stmt(env, layout_ids, scope, parent, default_branch);
if default_block.get_terminator().is_none() {
builder.build_unconditional_branch(cont_block); builder.build_unconditional_branch(cont_block);
incoming.push((default_val, default_block)); incoming.push((default_val, default_block));
}
// emit merge block // emit merge block
if incoming.is_empty() {
unsafe {
cont_block.delete().unwrap();
}
// produce unused garbage value
context.i64_type().const_zero().into()
} else {
builder.position_at_end(cont_block); builder.position_at_end(cont_block);
let phi = builder.build_phi(ret_type, "branch"); let phi = builder.build_phi(ret_type, "branch");
@ -1749,6 +1769,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
phi.as_basic_value() phi.as_basic_value()
} }
}
fn build_basic_phi2<'a, 'ctx, 'env, PassFn, FailFn>( fn build_basic_phi2<'a, 'ctx, 'env, PassFn, FailFn>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,

View file

@ -663,17 +663,74 @@ mod gen_tags {
} }
#[test] #[test]
fn join_points() { fn join_point_if() {
assert_evals_to_ir!( assert_evals_to_ir!(
indoc!( indoc!(
r#" r#"
x = x =
if True then 1 else 2 if True then 1 else 2
5 x
"# "#
), ),
5, 1,
i64
);
}
#[test]
fn join_point_when() {
assert_evals_to_ir!(
indoc!(
r#"
x : [ Red, White, Blue ]
x = Blue
y =
when x is
Red -> 1
White -> 2
Blue -> 3.1
y
"#
),
3.1,
f64
);
}
#[test]
fn join_point_with_cond_expr() {
assert_evals_to_ir!(
indoc!(
r#"
y =
when 1 + 2 is
3 -> 3
1 -> 1
_ -> 0
y
"#
),
3,
i64
);
assert_evals_to_ir!(
indoc!(
r#"
y =
if 1 + 2 > 0 then
3
else
0
y
"#
),
3,
i64 i64
); );
} }

View file

@ -873,7 +873,7 @@ enum Decider<'a, T> {
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
enum Choice<'a> { enum Choice<'a> {
Inline(Stores<'a>, Stmt<'a>), Inline(Stmt<'a>),
Jump(Label), Jump(Label),
} }
@ -885,20 +885,17 @@ pub fn optimize_when<'a>(
cond_symbol: Symbol, cond_symbol: Symbol,
cond_layout: Layout<'a>, cond_layout: Layout<'a>,
ret_layout: Layout<'a>, ret_layout: Layout<'a>,
opt_branches: Vec<(Pattern<'a>, Guard<'a>, Stores<'a>, Stmt<'a>)>, opt_branches: bumpalo::collections::Vec<'a, (Pattern<'a>, Guard<'a>, Stmt<'a>)>,
) -> Stmt<'a> { ) -> Stmt<'a> {
let (patterns, _indexed_branches) = opt_branches let (patterns, _indexed_branches) = opt_branches
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(index, (pattern, guard, stores, branch))| { .map(|(index, (pattern, guard, branch))| {
( ((guard, pattern, index as u64), (index as u64, branch))
(guard, pattern, index as u64),
(index as u64, stores, branch),
)
}) })
.unzip(); .unzip();
let indexed_branches: Vec<(u64, Stores<'a>, Stmt<'a>)> = _indexed_branches; let indexed_branches: Vec<(u64, Stmt<'a>)> = _indexed_branches;
let decision_tree = compile(patterns); let decision_tree = compile(patterns);
let decider = tree_to_decider(decision_tree); let decider = tree_to_decider(decision_tree);
@ -907,9 +904,8 @@ pub fn optimize_when<'a>(
let mut choices = MutMap::default(); let mut choices = MutMap::default();
let mut jumps = Vec::new(); let mut jumps = Vec::new();
for (index, stores, branch) in indexed_branches.into_iter() { for (index, branch) in indexed_branches.into_iter() {
let ((branch_index, choice), opt_jump) = let ((branch_index, choice), opt_jump) = create_choices(&target_counts, index, branch);
create_choices(&target_counts, index, stores, branch);
if let Some(jump) = opt_jump { if let Some(jump) = opt_jump {
jumps.push(jump); jumps.push(jump);
@ -920,7 +916,7 @@ pub fn optimize_when<'a>(
let choice_decider = insert_choices(&choices, decider); let choice_decider = insert_choices(&choices, decider);
let (stores, expr) = decide_to_branching( let expr = decide_to_branching(
env, env,
cond_symbol, cond_symbol,
cond_layout, cond_layout,
@ -932,7 +928,6 @@ pub fn optimize_when<'a>(
// increase the jump counter by the number of jumps in this branching structure // increase the jump counter by the number of jumps in this branching structure
*env.jump_counter += jumps.len() as u64; *env.jump_counter += jumps.len() as u64;
// Expr::Store(stores, env.arena.alloc(expr))
expr expr
} }
@ -1100,6 +1095,7 @@ fn test_to_equality<'a>(
let lhs = Expr::Literal(Literal::Bool(test_bit)); let lhs = Expr::Literal(Literal::Bool(test_bit));
let lhs_symbol = env.unique_symbol(); let lhs_symbol = env.unique_symbol();
let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout);
stores.push((lhs_symbol, Layout::Builtin(Builtin::Int1), lhs));
( (
stores, stores,
@ -1134,21 +1130,21 @@ fn decide_to_branching<'a>(
cond_layout: Layout<'a>, cond_layout: Layout<'a>,
ret_layout: Layout<'a>, ret_layout: Layout<'a>,
decider: Decider<'a, Choice<'a>>, decider: Decider<'a, Choice<'a>>,
jumps: &Vec<(u64, Stores<'a>, Stmt<'a>)>, jumps: &Vec<(u64, Stmt<'a>)>,
) -> (Stores<'a>, Stmt<'a>) { ) -> Stmt<'a> {
use Choice::*; use Choice::*;
use Decider::*; use Decider::*;
match decider { match decider {
Leaf(Jump(label)) => { Leaf(Jump(label)) => {
// we currently inline the jumps: does fewer jumps but produces a larger artifact // we currently inline the jumps: does fewer jumps but produces a larger artifact
let (_, stores, expr) = jumps let (_, expr) = jumps
.iter() .iter()
.find(|(l, _, _)| l == &label) .find(|(l, _)| l == &label)
.expect("jump not in list of jumps"); .expect("jump not in list of jumps");
(stores, expr.clone()) expr.clone()
} }
Leaf(Inline(stores, expr)) => (stores, expr), Leaf(Inline(expr)) => expr,
Chain { Chain {
test_chain, test_chain,
success, success,
@ -1156,7 +1152,7 @@ fn decide_to_branching<'a>(
} => { } => {
// generate a switch based on the test chain // generate a switch based on the test chain
let (pass_stores, mut pass_expr) = decide_to_branching( let pass_expr = decide_to_branching(
env, env,
cond_symbol, cond_symbol,
cond_layout.clone(), cond_layout.clone(),
@ -1165,15 +1161,7 @@ fn decide_to_branching<'a>(
jumps, jumps,
); );
dbg!(&pass_expr); let fail_expr = decide_to_branching(
// TODO remove clone
for (symbol, layout, expr) in pass_stores.iter().cloned().rev() {
println!("{} {:?}", symbol, expr);
pass_expr = Stmt::Let(symbol, expr, layout, env.arena.alloc(pass_expr));
}
let (fail_stores, mut fail_expr) = decide_to_branching(
env, env,
cond_symbol, cond_symbol,
cond_layout.clone(), cond_layout.clone(),
@ -1182,11 +1170,6 @@ fn decide_to_branching<'a>(
jumps, jumps,
); );
// TODO remove clone
for (symbol, layout, expr) in fail_stores.iter().cloned().rev() {
fail_expr = Stmt::Let(symbol, expr, layout, env.arena.alloc(fail_expr));
}
let fail = &*env.arena.alloc(fail_expr); let fail = &*env.arena.alloc(fail_expr);
let pass = &*env.arena.alloc(pass_expr); let pass = &*env.arena.alloc(pass_expr);
@ -1331,7 +1314,7 @@ fn decide_to_branching<'a>(
); );
// (env.arena.alloc(stores), cond) // (env.arena.alloc(stores), cond)
(&[], cond) cond
} }
FanOut { FanOut {
path, path,
@ -1344,7 +1327,7 @@ fn decide_to_branching<'a>(
let (cond, cond_stores_vec, cond_layout) = let (cond, cond_stores_vec, cond_layout) =
path_to_expr_help2(env, cond_symbol, &path, cond_layout); path_to_expr_help2(env, cond_symbol, &path, cond_layout);
let (default_stores, mut default_branch) = decide_to_branching( let default_branch = decide_to_branching(
env, env,
cond_symbol, cond_symbol,
cond_layout.clone(), cond_layout.clone(),
@ -1353,15 +1336,10 @@ fn decide_to_branching<'a>(
jumps, jumps,
); );
// TODO remove clone
for (symbol, layout, expr) in default_stores.iter().cloned() {
default_branch = Stmt::Let(symbol, expr, layout, env.arena.alloc(default_branch));
}
let mut branches = bumpalo::collections::Vec::with_capacity_in(tests.len(), env.arena); let mut branches = bumpalo::collections::Vec::with_capacity_in(tests.len(), env.arena);
for (test, decider) in tests { for (test, decider) in tests {
let (stores, mut branch) = decide_to_branching( let branch = decide_to_branching(
env, env,
cond_symbol, cond_symbol,
cond_layout.clone(), cond_layout.clone(),
@ -1370,11 +1348,6 @@ fn decide_to_branching<'a>(
jumps, jumps,
); );
// TODO remove clone
for (symbol, layout, expr) in stores.iter().cloned() {
branch = Stmt::Let(symbol, expr, layout, env.arena.alloc(branch));
}
let tag = match test { let tag = match test {
Test::IsInt(v) => v as u64, Test::IsInt(v) => v as u64,
Test::IsFloat(v) => v as u64, Test::IsFloat(v) => v as u64,
@ -1386,7 +1359,6 @@ fn decide_to_branching<'a>(
branches.push((tag, branch)); branches.push((tag, branch));
} }
dbg!(&branches, &default_branch);
let mut switch = Stmt::Switch { let mut switch = Stmt::Switch {
cond_layout, cond_layout,
@ -1401,7 +1373,7 @@ fn decide_to_branching<'a>(
} }
// make a jump table based on the tests // make a jump table based on the tests
(&[], switch) switch
} }
} }
} }
@ -1601,19 +1573,15 @@ fn count_targets_help(decision_tree: &Decider<u64>, targets: &mut MutMap<u64, u6
fn create_choices<'a>( fn create_choices<'a>(
target_counts: &MutMap<u64, u64>, target_counts: &MutMap<u64, u64>,
target: u64, target: u64,
stores: Stores<'a>,
branch: Stmt<'a>, branch: Stmt<'a>,
) -> ((u64, Choice<'a>), Option<(u64, Stores<'a>, Stmt<'a>)>) { ) -> ((u64, Choice<'a>), Option<(u64, Stmt<'a>)>) {
match target_counts.get(&target) { match target_counts.get(&target) {
None => unreachable!( None => unreachable!(
"this should never happen: {:?} not in {:?}", "this should never happen: {:?} not in {:?}",
target, target_counts target, target_counts
), ),
Some(1) => ((target, Choice::Inline(stores, branch)), None), Some(1) => ((target, Choice::Inline(branch)), None),
Some(_) => ( Some(_) => ((target, Choice::Jump(target)), Some((target, branch))),
(target, Choice::Jump(target)),
Some((target, stores, branch)),
),
} }
} }

View file

@ -349,7 +349,7 @@ pub enum Stmt<'a> {
Dec(Symbol, &'a Stmt<'a>), Dec(Symbol, &'a Stmt<'a>),
Join { Join {
id: JoinPointId, id: JoinPointId,
arguments: &'a [Symbol], arguments: &'a [(Symbol, Layout<'a>)],
/// does not contain jumps to this id /// does not contain jumps to this id
continuation: &'a Stmt<'a>, continuation: &'a Stmt<'a>,
/// contains the jumps to this id /// contains the jumps to this id
@ -548,7 +548,7 @@ impl<'a> Stmt<'a> {
) -> Self { ) -> Self {
let mut layout_cache = LayoutCache::default(); let mut layout_cache = LayoutCache::default();
from_can(env, can_expr, procs, &mut layout_cache) dbg!(from_can(env, can_expr, procs, &mut layout_cache))
} }
pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D, parens: bool) -> DocBuilder<'b, D, A> pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D, parens: bool) -> DocBuilder<'b, D, A>
where where
@ -625,17 +625,23 @@ impl<'a> Stmt<'a> {
arguments, arguments,
continuation, continuation,
remainder, remainder,
} => alloc.intersperse( } => {
let it = arguments.iter().map(|(s, _)| symbol_to_doc(alloc, *s));
alloc.intersperse(
vec![ vec![
remainder.to_doc(alloc, false), remainder.to_doc(alloc, false),
alloc alloc
.text("joinpoint ") .text("joinpoint ")
.append(join_point_to_doc(alloc, *id)) .append(join_point_to_doc(alloc, *id))
.append(" ".repeat(arguments.len().min(1)))
.append(alloc.intersperse(it, alloc.space()))
.append(":"), .append(":"),
continuation.to_doc(alloc, false).indent(4), continuation.to_doc(alloc, false).indent(4),
], ],
alloc.hardline(), alloc.hardline(),
), )
}
Jump(id, arguments) => { Jump(id, arguments) => {
let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s)); let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s));
@ -708,7 +714,7 @@ fn patterns_to_when<'a>(
// are only stores anyway, no branches. // are only stores anyway, no branches.
for (pattern_var, pattern) in patterns.into_iter() { for (pattern_var, pattern) in patterns.into_iter() {
let context = crate::pattern2::Context::BadArg; let context = crate::pattern2::Context::BadArg;
let mono_pattern = from_can_pattern(env, procs, layout_cache, &pattern.value); let mono_pattern = from_can_pattern(env, layout_cache, &pattern.value);
match crate::pattern2::check( match crate::pattern2::check(
pattern.region, pattern.region,
@ -1241,7 +1247,7 @@ pub fn with_hole<'a>(
.expect("invalid cond_layout"); .expect("invalid cond_layout");
let id = JoinPointId(env.unique_symbol()); let id = JoinPointId(env.unique_symbol());
let jump = env.arena.alloc(Stmt::Jump(id, &[])); let jump = env.arena.alloc(Stmt::Jump(id, env.arena.alloc([assigned])));
let mut stmt = with_hole(env, final_else.value, procs, layout_cache, assigned, jump); let mut stmt = with_hole(env, final_else.value, procs, layout_cache, assigned, jump);
@ -1270,18 +1276,66 @@ pub fn with_hole<'a>(
); );
} }
let join = Stmt::Join { let layout = layout_cache
.from_var(env.arena, branch_var, env.subs, env.pointer_size)
.unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err));
Stmt::Join {
id, id,
arguments: &[], arguments: env.arena.alloc([(assigned, layout)]),
remainder: env.arena.alloc(stmt), remainder: env.arena.alloc(stmt),
continuation: hole, continuation: hole,
}; }
// expr
join
} }
When { .. } | If { .. } => todo!("when or if in expression requires join points"), When {
cond_var,
expr_var,
region,
loc_cond,
branches,
} => {
let cond_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_cond.value {
symbol
} else {
env.unique_symbol()
};
let id = JoinPointId(env.unique_symbol());
let mut stmt = from_can_when(
env,
cond_var,
expr_var,
region,
cond_symbol,
branches,
layout_cache,
procs,
Some((id, assigned)),
);
// TODO define condition
// define the `when` condition
if let roc_can::expr::Expr::Var(_) = loc_cond.value {
// do nothing
} else {
let hole = env.arena.alloc(stmt);
stmt = with_hole(env, loc_cond.value, procs, layout_cache, cond_symbol, hole);
};
let layout = layout_cache
.from_var(env.arena, expr_var, env.subs, env.pointer_size)
.unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err));
Stmt::Join {
id,
arguments: bumpalo::vec![in env.arena; (assigned, layout)].into_bump_slice(),
remainder: env.arena.alloc(stmt),
continuation: hole,
}
}
List { loc_elems, .. } if loc_elems.is_empty() => { List { loc_elems, .. } if loc_elems.is_empty() => {
// because an empty list has an unknown element type, it is handled differently // because an empty list has an unknown element type, it is handled differently
@ -1687,6 +1741,7 @@ pub fn from_can<'a>(
branches, branches,
layout_cache, layout_cache,
procs, procs,
None,
); );
if let roc_can::expr::Expr::Var(_) = loc_cond.value { if let roc_can::expr::Expr::Var(_) = loc_cond.value {
@ -1704,7 +1759,7 @@ pub fn from_can<'a>(
} }
} }
fn from_can_when<'a>( fn to_opt_branches<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
cond_var: Variable, cond_var: Variable,
expr_var: Variable, expr_var: Variable,
@ -1712,75 +1767,13 @@ fn from_can_when<'a>(
cond_symbol: Symbol, cond_symbol: Symbol,
mut branches: std::vec::Vec<roc_can::expr::WhenBranch>, mut branches: std::vec::Vec<roc_can::expr::WhenBranch>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
procs: &mut Procs<'a>, ) -> std::vec::Vec<(
) -> Stmt<'a> { Pattern<'a>,
if branches.is_empty() { crate::decision_tree2::Guard<'a>,
// A when-expression with no branches is a runtime error. roc_can::expr::Expr,
// We can't know what to return! )> {
Stmt::RuntimeError("Hit a 0-branch when expression") debug_assert!(!branches.is_empty());
} else if branches.len() == 1 && branches[0].patterns.len() == 1 && branches[0].guard.is_none()
{
let first = branches.remove(0);
// A when-expression with exactly 1 branch is essentially a LetNonRec.
// As such, we can compile it direcly to a Store.
let arena = env.arena;
let mut stored = Vec::with_capacity_in(1, arena);
let bound_symbols = first
.patterns
.iter()
.map(|pat| roc_can::pattern::symbols_from_pattern(&pat.value))
.flatten()
.collect::<std::vec::Vec<_>>();
let loc_when_pattern = &first.patterns[0];
let mono_pattern = from_can_pattern(env, procs, layout_cache, &loc_when_pattern.value);
// record pattern matches can have 1 branch and typecheck, but may still not be exhaustive
let guard = if first.guard.is_some() {
Guard::HasGuard
} else {
Guard::NoGuard
};
let context = crate::pattern2::Context::BadCase;
match crate::pattern2::check(
region,
&[(
Located::at(loc_when_pattern.region, mono_pattern.clone()),
guard,
)],
context,
) {
Ok(_) => {}
Err(errors) => {
for error in errors {
env.problems.push(MonoProblem::PatternProblem(error))
}
// panic!("generate runtime error, should probably also optimize this");
}
}
let cond_layout = layout_cache
.from_var(env.arena, cond_var, env.subs, env.pointer_size)
.unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err));
// NOTE this will still store shadowed names.
// that's fine: the branch throws a runtime error anyway
let mut ret = match store_pattern(env, &mono_pattern, cond_symbol, cond_layout, &mut stored)
{
Ok(_) => from_can(env, first.value.value, procs, layout_cache),
Err(message) => Stmt::RuntimeError(env.arena.alloc(message)),
};
for (symbol, layout, expr) in stored.iter().rev().cloned() {
ret = Stmt::Let(symbol, expr, layout, env.arena.alloc(ret));
}
ret
} else {
let cond_layout = layout_cache let cond_layout = layout_cache
.from_var(env.arena, cond_var, env.subs, env.pointer_size) .from_var(env.arena, cond_var, env.subs, env.pointer_size)
.unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err));
@ -1789,8 +1782,6 @@ fn from_can_when<'a>(
let mut opt_branches = std::vec::Vec::new(); let mut opt_branches = std::vec::Vec::new();
for when_branch in branches { for when_branch in branches {
let mono_expr = from_can(env, when_branch.value.value, procs, layout_cache);
let exhaustive_guard = if when_branch.guard.is_some() { let exhaustive_guard = if when_branch.guard.is_some() {
Guard::HasGuard Guard::HasGuard
} else { } else {
@ -1798,92 +1789,23 @@ fn from_can_when<'a>(
}; };
for loc_pattern in when_branch.patterns { for loc_pattern in when_branch.patterns {
let mono_pattern = from_can_pattern(env, procs, layout_cache, &loc_pattern.value); let mono_pattern = from_can_pattern(env, layout_cache, &loc_pattern.value);
loc_branches.push(( loc_branches.push((
Located::at(loc_pattern.region, mono_pattern.clone()), Located::at(loc_pattern.region, mono_pattern.clone()),
exhaustive_guard.clone(), exhaustive_guard.clone(),
)); ));
let mut stores = Vec::with_capacity_in(1, env.arena); // TODO implement guard again
let mono_guard = crate::decision_tree2::Guard::NoGuard;
let (mono_guard, stores, expr) = match store_pattern( opt_branches.push((mono_pattern, mono_guard, when_branch.value.value.clone()));
env,
&mono_pattern,
cond_symbol,
cond_layout.clone(),
&mut stores,
) {
Ok(_) => {
// if the branch is guarded, the guard can use variables bound in the
// pattern. They must be available, so we give the stores to the
// decision_tree. A branch with guard can only be entered with the guard
// evaluated, so variables will also be loaded in the branch's body expr.
//
// otherwise, we modify the branch's expression to include the stores
if let Some(loc_guard) = when_branch.guard.clone() {
let guard_symbol = env.unique_symbol();
let id = JoinPointId(env.unique_symbol());
let hole = env.arena.alloc(Stmt::Jump(id, &[]));
let mut stmt = with_hole(
env,
loc_guard.value,
procs,
layout_cache,
guard_symbol,
hole,
);
for (symbol, expr, layout) in stores.into_iter().rev() {
stmt = Stmt::Let(symbol, layout, expr, env.arena.alloc(stmt));
}
(
crate::decision_tree2::Guard::Guard {
stmt,
id,
symbol: guard_symbol,
},
&[] as &[_],
mono_expr.clone(),
)
} else {
(
crate::decision_tree2::Guard::NoGuard,
stores.into_bump_slice(),
mono_expr.clone(),
)
}
}
Err(message) => {
// when the pattern is invalid, a guard must give a runtime error too
if when_branch.guard.is_some() {
/*
(
crate::decision_tree2::Guard::Guard {
stores: &[],
expr: Stmt::RuntimeError(env.arena.alloc(message)),
},
&[] as &[_],
// we can never hit this
Stmt::RuntimeError(&"invalid pattern with guard: unreachable"),
)
*/
todo!()
} else {
(
crate::decision_tree2::Guard::NoGuard,
&[] as &[_],
Stmt::RuntimeError(env.arena.alloc(message)),
)
}
}
};
opt_branches.push((mono_pattern, mono_guard, stores, expr));
} }
} }
// NOTE exhaustiveness is checked after the construction of all the branches
// In contrast to elm (currently), we still do codegen even if a pattern is non-exhaustive.
// So we not only report exhaustiveness errors, but also correct them
let context = crate::pattern2::Context::BadCase; let context = crate::pattern2::Context::BadCase;
match crate::pattern2::check(region, &loc_branches, context) { match crate::pattern2::check(region, &loc_branches, context) {
Ok(_) => {} Ok(_) => {}
@ -1914,41 +1836,105 @@ fn from_can_when<'a>(
opt_branches.push(( opt_branches.push((
Pattern::Underscore, Pattern::Underscore,
crate::decision_tree2::Guard::NoGuard, crate::decision_tree2::Guard::NoGuard,
&[], roc_can::expr::Expr::RuntimeError(
Stmt::RuntimeError("non-exhaustive pattern match"), roc_problem::can::RuntimeError::NonExhaustivePattern,
),
)); ));
} }
} }
} }
opt_branches
}
fn from_can_when<'a>(
env: &mut Env<'a, '_>,
cond_var: Variable,
expr_var: Variable,
region: Region,
cond_symbol: Symbol,
branches: std::vec::Vec<roc_can::expr::WhenBranch>,
layout_cache: &mut LayoutCache<'a>,
procs: &mut Procs<'a>,
join_point: Option<(JoinPointId, Symbol)>,
) -> Stmt<'a> {
if branches.is_empty() {
// A when-expression with no branches is a runtime error.
// We can't know what to return!
return Stmt::RuntimeError("Hit a 0-branch when expression");
}
let opt_branches = to_opt_branches(
env,
cond_var,
expr_var,
region,
cond_symbol,
branches,
layout_cache,
);
let cond_layout = layout_cache
.from_var(env.arena, cond_var, env.subs, env.pointer_size)
.unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err));
let ret_layout = layout_cache let ret_layout = layout_cache
.from_var(env.arena, expr_var, env.subs, env.pointer_size) .from_var(env.arena, expr_var, env.subs, env.pointer_size)
.unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err));
let arena = env.arena;
let it = opt_branches.into_iter().map(|(pattern, guard, can_expr)| {
let mut stores = Vec::with_capacity_in(1, env.arena);
let res_stores =
store_pattern(env, &pattern, cond_symbol, cond_layout.clone(), &mut stores);
let mut stmt = match join_point {
None => from_can(env, can_expr, procs, layout_cache),
Some((id, _symbol)) => {
let symbol = env.unique_symbol();
let arguments = bumpalo::vec![in env.arena; symbol].into_bump_slice();
let jump = env.arena.alloc(Stmt::Jump(id, arguments));
with_hole(env, can_expr, procs, layout_cache, symbol, jump)
}
};
match res_stores {
Ok(_) => {
for (symbol, layout, expr) in stores.into_iter().rev() {
stmt = Stmt::Let(symbol, expr, layout, env.arena.alloc(stmt));
}
(pattern, guard, stmt)
}
Err(msg) => (
Pattern::Underscore,
guard,
Stmt::RuntimeError(env.arena.alloc(msg)),
),
}
});
let mono_branches = Vec::from_iter_in(it, arena);
crate::decision_tree2::optimize_when( crate::decision_tree2::optimize_when(
env, env,
cond_symbol, cond_symbol,
cond_layout.clone(), cond_layout.clone(),
ret_layout, ret_layout,
opt_branches, mono_branches,
) )
} }
}
fn store_pattern<'a>( fn store_pattern<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
can_pat: &Pattern<'a>, can_pat: &Pattern<'a>,
outer_symbol: Symbol, outer_symbol: Symbol,
layout: Layout<'a>, _layout: Layout<'a>,
stored: &mut Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>, stored: &mut Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>,
) -> Result<(), String> { ) -> Result<(), String> {
use Pattern::*; use Pattern::*;
match can_pat { match can_pat {
Identifier(symbol) => { Identifier(symbol) => {
// let load = Expr::Load(outer_symbol); // TODO surely something should happen here?
// stored.push((*symbol, layout, load))
// todo!()
} }
Underscore => { Underscore => {
// Since _ is never read, it's safe to reassign it. // Since _ is never read, it's safe to reassign it.
@ -2343,7 +2329,7 @@ pub struct RecordDestruct<'a> {
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum DestructType<'a> { pub enum DestructType<'a> {
Required, Required,
Optional(Stmt<'a>), Optional(roc_can::expr::Expr),
Guard(Pattern<'a>), Guard(Pattern<'a>),
} }
@ -2356,7 +2342,6 @@ pub struct WhenBranch<'a> {
pub fn from_can_pattern<'a>( pub fn from_can_pattern<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
can_pattern: &roc_can::pattern::Pattern, can_pattern: &roc_can::pattern::Pattern,
) -> Pattern<'a> { ) -> Pattern<'a> {
@ -2466,7 +2451,7 @@ pub fn from_can_pattern<'a>(
let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena);
for ((_, loc_pat), layout) in arguments.iter().zip(field_layouts.iter()) { for ((_, loc_pat), layout) in arguments.iter().zip(field_layouts.iter()) {
mono_args.push(( mono_args.push((
from_can_pattern(env, procs, layout_cache, &loc_pat.value), from_can_pattern(env, layout_cache, &loc_pat.value),
layout.clone(), layout.clone(),
)); ));
} }
@ -2508,7 +2493,7 @@ pub fn from_can_pattern<'a>(
let it = argument_layouts[1..].iter(); let it = argument_layouts[1..].iter();
for ((_, loc_pat), layout) in arguments.iter().zip(it) { for ((_, loc_pat), layout) in arguments.iter().zip(it) {
mono_args.push(( mono_args.push((
from_can_pattern(env, procs, layout_cache, &loc_pat.value), from_can_pattern(env, layout_cache, &loc_pat.value),
layout.clone(), layout.clone(),
)); ));
} }
@ -2561,7 +2546,6 @@ pub fn from_can_pattern<'a>(
mono_destructs.push(from_can_record_destruct( mono_destructs.push(from_can_record_destruct(
env, env,
procs,
layout_cache, layout_cache,
&destruct.value, &destruct.value,
field_layout.clone(), field_layout.clone(),
@ -2597,7 +2581,6 @@ pub fn from_can_pattern<'a>(
fn from_can_record_destruct<'a>( fn from_can_record_destruct<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
can_rd: &roc_can::pattern::RecordDestruct, can_rd: &roc_can::pattern::RecordDestruct,
field_layout: Layout<'a>, field_layout: Layout<'a>,
@ -2609,11 +2592,11 @@ fn from_can_record_destruct<'a>(
typ: match &can_rd.typ { typ: match &can_rd.typ {
roc_can::pattern::DestructType::Required => DestructType::Required, roc_can::pattern::DestructType::Required => DestructType::Required,
roc_can::pattern::DestructType::Optional(_, loc_expr) => { roc_can::pattern::DestructType::Optional(_, loc_expr) => {
DestructType::Optional(from_can(env, loc_expr.value.clone(), procs, layout_cache)) DestructType::Optional(loc_expr.value.clone())
}
roc_can::pattern::DestructType::Guard(_, loc_pattern) => {
DestructType::Guard(from_can_pattern(env, layout_cache, &loc_pattern.value))
} }
roc_can::pattern::DestructType::Guard(_, loc_pattern) => DestructType::Guard(
from_can_pattern(env, procs, layout_cache, &loc_pattern.value),
),
}, },
} }
} }

View file

@ -1555,4 +1555,35 @@ mod test_mono {
), ),
) )
} }
#[test]
fn when_joinpoint() {
compiles_to_ir(
r#"
x : [ Red, White, Blue ]
x = Blue
y =
when x is
Red -> 1
White -> 2
Blue -> 3
y
"#,
indoc!(
r#"
procedure List.5 (#Attr.2, #Attr.3):
let Test.3 = lowlevel ListPush #Attr.2 #Attr.3;
ret Test.3;
let Test.4 = 1i64;
let Test.1 = Array [Test.4];
let Test.2 = 2i64;
let Test.0 = CallByName List.5 Test.1 Test.2;
ret Test.0;
"#
),
)
}
} }

View file

@ -123,6 +123,8 @@ pub enum RuntimeError {
InvalidInt(IntErrorKind, Base, Region, Box<str>), InvalidInt(IntErrorKind, Base, Region, Box<str>),
CircularDef(Vec<Symbol>, Vec<(Region /* pattern */, Region /* expr */)>), CircularDef(Vec<Symbol>, Vec<(Region /* pattern */, Region /* expr */)>),
NonExhaustivePattern,
/// When the author specifies a type annotation but no implementation /// When the author specifies a type annotation but no implementation
NoImplementation, NoImplementation,
} }