diff --git a/compiler/alias_analysis/src/lib.rs b/compiler/alias_analysis/src/lib.rs index 7b8c34d6eb..91e331f8d5 100644 --- a/compiler/alias_analysis/src/lib.rs +++ b/compiler/alias_analysis/src/lib.rs @@ -439,6 +439,7 @@ fn stmt_spec<'a>( builder.add_choice(block, &cases) } + Expect { remainder, .. } => stmt_spec(builder, env, block, layout, remainder), Ret(symbol) => Ok(env.symbols[symbol]), Refcounting(modify_rc, continuation) => match modify_rc { ModifyRc::Inc(symbol, _) => { diff --git a/compiler/gen_dev/src/lib.rs b/compiler/gen_dev/src/lib.rs index bdc9aaf6b2..6085df5b0e 100644 --- a/compiler/gen_dev/src/lib.rs +++ b/compiler/gen_dev/src/lib.rs @@ -979,6 +979,9 @@ trait Backend<'a> { self.set_last_seen(*sym, stmt); } } + + Stmt::Expect { .. } => todo!("expect is not implemented in the wasm backend"), + Stmt::RuntimeError(_) => {} } } diff --git a/compiler/gen_llvm/src/llvm/build.rs b/compiler/gen_llvm/src/llvm/build.rs index 50d0bbfe59..19e81be794 100644 --- a/compiler/gen_llvm/src/llvm/build.rs +++ b/compiler/gen_llvm/src/llvm/build.rs @@ -2737,6 +2737,80 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( } } + Expect { + condition: cond, + lookups: _, + layouts: _, + remainder, + } => { + // do stuff + + let bd = env.builder; + let context = env.context; + + let (cond, _cond_layout) = load_symbol_and_layout(scope, cond); + + let condition = bd.build_int_compare( + IntPredicate::EQ, + cond.into_int_value(), + context.bool_type().const_int(1, false), + "is_true", + ); + + let then_block = context.append_basic_block(parent, "then_block"); + let throw_block = context.append_basic_block(parent, "throw_block"); + + bd.build_conditional_branch(condition, then_block, throw_block); + + { + bd.position_at_end(throw_block); + + match env.target_info.ptr_width() { + roc_target::PtrWidth::Bytes8 => { + let func = env + .module + .get_function(bitcode::UTILS_EXPECT_FAILED) + .unwrap(); + // TODO get the actual line info instead of + // hardcoding as zero! + let callable = CallableValue::try_from(func).unwrap(); + let start_line = context.i32_type().const_int(0, false); + let end_line = context.i32_type().const_int(0, false); + let start_col = context.i16_type().const_int(0, false); + let end_col = context.i16_type().const_int(0, false); + + bd.build_call( + callable, + &[ + start_line.into(), + end_line.into(), + start_col.into(), + end_col.into(), + ], + "call_expect_failed", + ); + + bd.build_unconditional_branch(then_block); + } + roc_target::PtrWidth::Bytes4 => { + // temporary WASM implementation + throw_exception(env, "An expectation failed!"); + } + } + } + + bd.position_at_end(then_block); + + build_exp_stmt( + env, + layout_ids, + func_spec_solutions, + scope, + parent, + remainder, + ) + } + RuntimeError(error_msg) => { throw_exception(env, error_msg); @@ -6128,8 +6202,6 @@ fn run_low_level<'a, 'ctx, 'env>( set } ExpectTrue => { - debug_assert_eq!(args.len(), 1); - let context = env.context; let bd = env.builder; diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 2a1a199c9d..dde3745dcf 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -457,6 +457,8 @@ impl<'a> WasmBackend<'a> { Stmt::Refcounting(modify, following) => self.stmt_refcounting(modify, following), + Stmt::Expect { .. } => todo!("expect is not implemented in the wasm backend"), + Stmt::RuntimeError(msg) => self.stmt_runtime_error(msg), } } diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index d6ddd7aa1a..d9437cbc11 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -307,6 +307,9 @@ impl<'a> ParamMap<'a> { Let(_, _, _, cont) => { stack.push(cont); } + + Expect { remainder, .. } => stack.push(remainder), + Switch { branches, default_branch, @@ -835,6 +838,11 @@ impl<'a> BorrowInfState<'a> { } self.collect_stmt(param_map, default_branch.1); } + + Expect { remainder, .. } => { + self.collect_stmt(param_map, remainder); + } + Refcounting(_, _) => unreachable!("these have not been introduced yet"), Ret(_) | RuntimeError(_) => { @@ -1027,6 +1035,9 @@ fn call_info_stmt<'a>(arena: &'a Bump, stmt: &Stmt<'a>, info: &mut CallInfo<'a>) stack.extend(branches.iter().map(|b| &b.2)); stack.push(default_branch.1); } + + Expect { remainder, .. } => stack.push(remainder), + Refcounting(_, _) => unreachable!("these have not been introduced yet"), Ret(_) | Jump(_, _) | RuntimeError(_) => { diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index 817a6b02a9..278d8a9a76 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -108,6 +108,15 @@ pub fn occurring_variables(stmt: &Stmt<'_>) -> (MutSet, MutSet) stack.push(cont); } + Expect { + condition, + remainder, + .. + } => { + result.insert(*condition); + stack.push(remainder); + } + Jump(_, arguments) => { result.extend(arguments.iter().copied()); } @@ -1196,6 +1205,8 @@ impl<'a> Context<'a> { (switch, case_live_vars) } + Expect { remainder, .. } => self.visit_stmt(codegen, remainder), + RuntimeError(_) | Refcounting(_, _) => (stmt, MutSet::default()), } } @@ -1299,6 +1310,15 @@ pub fn collect_stmt( collect_stmt(cont, jp_live_vars, vars) } + Expect { + condition, + remainder, + .. + } => { + vars.insert(*condition); + collect_stmt(remainder, jp_live_vars, vars) + } + Join { id: j, parameters, diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 2c85b48ce1..cc60e20fcf 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -1326,6 +1326,13 @@ pub enum Stmt<'a> { }, Ret(Symbol), Refcounting(ModifyRc, &'a Stmt<'a>), + Expect { + condition: Symbol, + lookups: &'a [Symbol], + layouts: &'a [Layout<'a>], + /// what happens after the expect + remainder: &'a Stmt<'a>, + }, /// a join point `join f = in remainder` Join { id: JoinPointId, @@ -1912,6 +1919,10 @@ impl<'a> Stmt<'a> { .append(alloc.hardline()) .append(cont.to_doc(alloc)), + Expect{condition, .. } => + alloc.text("expect ") + .append(symbol_to_doc(alloc, *condition)), + Ret(symbol) => alloc .text("ret ") .append(symbol_to_doc(alloc, *symbol)) @@ -5931,22 +5942,20 @@ pub fn from_can<'a>( Expect { loc_condition, loc_continuation, - lookups_in_cond: _ , + lookups_in_cond, } => { let rest = from_can(env, variable, loc_continuation.value, procs, layout_cache); let cond_symbol = env.unique_symbol(); - let mut stmt = Stmt::Let( - env.unique_symbol(), - Expr::Call(self::Call { - call_type: CallType::LowLevel { - op: LowLevel::ExpectTrue, - update_mode: env.next_update_mode_id(), - }, - arguments: env.arena.alloc([cond_symbol]), - }), - Layout::Builtin(Builtin::Bool), - env.arena.alloc(rest), - ); + + let lookups = Vec::from_iter_in(lookups_in_cond.iter().map(|t| t.0), env.arena); + // let layouts = Vec::from_iter_in(lookups_in_cond.iter().map(|t| t.1), env.arena); + + let mut stmt = Stmt::Expect { + condition: cond_symbol, + lookups: lookups.into_bump_slice(), + layouts: &[], + remainder: env.arena.alloc(rest), + }; stmt = with_hole( env, @@ -5958,7 +5967,6 @@ pub fn from_can<'a>( env.arena.alloc(stmt), ); - stmt } @@ -6300,6 +6308,14 @@ fn substitute_in_stmt_help<'a>( } } + Expect { condition, lookups, layouts, remainder } => { + // TODO should we substitute in the ModifyRc? + match substitute_in_stmt_help(arena, remainder, subs) { + Some(cont) => Some(arena.alloc(Expect { condition: *condition , lookups, layouts, remainder: cont} )), + None => None, + } + } + Jump(id, args) => { let mut did_change = false; let new_args = Vec::from_iter_in( diff --git a/compiler/mono/src/reset_reuse.rs b/compiler/mono/src/reset_reuse.rs index 4937eff16a..2dfcfb0af2 100644 --- a/compiler/mono/src/reset_reuse.rs +++ b/compiler/mono/src/reset_reuse.rs @@ -192,6 +192,30 @@ fn function_s<'a, 'i>( arena.alloc(new_refcounting) } } + + Expect { + condition, + lookups, + layouts, + remainder, + } => { + let continuation: &Stmt = *remainder; + let new_continuation = function_s(env, w, c, continuation); + + if std::ptr::eq(continuation, new_continuation) || continuation == new_continuation { + stmt + } else { + let new_refcounting = Expect { + condition: *condition, + lookups, + layouts, + remainder: new_continuation, + }; + + arena.alloc(new_refcounting) + } + } + Ret(_) | Jump(_, _) | RuntimeError(_) => stmt, } } @@ -388,6 +412,37 @@ fn function_d_main<'a, 'i>( (arena.alloc(refcounting), found) } } + + Expect { + condition, + lookups, + layouts, + remainder, + } => { + let (b, found) = function_d_main(env, x, c, remainder); + + if found || *condition != x { + let refcounting = Expect { + condition: *condition, + lookups, + layouts, + remainder: b, + }; + + (arena.alloc(refcounting), found) + } else { + let b = try_function_s(env, x, c, b); + + let refcounting = Expect { + condition: *condition, + lookups, + layouts, + remainder: b, + }; + + (arena.alloc(refcounting), found) + } + } Join { id, parameters, @@ -540,6 +595,24 @@ fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> arena.alloc(Refcounting(*modify_rc, b)) } + Expect { + condition, + lookups, + layouts, + remainder, + } => { + let b = function_r(env, remainder); + + let expect = Expect { + condition: *condition, + lookups, + layouts, + remainder: b, + }; + + arena.alloc(expect) + } + Ret(_) | Jump(_, _) | RuntimeError(_) => { // terminals stmt @@ -570,6 +643,11 @@ fn has_live_var<'a>(jp_live_vars: &JPLiveVarMap, stmt: &'a Stmt<'a>, needle: Sym Refcounting(modify_rc, cont) => { modify_rc.get_symbol() == needle || has_live_var(jp_live_vars, cont, needle) } + Expect { + condition, + remainder, + .. + } => *condition == needle || has_live_var(jp_live_vars, remainder, needle), Join { id, parameters, diff --git a/compiler/mono/src/tail_recursion.rs b/compiler/mono/src/tail_recursion.rs index e8d70c7828..431060f2ec 100644 --- a/compiler/mono/src/tail_recursion.rs +++ b/compiler/mono/src/tail_recursion.rs @@ -191,6 +191,21 @@ fn insert_jumps<'a>( None => None, }, + Expect { + condition, + lookups, + layouts, + remainder, + } => match insert_jumps(arena, remainder, goal_id, needle) { + Some(cont) => Some(arena.alloc(Expect { + condition: *condition, + lookups, + layouts, + remainder: cont, + })), + None => None, + }, + Ret(_) => None, Jump(_, _) => None, RuntimeError(_) => None, diff --git a/examples/hello-world/zig-platform/helloZig.roc b/examples/hello-world/zig-platform/helloZig.roc index 9ef3b5d311..e723cfcdf4 100644 --- a/examples/hello-world/zig-platform/helloZig.roc +++ b/examples/hello-world/zig-platform/helloZig.roc @@ -3,4 +3,8 @@ app "helloZig" imports [] provides [main] to pf -main = "Hello, World!\n" +x = 42 + +main = + expect x != x + "Hello, World!\n"