diff --git a/crates/wasm_interp/src/frame.rs b/crates/wasm_interp/src/frame.rs index 71db1427bb..3275c1abd5 100644 --- a/crates/wasm_interp/src/frame.rs +++ b/crates/wasm_interp/src/frame.rs @@ -10,12 +10,14 @@ pub struct Frame { pub fn_index: usize, /// Address in the code section where this frame returns to pub return_addr: usize, - /// Depth of the "function block" for this frame - pub function_block_depth: usize, + /// Depth of the "function body block" for this frame + pub body_block_index: usize, /// Offset in the ValueStack where the args & locals begin pub locals_start: usize, /// Number of args & locals in the frame pub locals_count: usize, + /// Expected return type, if any + pub return_type: Option, } impl Frame { @@ -23,22 +25,23 @@ impl Frame { Frame { fn_index: 0, return_addr: 0, - function_block_depth: 0, + body_block_index: 0, locals_start: 0, locals_count: 0, + return_type: None, } } pub fn enter( fn_index: usize, return_addr: usize, - function_block_depth: usize, - arg_type_bytes: &[u8], + body_block_index: usize, + n_args: usize, + return_type: Option, code_bytes: &[u8], value_stack: &mut ValueStack<'_>, pc: &mut usize, ) -> Self { - let n_args = arg_type_bytes.len(); let locals_start = value_stack.depth() - n_args; // Parse local variable declarations in the function header. They're grouped by type. @@ -60,9 +63,10 @@ impl Frame { Frame { fn_index, return_addr, - function_block_depth, + body_block_index, locals_start, locals_count, + return_type, } } diff --git a/crates/wasm_interp/src/instance.rs b/crates/wasm_interp/src/instance.rs index 4a96ab04a8..5e496dc6de 100644 --- a/crates/wasm_interp/src/instance.rs +++ b/crates/wasm_interp/src/instance.rs @@ -4,7 +4,7 @@ use std::iter; use roc_wasm_module::opcodes::OpCode; use roc_wasm_module::parse::{Parse, SkipBytes}; -use roc_wasm_module::sections::{ImportDesc, MemorySection}; +use roc_wasm_module::sections::{ImportDesc, MemorySection, SignatureParamsIter}; use roc_wasm_module::{ExportType, WasmModule}; use roc_wasm_module::{Value, ValueType}; @@ -18,10 +18,18 @@ pub enum Action { Break, } -#[derive(Debug)] -enum Block { - Loop { vstack: usize, start_addr: usize }, - Normal { vstack: usize }, +#[derive(Debug, Clone, Copy)] +enum BlockType { + Loop(usize), // Loop block, with start address to loop back to + Normal, // Block created by `block` instruction + Locals(usize), // Special "block" for locals. Holds function index for debug + FunctionBody(usize), // Special block surrounding the function body. Holds function index for debug +} + +#[derive(Debug, Clone, Copy)] +struct Block { + ty: BlockType, + vstack: usize, } #[derive(Debug, Clone)] @@ -162,14 +170,11 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { where A: IntoIterator, { - let (fn_index, arg_type_bytes) = + let (fn_index, param_type_iter, ret_type) = self.call_export_help_before_arg_load(self.module, fn_name)?; + let n_args = param_type_iter.len(); - for (i, (value, type_byte)) in arg_values - .into_iter() - .zip(arg_type_bytes.iter().copied()) - .enumerate() - { + for (i, (value, type_byte)) in arg_values.into_iter().zip(param_type_iter).enumerate() { let expected_type = ValueType::from(type_byte); let actual_type = ValueType::from(value); if actual_type != expected_type { @@ -181,7 +186,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { self.value_stack.push(value); } - self.call_export_help_after_arg_load(self.module, fn_index, arg_type_bytes) + self.call_export_help_after_arg_load(self.module, fn_index, n_args, ret_type) } pub fn call_export_from_cli( @@ -203,11 +208,13 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { // Implement the "basic numbers" CLI // Check if the called Wasm function takes numeric arguments, and if so, try to parse them from the CLI. - let (fn_index, arg_type_bytes) = self.call_export_help_before_arg_load(module, fn_name)?; + let (fn_index, param_type_iter, ret_type) = + self.call_export_help_before_arg_load(module, fn_name)?; + let n_args = param_type_iter.len(); for (value_bytes, type_byte) in arg_strings .iter() .skip(1) // first string is the .wasm filename - .zip(arg_type_bytes.iter().copied()) + .zip(param_type_iter) { use ValueType::*; let value_str = String::from_utf8_lossy(value_bytes); @@ -220,14 +227,14 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { self.value_stack.push(value); } - self.call_export_help_after_arg_load(module, fn_index, arg_type_bytes) + self.call_export_help_after_arg_load(module, fn_index, n_args, ret_type) } fn call_export_help_before_arg_load<'m>( &mut self, module: &'m WasmModule<'a>, fn_name: &str, - ) -> Result<(usize, &'m [u8]), String> { + ) -> Result<(usize, SignatureParamsIter<'m>, Option), String> { let fn_index = { let mut export_iter = module.export.exports.iter(); export_iter @@ -270,9 +277,9 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { cursor }; - let arg_type_bytes = { + let (param_type_iter, return_type) = { let signature_index = module.function.signatures[internal_fn_index]; - module.types.look_up_arg_type_bytes(signature_index) + module.types.look_up(signature_index) }; if self.debug_string.is_some() { @@ -284,29 +291,36 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { ); } - Ok((fn_index, arg_type_bytes)) + Ok((fn_index, param_type_iter, return_type)) } fn call_export_help_after_arg_load( &mut self, module: &WasmModule<'a>, fn_index: usize, - arg_type_bytes: &[u8], + n_args: usize, + return_type: Option, ) -> Result, String> { self.previous_frames.clear(); self.blocks.clear(); - self.blocks.push(Block::Normal { + self.blocks.push(Block { + ty: BlockType::Locals(fn_index), vstack: self.value_stack.depth(), }); self.current_frame = Frame::enter( fn_index, 0, // return_addr self.blocks.len(), - arg_type_bytes, + n_args, + return_type, &module.code.bytes, &mut self.value_stack, &mut self.program_counter, ); + self.blocks.push(Block { + ty: BlockType::FunctionBody(fn_index), + vstack: self.value_stack.depth(), + }); loop { match self.execute_next_instruction(module) { @@ -351,14 +365,16 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { fn do_return(&mut self) -> Action { let Frame { return_addr, - function_block_depth, - locals_start, + body_block_index, .. } = self.current_frame; // Check where in the value stack the current block started let current_block_base = match self.blocks.last() { - Some(Block::Loop { vstack, .. } | Block::Normal { vstack }) => *vstack, + Some(Block { ty, vstack }) => { + debug_assert!(!matches!(ty, BlockType::Locals(_))); + *vstack + } _ => 0, }; @@ -369,14 +385,18 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { None }; - self.value_stack.truncate(locals_start); + // Throw away all values from arg[0] upward + let locals_block_index = body_block_index - 1; + let locals_block = &self.blocks[locals_block_index]; + self.value_stack.truncate(locals_block.vstack); if let Some(val) = return_value { self.value_stack.push(val); } - self.blocks.truncate(function_block_depth - 1); + // Resume executing at the next instruction in the caller function + let new_block_len = locals_block_index; // don't need a -1 because one is a length and the other is an index! + self.blocks.truncate(new_block_len); self.program_counter = return_addr; - if let Some(caller_frame) = self.previous_frames.pop() { self.current_frame = caller_frame; Action::Continue @@ -416,16 +436,18 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { fn do_break(&mut self, relative_blocks_outward: u32, module: &WasmModule<'a>) { let block_index = self.blocks.len() - 1 - relative_blocks_outward as usize; - match self.blocks[block_index] { - Block::Loop { start_addr, vstack } => { + let Block { ty, vstack } = self.blocks[block_index]; + match ty { + BlockType::Loop(start_addr) => { self.blocks.truncate(block_index + 1); self.value_stack.truncate(vstack); self.program_counter = start_addr; } - Block::Normal { vstack } => { + BlockType::FunctionBody(_) | BlockType::Normal => { self.break_forward(relative_blocks_outward, module); self.value_stack.truncate(vstack); } + BlockType::Locals(_) => unreachable!(), } } @@ -497,16 +519,14 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { ); } - let arg_type_bytes = module.types.look_up_arg_type_bytes(signature_index); + let (arg_type_iter, ret_type) = module.types.look_up(signature_index); if let Some(import) = opt_import { self.import_arguments.clear(); self.import_arguments - .extend(std::iter::repeat(Value::I64(0)).take(arg_type_bytes.len())); - for (i, type_byte) in arg_type_bytes.iter().copied().enumerate().rev() { + .extend(std::iter::repeat(Value::I64(0)).take(arg_type_iter.len())); + for (i, expected) in arg_type_iter.enumerate().rev() { let arg = self.value_stack.pop(); - - let expected = ValueType::from(type_byte); let actual = ValueType::from(arg); if actual != expected { return Err(Error::ValueStackType(expected, actual)); @@ -535,23 +555,31 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { // advance PC to the start of the local variable declarations u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap(); - self.blocks.push(Block::Normal { + self.blocks.push(Block { + ty: BlockType::Locals(fn_index), vstack: self.value_stack.depth(), }); - let function_block_depth = self.blocks.len(); + let body_block_index = self.blocks.len(); let mut swap_frame = Frame::enter( fn_index, return_addr, - function_block_depth, - arg_type_bytes, + body_block_index, + arg_type_iter.len(), + ret_type, &module.code.bytes, &mut self.value_stack, &mut self.program_counter, ); std::mem::swap(&mut swap_frame, &mut self.current_frame); self.previous_frames.push(swap_frame); + + self.blocks.push(Block { + ty: BlockType::FunctionBody(fn_index), + vstack: self.value_stack.depth(), + }); } + Ok(()) } @@ -580,21 +608,23 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { NOP => {} BLOCK => { self.fetch_immediate_u32(module); // blocktype (ignored) - self.blocks.push(Block::Normal { + self.blocks.push(Block { + ty: BlockType::Normal, vstack: self.value_stack.depth(), }); } LOOP => { self.fetch_immediate_u32(module); // blocktype (ignored) - self.blocks.push(Block::Loop { + self.blocks.push(Block { + ty: BlockType::Loop(self.program_counter), vstack: self.value_stack.depth(), - start_addr: self.program_counter, }); } IF => { self.fetch_immediate_u32(module); // blocktype (ignored) let condition = self.value_stack.pop_i32()?; - self.blocks.push(Block::Normal { + self.blocks.push(Block { + ty: BlockType::Normal, vstack: self.value_stack.depth(), }); if condition == 0 { @@ -647,7 +677,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { self.do_break(0, module); } END => { - if self.blocks.len() == self.current_frame.function_block_depth { + if self.blocks.len() == (self.current_frame.body_block_index + 1) { // implicit RETURN at end of function action = self.do_return(); implicit_return = true; @@ -1661,9 +1691,11 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { let is_program_end = self.program_counter == 0; if is_return && !is_program_end { eprintln!( - "returning to function {} at {:06x}\n", + "returning to function {} at {:06x}\nwith stack+locals {:?}\nand block indices {:?}\n", self.current_frame.fn_index, - self.program_counter + self.module.code.section_offset as usize + self.program_counter + self.module.code.section_offset as usize, + &self.value_stack, + &self.blocks ); } else if op_code == CALL || op_code == CALLINDIRECT { eprintln!(); diff --git a/crates/wasm_interp/src/tests/test_basics.rs b/crates/wasm_interp/src/tests/test_basics.rs index d2d3363cb2..00d5665de9 100644 --- a/crates/wasm_interp/src/tests/test_basics.rs +++ b/crates/wasm_interp/src/tests/test_basics.rs @@ -838,12 +838,14 @@ fn test_set_get_local() { let fn_index = 0; let return_addr = 0x1234; let return_block_depth = 0; - let arg_type_bytes = &[]; + let n_args = 0; + let ret_type = Some(ValueType::I32); inst.current_frame = Frame::enter( fn_index, return_addr, return_block_depth, - arg_type_bytes, + n_args, + ret_type, &buffer, &mut inst.value_stack, &mut cursor, @@ -883,12 +885,14 @@ fn test_tee_get_local() { let fn_index = 0; let return_addr = 0x1234; let return_block_depth = 0; - let arg_type_bytes = &[]; + let n_args = 0; + let ret_type = Some(ValueType::I32); inst.current_frame = Frame::enter( fn_index, return_addr, return_block_depth, - arg_type_bytes, + n_args, + ret_type, &buffer, &mut inst.value_stack, &mut cursor, diff --git a/crates/wasm_module/src/sections.rs b/crates/wasm_module/src/sections.rs index be3d22a329..6b6723f836 100644 --- a/crates/wasm_module/src/sections.rs +++ b/crates/wasm_module/src/sections.rs @@ -212,6 +212,46 @@ impl<'a> Serialize for Signature<'a> { } } +#[derive(Debug)] +pub struct SignatureParamsIter<'a> { + bytes: &'a [u8], + index: usize, + end: usize, +} + +impl<'a> Iterator for SignatureParamsIter<'a> { + type Item = ValueType; + + fn next(&mut self) -> Option { + if self.index >= self.end { + None + } else { + self.bytes.get(self.index).map(|b| { + self.index += 1; + ValueType::from(*b) + }) + } + } + + fn size_hint(&self) -> (usize, Option) { + let size = self.end - self.index; + (size, Some(size)) + } +} + +impl<'a> ExactSizeIterator for SignatureParamsIter<'a> {} + +impl<'a> DoubleEndedIterator for SignatureParamsIter<'a> { + fn next_back(&mut self) -> Option { + if self.end == 0 { + None + } else { + self.end -= 1; + self.bytes.get(self.end).map(|b| ValueType::from(*b)) + } + } +} + #[derive(Debug)] pub struct TypeSection<'a> { /// Private. See WasmModule::add_function_signature @@ -258,11 +298,23 @@ impl<'a> TypeSection<'a> { self.bytes.is_empty() } - pub fn look_up_arg_type_bytes(&self, sig_index: u32) -> &[u8] { + pub fn look_up(&'a self, sig_index: u32) -> (SignatureParamsIter<'a>, Option) { let mut offset = self.offsets[sig_index as usize]; offset += 1; // separator - let count = u32::parse((), &self.bytes, &mut offset).unwrap() as usize; - &self.bytes[offset..][..count] + let param_count = u32::parse((), &self.bytes, &mut offset).unwrap() as usize; + let params_iter = SignatureParamsIter { + bytes: &self.bytes[offset..][..param_count], + index: 0, + end: param_count, + }; + offset += param_count; + + let return_type = if self.bytes[offset] == 0 { + None + } else { + Some(ValueType::from(self.bytes[offset + 1])) + }; + (params_iter, return_type) } }