diff --git a/crates/wasm_interp/src/execute.rs b/crates/wasm_interp/src/execute.rs index 19e0a7bc3a..eeefeed80a 100644 --- a/crates/wasm_interp/src/execute.rs +++ b/crates/wasm_interp/src/execute.rs @@ -24,6 +24,7 @@ pub struct ExecutionState<'a> { pub globals: Vec<'a, Value>, pub program_counter: usize, block_depth: u32, + block_loop_addrs: Vec<'a, Option>, import_signatures: Vec<'a, u32>, debug_string: Option, } @@ -41,6 +42,7 @@ impl<'a> ExecutionState<'a> { globals: Vec::from_iter_in(globals, arena), program_counter, block_depth: 0, + block_loop_addrs: Vec::new_in(arena), import_signatures: Vec::new_in(arena), debug_string: Some(String::new()), } @@ -141,6 +143,7 @@ impl<'a> ExecutionState<'a> { globals, program_counter, block_depth: 0, + block_loop_addrs: Vec::new_in(arena), import_signatures, debug_string, }) @@ -198,26 +201,51 @@ impl<'a> ExecutionState<'a> { } } + fn do_break(&mut self, relative_blocks_outward: u32, module: &WasmModule<'a>, op: OpCode) { + let maybe_loop = if matches!(op, OpCode::ELSE) { + None + } else { + let block_index = self.block_loop_addrs.len() - 1 - relative_blocks_outward as usize; + self.block_loop_addrs[block_index].map(|addr| (block_index, addr)) + }; + + match maybe_loop { + Some((block_index, addr)) => { + self.block_loop_addrs.truncate(block_index + 1); + self.block_depth = self.block_loop_addrs.len() as u32; + self.program_counter = addr as usize; + } + None => { + self.break_forward(relative_blocks_outward, module); + } + } + } + fn break_forward(&mut self, relative_blocks_outward: u32, module: &WasmModule<'a>) { use OpCode::*; + let mut depth = self.block_depth; let target_block_depth = self.block_depth - relative_blocks_outward - 1; loop { let skipped_op = OpCode::from(module.code.bytes[self.program_counter]); OpCode::skip_bytes(&module.code.bytes, &mut self.program_counter).unwrap(); match skipped_op { BLOCK | LOOP | IF => { - self.block_depth += 1; + depth += 1; } END => { - self.block_depth -= 1; - if self.block_depth == target_block_depth { + depth -= 1; + if depth == target_block_depth { break; } } _ => {} } } + while self.block_depth > depth { + self.block_depth -= 1; + self.block_loop_addrs.pop().unwrap(); + } } pub fn execute_next_instruction(&mut self, module: &WasmModule<'a>) -> Action { @@ -245,15 +273,19 @@ impl<'a> ExecutionState<'a> { BLOCK => { self.fetch_immediate_u32(module); // blocktype (ignored) self.block_depth += 1; + self.block_loop_addrs.push(None); } LOOP => { self.fetch_immediate_u32(module); // blocktype (ignored) self.block_depth += 1; + self.block_loop_addrs + .push(Some(self.program_counter as u32)); } IF => { self.fetch_immediate_u32(module); // blocktype (ignored) let condition = self.value_stack.pop_i32(); self.block_depth += 1; + self.block_loop_addrs.push(None); if condition == 0 { let mut depth = self.block_depth; loop { @@ -280,7 +312,7 @@ impl<'a> ExecutionState<'a> { // We only reach this point when we finish executing the "then" block of an IF statement // (For a false condition, we would have skipped past the ELSE when we saw the IF) // We don't want to execute the ELSE block, so we skip it, just like `br 0` would. - self.break_forward(0, module); + self.do_break(0, module, op_code); } END => { if self.block_depth == 0 { @@ -292,13 +324,13 @@ impl<'a> ExecutionState<'a> { } BR => { let relative_blocks_outward = self.fetch_immediate_u32(module); - self.break_forward(relative_blocks_outward, module); + self.do_break(relative_blocks_outward, module, op_code); } BRIF => { let relative_blocks_outward = self.fetch_immediate_u32(module); let condition = self.value_stack.pop_i32(); if condition != 0 { - self.break_forward(relative_blocks_outward, module); + self.do_break(relative_blocks_outward, module, op_code); } } BRTABLE => { @@ -313,7 +345,7 @@ impl<'a> ExecutionState<'a> { } let fallback = self.fetch_immediate_u32(module); let relative_blocks_outward = selected.unwrap_or(fallback); - self.break_forward(relative_blocks_outward, module); + self.do_break(relative_blocks_outward, module, op_code); } RETURN => { action = self.do_return(); diff --git a/crates/wasm_interp/tests/test_opcodes.rs b/crates/wasm_interp/tests/test_opcodes.rs index 5d7d077230..c3a6e643ca 100644 --- a/crates/wasm_interp/tests/test_opcodes.rs +++ b/crates/wasm_interp/tests/test_opcodes.rs @@ -16,8 +16,94 @@ fn default_state(arena: &Bump) -> ExecutionState { ExecutionState::new(arena, pages, program_counter, globals) } -// #[test] -// fn test_loop() {} +#[test] +fn test_loop() { + test_loop_help(10, 55); +} + +fn test_loop_help(end: i32, expected: i32) { + let arena = Bump::new(); + let mut module = WasmModule::new(&arena); + let buf = &mut module.code.bytes; + + // Loop from 0 to end, adding the loop variable to a total + let var_i = 0; + let var_total = 1; + + // (local i32 i32) + buf.push(1); // one group of the given type + buf.push(2); // two locals in the group + buf.push(ValueType::I32 as u8); + + // loop + buf.push(OpCode::LOOP as u8); + buf.push(ValueType::VOID as u8); + + // local.get $i + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_i); + + // i32.const 1 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(1); + + // i32.add + buf.push(OpCode::I32ADD as u8); + + // local.tee $i + buf.push(OpCode::TEELOCAL as u8); + buf.encode_u32(var_i); + + // local.get $total + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_total); + + // i32.add + buf.push(OpCode::I32ADD as u8); + + // local.set $total + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(var_total); + + // local.get $i + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_i); + + // i32.const $end + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(end); + + // i32.lt_s + buf.push(OpCode::I32LTS as u8); + + // br_if 0 + buf.push(OpCode::BRIF as u8); + buf.encode_u32(0); + + // end + buf.push(OpCode::END as u8); + + // local.get $total + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_total); + + // end function + buf.push(OpCode::END as u8); + + let mut state = default_state(&arena); + state.call_stack.push_frame( + 0, + 0, + &[], + &mut state.value_stack, + &module.code.bytes, + &mut state.program_counter, + ); + + while let Action::Continue = state.execute_next_instruction(&module) {} + + assert_eq!(state.value_stack.pop_i32(), expected); +} #[test] fn test_if_else() {