diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 94877e4ad0..a1f595fcf0 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -3150,118 +3150,123 @@ pub fn with_hole<'a>( branches, final_else, } => { - let ret_layout = layout_cache - .from_var(env.arena, branch_var, env.subs) - .expect("invalid ret_layout"); - let cond_layout = layout_cache - .from_var(env.arena, cond_var, env.subs) - .expect("invalid cond_layout"); + match ( + layout_cache.from_var(env.arena, branch_var, env.subs), + layout_cache.from_var(env.arena, cond_var, env.subs), + ) { + (Ok(ret_layout), Ok(cond_layout)) => { + // if the hole is a return, then we don't need to merge the two + // branches together again, we can just immediately return + let is_terminated = matches!(hole, Stmt::Ret(_)); - // if the hole is a return, then we don't need to merge the two - // branches together again, we can just immediately return - let is_terminated = matches!(hole, Stmt::Ret(_)); + if is_terminated { + let terminator = hole; - if is_terminated { - let terminator = hole; + let mut stmt = with_hole( + env, + final_else.value, + branch_var, + procs, + layout_cache, + assigned, + terminator, + ); - let mut stmt = with_hole( - env, - final_else.value, - branch_var, - procs, - layout_cache, - assigned, - terminator, - ); + for (loc_cond, loc_then) in branches.into_iter().rev() { + let branching_symbol = env.unique_symbol(); - for (loc_cond, loc_then) in branches.into_iter().rev() { - let branching_symbol = env.unique_symbol(); + let then = with_hole( + env, + loc_then.value, + branch_var, + procs, + layout_cache, + assigned, + terminator, + ); - let then = with_hole( - env, - loc_then.value, - branch_var, - procs, - layout_cache, - assigned, - terminator, - ); + stmt = cond(env, branching_symbol, cond_layout, then, stmt, ret_layout); - stmt = cond(env, branching_symbol, cond_layout, then, stmt, ret_layout); + // add condition + stmt = with_hole( + env, + loc_cond.value, + cond_var, + procs, + layout_cache, + branching_symbol, + env.arena.alloc(stmt), + ); + } + stmt + } else { + let assigned_in_jump = env.unique_symbol(); + let id = JoinPointId(env.unique_symbol()); - // add condition - stmt = with_hole( - env, - loc_cond.value, - cond_var, - procs, - layout_cache, - branching_symbol, - env.arena.alloc(stmt), - ); - } - stmt - } else { - let assigned_in_jump = env.unique_symbol(); - let id = JoinPointId(env.unique_symbol()); - - let terminator = env - .arena - .alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump]))); - - let mut stmt = with_hole( - env, - final_else.value, - branch_var, - procs, - layout_cache, - assigned_in_jump, - terminator, - ); - - for (loc_cond, loc_then) in branches.into_iter().rev() { - let branching_symbol = possible_reuse_symbol(env, procs, &loc_cond.value); - - let then = with_hole( - env, - loc_then.value, - branch_var, - procs, - layout_cache, - assigned_in_jump, - terminator, - ); - - stmt = cond(env, branching_symbol, cond_layout, then, stmt, ret_layout); - - // add condition - stmt = assign_to_symbol( - env, - procs, - layout_cache, - cond_var, - loc_cond, - branching_symbol, - stmt, - ); - } - - let layout = layout_cache - .from_var(env.arena, branch_var, env.subs) - .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); - - let param = Param { - symbol: assigned, - layout, - borrow: false, - }; - - Stmt::Join { - id, - parameters: env.arena.alloc([param]), - remainder: env.arena.alloc(stmt), - body: hole, + let terminator = env + .arena + .alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump]))); + + let mut stmt = with_hole( + env, + final_else.value, + branch_var, + procs, + layout_cache, + assigned_in_jump, + terminator, + ); + + for (loc_cond, loc_then) in branches.into_iter().rev() { + let branching_symbol = + possible_reuse_symbol(env, procs, &loc_cond.value); + + let then = with_hole( + env, + loc_then.value, + branch_var, + procs, + layout_cache, + assigned_in_jump, + terminator, + ); + + stmt = cond(env, branching_symbol, cond_layout, then, stmt, ret_layout); + + // add condition + stmt = assign_to_symbol( + env, + procs, + layout_cache, + cond_var, + loc_cond, + branching_symbol, + stmt, + ); + } + + let layout = layout_cache + .from_var(env.arena, branch_var, env.subs) + .unwrap_or_else(|err| { + panic!("TODO turn fn_var into a RuntimeError {:?}", err) + }); + + let param = Param { + symbol: assigned, + layout, + borrow: false, + }; + + Stmt::Join { + id, + parameters: env.arena.alloc([param]), + remainder: env.arena.alloc(stmt), + body: hole, + } + } } + (Err(_), _) => Stmt::RuntimeError("invalid ret_layout"), + (_, Err(_)) => Stmt::RuntimeError("invalid cond_layout"), } }