diff --git a/crates/wasm_interp/src/call_stack.rs b/crates/wasm_interp/src/call_stack.rs deleted file mode 100644 index f3cf4d23cc..0000000000 --- a/crates/wasm_interp/src/call_stack.rs +++ /dev/null @@ -1,344 +0,0 @@ -use bumpalo::{collections::Vec, Bump}; -use roc_wasm_module::opcodes::OpCode; -use roc_wasm_module::sections::ImportDesc; -use roc_wasm_module::{parse::Parse, Value, ValueType, WasmModule}; -use std::fmt::{self, Write}; -use std::iter::repeat; - -use crate::{pc_to_fn_index, Error, ValueStack}; - -/// Struct-of-Arrays storage for the call stack. -/// Type info is packed to avoid wasting space on padding. -/// However we store 64 bits for every local, even 32-bit values, for easy random access. -#[derive(Debug)] -pub struct CallStack<'a> { - /// return addresses and nested block depths (one entry per frame) - return_addrs_and_block_depths: Vec<'a, (u32, u32)>, - /// frame offsets into the `locals`, `is_float`, and `is_64` vectors (one entry per frame) - frame_offsets: Vec<'a, u32>, - /// base size of the value stack before executing (one entry per frame) - value_stack_bases: Vec<'a, u32>, - /// local variables (one entry per local) - locals: Vec<'a, Value>, -} - -impl<'a> CallStack<'a> { - pub fn new(arena: &'a Bump) -> Self { - CallStack { - return_addrs_and_block_depths: Vec::with_capacity_in(256, arena), - frame_offsets: Vec::with_capacity_in(256, arena), - value_stack_bases: Vec::with_capacity_in(256, arena), - locals: Vec::with_capacity_in(16 * 256, arena), - } - } - - /// On entering a Wasm call, save the return address, and make space for locals - pub(crate) fn push_frame( - &mut self, - return_addr: u32, - return_block_depth: u32, - arg_type_bytes: &[u8], - value_stack: &mut ValueStack<'a>, - code_bytes: &[u8], - pc: &mut usize, - ) -> Result<(), crate::Error> { - self.return_addrs_and_block_depths - .push((return_addr, return_block_depth)); - let frame_offset = self.locals.len(); - self.frame_offsets.push(frame_offset as u32); - - // Make space for arguments - let n_args = arg_type_bytes.len(); - self.locals.extend(repeat(Value::I64(0)).take(n_args)); - - // Pop arguments off the value stack and into locals - for (i, type_byte) in arg_type_bytes.iter().copied().enumerate().rev() { - let arg = value_stack.pop(); - let ty = ValueType::from(arg); - let expected_type = ValueType::from(type_byte); - if ty != expected_type { - return Err(Error::ValueStackType(expected_type, ty)); - } - self.set_local_help(i as u32, arg); - } - - self.value_stack_bases.push(value_stack.depth() as u32); - - // Parse local variable declarations in the function header. They're grouped by type. - let local_group_count = u32::parse((), code_bytes, pc).unwrap(); - for _ in 0..local_group_count { - let (group_size, ty) = <(u32, ValueType)>::parse((), code_bytes, pc).unwrap(); - let n = group_size as usize; - let zero = match ty { - ValueType::I32 => Value::I32(0), - ValueType::I64 => Value::I64(0), - ValueType::F32 => Value::F32(0.0), - ValueType::F64 => Value::F64(0.0), - }; - self.locals.extend(repeat(zero).take(n)); - } - Ok(()) - } - - /// On returning from a Wasm call, drop its locals and retrieve the return address - pub fn pop_frame(&mut self) -> Option<(u32, u32)> { - let frame_offset = self.frame_offsets.pop()? as usize; - self.value_stack_bases.pop()?; - self.locals.truncate(frame_offset); - self.return_addrs_and_block_depths.pop() - } - - pub fn get_local(&self, local_index: u32) -> Value { - self.get_local_help(self.frame_offsets.len() - 1, local_index) - } - - fn get_local_help(&self, frame_index: usize, local_index: u32) -> Value { - let frame_offset = self.frame_offsets[frame_index]; - let index = (frame_offset + local_index) as usize; - self.locals[index] - } - - pub(crate) fn set_local(&mut self, local_index: u32, value: Value) -> Result<(), Error> { - let expected_type = self.set_local_help(local_index, value); - let actual_type = ValueType::from(value); - if actual_type == expected_type { - Ok(()) - } else { - Err(Error::ValueStackType(expected_type, actual_type)) - } - } - - fn set_local_help(&mut self, local_index: u32, value: Value) -> ValueType { - let frame_offset = *self.frame_offsets.last().unwrap(); - let index = (frame_offset + local_index) as usize; - let old_value = self.locals[index]; - self.locals[index] = value; - ValueType::from(old_value) - } - - pub fn value_stack_base(&self) -> u32 { - *self.value_stack_bases.last().unwrap_or(&0) - } - - pub fn is_empty(&self) -> bool { - self.frame_offsets.is_empty() - } - - /// Dump a stack trace of the WebAssembly program - /// - /// -------------- - /// function 123 - /// address 0x12345 - /// args 0: I64(234), 1: F64(7.15) - /// locals 2: I32(412), 3: F64(3.14) - /// stack [I64(111), F64(3.14)] - /// -------------- - pub fn dump_trace( - &self, - module: &WasmModule<'a>, - value_stack: &ValueStack<'a>, - pc: usize, - buffer: &mut String, - ) -> fmt::Result { - let divider = "-------------------"; - writeln!(buffer, "{}", divider)?; - - let mut value_stack_iter = value_stack.iter(); - - for frame in 0..self.frame_offsets.len() { - let next_frame = frame + 1; - let op_offset = if next_frame < self.frame_offsets.len() { - // return address of next frame = next op in this frame - let next_op = self.return_addrs_and_block_depths[next_frame].0 as usize; - // Call address is more intuitive than the return address when debugging. Search backward for it. - // Skip last byte of function index to avoid a false match with CALL/CALLINDIRECT. - // The more significant bytes won't match because of LEB-128 encoding. - let mut call_op = next_op - 2; - loop { - let byte = module.code.bytes[call_op]; - if byte == OpCode::CALL as u8 || byte == OpCode::CALLINDIRECT as u8 { - break; - } else { - call_op -= 1; - } - } - call_op - } else { - pc - }; - - let fn_index = pc_to_fn_index(op_offset, module); - let address = op_offset + module.code.section_offset as usize; - writeln!(buffer, "function {}", fn_index)?; - writeln!(buffer, " address {:06x}", address)?; // format matches wasm-objdump, for easy search - - write!(buffer, " args ")?; - let arg_count = { - let n_import_fns = module.import.imports.len(); - let signature_index = if fn_index < n_import_fns { - match module.import.imports[fn_index].description { - ImportDesc::Func { signature_index } => signature_index, - _ => unreachable!(), - } - } else { - module.function.signatures[fn_index - n_import_fns] - }; - module.types.look_up_arg_type_bytes(signature_index).len() - }; - let args_and_locals_count = { - let frame_offset = self.frame_offsets[frame] as usize; - let next_frame_offset = if frame == self.frame_offsets.len() - 1 { - self.locals.len() - } else { - self.frame_offsets[frame + 1] as usize - }; - next_frame_offset - frame_offset - }; - for index in 0..args_and_locals_count { - let value = self.get_local_help(frame, index as u32); - if index != 0 { - write!(buffer, ", ")?; - } - if index == arg_count { - write!(buffer, "\n locals ")?; - } - write!(buffer, "{}: {:?}", index, value)?; - } - write!(buffer, "\n stack [")?; - - let frame_value_count = { - let value_stack_base = self.value_stack_bases[frame]; - let next_value_stack_base = if frame == self.frame_offsets.len() - 1 { - value_stack.depth() as u32 - } else { - self.value_stack_bases[frame + 1] - }; - next_value_stack_base - value_stack_base - }; - for i in 0..frame_value_count { - if i != 0 { - write!(buffer, ", ")?; - } - if let Some(value) = value_stack_iter.next() { - write!(buffer, "{:?}", value)?; - } - } - - writeln!(buffer, "]")?; - writeln!(buffer, "{}", divider)?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use roc_wasm_module::Serialize; - - use super::*; - - const RETURN_ADDR: u32 = 0x12345; - - fn test_get_set(call_stack: &mut CallStack<'_>, index: u32, value: Value) { - call_stack.set_local(index, value).unwrap(); - assert_eq!(call_stack.get_local(index), value); - } - - fn setup<'a>(arena: &'a Bump, call_stack: &mut CallStack<'a>) { - let mut buffer = vec![]; - let mut cursor = 0; - let mut vs = ValueStack::new(arena); - - // Push a other few frames before the test frame, just to make the scenario more typical. - [(1u32, ValueType::I32)].serialize(&mut buffer); - call_stack - .push_frame(0x11111, 0, &[], &mut vs, &buffer, &mut cursor) - .unwrap(); - - [(2u32, ValueType::I32)].serialize(&mut buffer); - call_stack - .push_frame(0x22222, 0, &[], &mut vs, &buffer, &mut cursor) - .unwrap(); - - [(3u32, ValueType::I32)].serialize(&mut buffer); - call_stack - .push_frame(0x33333, 0, &[], &mut vs, &buffer, &mut cursor) - .unwrap(); - - // Create a test call frame with local variables of every type - [ - (8u32, ValueType::I32), - (4u32, ValueType::I64), - (2u32, ValueType::F32), - (1u32, ValueType::F64), - ] - .serialize(&mut buffer); - call_stack - .push_frame(RETURN_ADDR, 0, &[], &mut vs, &buffer, &mut cursor) - .unwrap(); - } - - #[test] - fn test_all() { - let arena = Bump::new(); - let mut call_stack = CallStack::new(&arena); - - setup(&arena, &mut call_stack); - - test_get_set(&mut call_stack, 0, Value::I32(123)); - test_get_set(&mut call_stack, 8, Value::I64(123456)); - test_get_set(&mut call_stack, 12, Value::F32(1.01)); - test_get_set(&mut call_stack, 14, Value::F64(-1.1)); - - test_get_set(&mut call_stack, 0, Value::I32(i32::MIN)); - test_get_set(&mut call_stack, 0, Value::I32(i32::MAX)); - - test_get_set(&mut call_stack, 8, Value::I64(i64::MIN)); - test_get_set(&mut call_stack, 8, Value::I64(i64::MAX)); - - test_get_set(&mut call_stack, 12, Value::F32(f32::MIN)); - test_get_set(&mut call_stack, 12, Value::F32(f32::MAX)); - - test_get_set(&mut call_stack, 14, Value::F64(f64::MIN)); - test_get_set(&mut call_stack, 14, Value::F64(f64::MAX)); - - assert_eq!(call_stack.pop_frame(), Some((RETURN_ADDR, 0))); - } - - #[test] - #[should_panic] - fn test_type_error_i32() { - let arena = Bump::new(); - let mut call_stack = CallStack::new(&arena); - setup(&arena, &mut call_stack); - test_get_set(&mut call_stack, 0, Value::F32(1.01)); - } - - #[test] - #[should_panic] - fn test_type_error_i64() { - let arena = Bump::new(); - let mut call_stack = CallStack::new(&arena); - setup(&arena, &mut call_stack); - test_get_set(&mut call_stack, 8, Value::F32(1.01)); - } - - #[test] - #[should_panic] - fn test_type_error_f32() { - let arena = Bump::new(); - let mut call_stack = CallStack::new(&arena); - setup(&arena, &mut call_stack); - test_get_set(&mut call_stack, 12, Value::I32(123)); - } - - #[test] - #[should_panic] - fn test_type_error_f64() { - let arena = Bump::new(); - let mut call_stack = CallStack::new(&arena); - setup(&arena, &mut call_stack); - test_get_set(&mut call_stack, 14, Value::I32(123)); - } -} diff --git a/crates/wasm_interp/src/frame.rs b/crates/wasm_interp/src/frame.rs new file mode 100644 index 0000000000..cbdd36bf5e --- /dev/null +++ b/crates/wasm_interp/src/frame.rs @@ -0,0 +1,82 @@ +use roc_wasm_module::{parse::Parse, Value, ValueType}; +use std::iter::repeat; + +use crate::value_stack::ValueStack; + +#[derive(Debug)] +pub struct Frame { + /// The function this frame belongs to + pub fn_index: usize, + /// Address in the code section where this frame returns to + pub return_addr: usize, + /// Depth of the "function body block" for this frame + pub body_block_index: usize, + /// Offset in the ValueStack where the args & locals begin + pub locals_start: usize, + /// Number of args & locals in the frame + pub locals_count: usize, + /// Expected return type, if any + pub return_type: Option, +} + +impl Frame { + pub fn new() -> Self { + Frame { + fn_index: 0, + return_addr: 0, + body_block_index: 0, + locals_start: 0, + locals_count: 0, + return_type: None, + } + } + + #[allow(clippy::too_many_arguments)] + pub fn enter( + fn_index: usize, + return_addr: usize, + body_block_index: usize, + n_args: usize, + return_type: Option, + code_bytes: &[u8], + value_stack: &mut ValueStack<'_>, + pc: &mut usize, + ) -> Self { + let locals_start = value_stack.depth() - n_args; + + // Parse local variable declarations in the function header. They're grouped by type. + let local_group_count = u32::parse((), code_bytes, pc).unwrap(); + for _ in 0..local_group_count { + let (group_size, ty) = <(u32, ValueType)>::parse((), code_bytes, pc).unwrap(); + let n = group_size as usize; + let zero = match ty { + ValueType::I32 => Value::I32(0), + ValueType::I64 => Value::I64(0), + ValueType::F32 => Value::F32(0.0), + ValueType::F64 => Value::F64(0.0), + }; + value_stack.extend(repeat(zero).take(n)); + } + + let locals_count = value_stack.depth() - locals_start; + + Frame { + fn_index, + return_addr, + body_block_index, + locals_start, + locals_count, + return_type, + } + } + + pub fn get_local(&self, values: &ValueStack<'_>, index: u32) -> Value { + debug_assert!((index as usize) < self.locals_count); + *values.get(self.locals_start + index as usize).unwrap() + } + + pub fn set_local(&self, values: &mut ValueStack<'_>, index: u32, value: Value) { + debug_assert!((index as usize) < self.locals_count); + values.set(self.locals_start + index as usize, value) + } +} diff --git a/crates/wasm_interp/src/instance.rs b/crates/wasm_interp/src/instance.rs index 2c0705b154..51c1625383 100644 --- a/crates/wasm_interp/src/instance.rs +++ b/crates/wasm_interp/src/instance.rs @@ -1,16 +1,16 @@ use bumpalo::{collections::Vec, Bump}; use std::fmt::{self, Write}; -use std::iter; +use std::iter::{self, once, Iterator}; use roc_wasm_module::opcodes::OpCode; use roc_wasm_module::parse::{Parse, SkipBytes}; -use roc_wasm_module::sections::{ImportDesc, MemorySection}; +use roc_wasm_module::sections::{ImportDesc, MemorySection, SignatureParamsIter}; use roc_wasm_module::{ExportType, WasmModule}; use roc_wasm_module::{Value, ValueType}; -use crate::call_stack::CallStack; +use crate::frame::Frame; use crate::value_stack::ValueStack; -use crate::{pc_to_fn_index, Error, ImportDispatcher}; +use crate::{Error, ImportDispatcher}; #[derive(Debug)] pub enum Action { @@ -18,10 +18,18 @@ pub enum Action { Break, } -#[derive(Debug)] -enum Block { - Loop { vstack: usize, start_addr: usize }, - Normal { vstack: usize }, +#[derive(Debug, Clone, Copy)] +enum BlockType { + Loop(usize), // Loop block, with start address to loop back to + Normal, // Block created by `block` instruction + Locals(usize), // Special "block" for locals. Holds function index for debug + FunctionBody(usize), // Special block surrounding the function body. Holds function index for debug +} + +#[derive(Debug, Clone, Copy)] +struct Block { + ty: BlockType, + vstack: usize, } #[derive(Debug, Clone)] @@ -33,23 +41,21 @@ struct BranchCacheEntry { #[derive(Debug)] pub struct Instance<'a, I: ImportDispatcher> { - module: &'a WasmModule<'a>, + pub(crate) module: &'a WasmModule<'a>, /// Contents of the WebAssembly instance's memory pub memory: Vec<'a, u8>, - /// Metadata for every currently-active function call - pub call_stack: CallStack<'a>, + /// The current call frame + pub(crate) current_frame: Frame, + /// Previous call frames + previous_frames: Vec<'a, Frame>, /// The WebAssembly stack machine's stack of values - pub value_stack: ValueStack<'a>, + pub(crate) value_stack: ValueStack<'a>, /// Values of any global variables - pub globals: Vec<'a, Value>, + pub(crate) globals: Vec<'a, Value>, /// Index in the code section of the current instruction - pub program_counter: usize, + pub(crate) program_counter: usize, /// One entry per nested block. For loops, stores the address of the first instruction. blocks: Vec<'a, Block>, - /// Outermost block depth for the currently-executing function. - outermost_block: u32, - /// Current function index - current_function: usize, /// Cache for branching instructions, split into buckets for each function. branch_cache: Vec<'a, Vec<'a, BranchCacheEntry>>, /// Number of imports in the module @@ -78,14 +84,13 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { Instance { module: arena.alloc(WasmModule::new(arena)), memory: Vec::from_iter_in(iter::repeat(0).take(mem_bytes as usize), arena), - call_stack: CallStack::new(arena), + current_frame: Frame::new(), + previous_frames: Vec::new_in(arena), value_stack: ValueStack::new(arena), globals: Vec::from_iter_in(globals, arena), program_counter, blocks: Vec::new_in(arena), - outermost_block: 0, branch_cache: bumpalo::vec![in arena; bumpalo::vec![in arena]], - current_function: 0, import_count: 0, import_dispatcher, import_arguments: Vec::new_in(arena), @@ -130,7 +135,6 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { ); let value_stack = ValueStack::new(arena); - let call_stack = CallStack::new(arena); let debug_string = if is_debug_mode { Some(String::new()) @@ -148,13 +152,12 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { Ok(Instance { module, memory, - call_stack, + current_frame: Frame::new(), + previous_frames: Vec::new_in(arena), value_stack, globals, program_counter: usize::MAX, blocks: Vec::new_in(arena), - outermost_block: 0, - current_function: usize::MAX, branch_cache, import_count, import_dispatcher, @@ -167,14 +170,11 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { where A: IntoIterator, { - let arg_type_bytes = self.prepare_to_call_export(self.module, fn_name)?; + let (fn_index, param_type_iter, ret_type) = + self.call_export_help_before_arg_load(self.module, fn_name)?; + let n_args = param_type_iter.len(); - for (i, (value, type_byte)) in arg_values - .into_iter() - .zip(arg_type_bytes.iter().copied()) - .enumerate() - { - let expected_type = ValueType::from(type_byte); + for (i, (value, expected_type)) in arg_values.into_iter().zip(param_type_iter).enumerate() { let actual_type = ValueType::from(value); if actual_type != expected_type { return Err(format!( @@ -185,7 +185,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { self.value_stack.push(value); } - self.call_export_help(self.module, arg_type_bytes) + self.call_export_help_after_arg_load(self.module, fn_index, n_args, ret_type) } pub fn call_export_from_cli( @@ -207,15 +207,17 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { // Implement the "basic numbers" CLI // Check if the called Wasm function takes numeric arguments, and if so, try to parse them from the CLI. - let arg_type_bytes = self.prepare_to_call_export(module, fn_name)?; - for (value_bytes, type_byte) in arg_strings + let (fn_index, param_type_iter, ret_type) = + self.call_export_help_before_arg_load(module, fn_name)?; + let n_args = param_type_iter.len(); + for (value_bytes, value_type) in arg_strings .iter() .skip(1) // first string is the .wasm filename - .zip(arg_type_bytes.iter().copied()) + .zip(param_type_iter) { use ValueType::*; let value_str = String::from_utf8_lossy(value_bytes); - let value = match ValueType::from(type_byte) { + let value = match value_type { I32 => Value::I32(value_str.parse::().map_err(|e| e.to_string())?), I64 => Value::I64(value_str.parse::().map_err(|e| e.to_string())?), F32 => Value::F32(value_str.parse::().map_err(|e| e.to_string())?), @@ -224,15 +226,15 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { self.value_stack.push(value); } - self.call_export_help(module, arg_type_bytes) + self.call_export_help_after_arg_load(module, fn_index, n_args, ret_type) } - fn prepare_to_call_export<'m>( + fn call_export_help_before_arg_load<'m>( &mut self, module: &'m WasmModule<'a>, fn_name: &str, - ) -> Result<&'m [u8], String> { - self.current_function = { + ) -> Result<(usize, SignatureParamsIter<'m>, Option), String> { + let fn_index = { let mut export_iter = module.export.exports.iter(); export_iter // First look up the name in exports @@ -266,7 +268,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { })? as usize }; - let internal_fn_index = self.current_function - self.import_count; + let internal_fn_index = fn_index - self.import_count; self.program_counter = { let mut cursor = module.code.function_offsets[internal_fn_index] as usize; @@ -274,38 +276,50 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { cursor }; - let arg_type_bytes = { + let (param_type_iter, return_type) = { let signature_index = module.function.signatures[internal_fn_index]; - module.types.look_up_arg_type_bytes(signature_index) + module.types.look_up(signature_index) }; if self.debug_string.is_some() { println!( "Calling export func[{}] '{}' at address {:#x}", - self.current_function, + fn_index, fn_name, self.program_counter + module.code.section_offset as usize ); } - Ok(arg_type_bytes) + Ok((fn_index, param_type_iter, return_type)) } - fn call_export_help( + fn call_export_help_after_arg_load( &mut self, module: &WasmModule<'a>, - arg_type_bytes: &[u8], + fn_index: usize, + n_args: usize, + return_type: Option, ) -> Result, String> { - self.call_stack - .push_frame( - 0, // return_addr - 0, // return_block_depth - arg_type_bytes, - &mut self.value_stack, - &module.code.bytes, - &mut self.program_counter, - ) - .map_err(|e| e.to_string_at(self.program_counter))?; + self.previous_frames.clear(); + self.blocks.clear(); + self.blocks.push(Block { + ty: BlockType::Locals(fn_index), + vstack: self.value_stack.depth(), + }); + self.current_frame = Frame::enter( + fn_index, + 0, // return_addr + self.blocks.len(), + n_args, + return_type, + &module.code.bytes, + &mut self.value_stack, + &mut self.program_counter, + ); + self.blocks.push(Block { + ty: BlockType::FunctionBody(fn_index), + vstack: self.value_stack.depth(), + }); loop { match self.execute_next_instruction(module) { @@ -316,14 +330,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { Err(e) => { let file_offset = self.program_counter + module.code.section_offset as usize; let mut message = e.to_string_at(file_offset); - self.call_stack - .dump_trace( - module, - &self.value_stack, - self.program_counter, - &mut message, - ) - .unwrap(); + self.debug_stack_trace(&mut message).unwrap(); return Err(message); } }; @@ -347,18 +354,39 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { } fn do_return(&mut self) -> Action { - self.blocks.truncate(self.outermost_block as usize); - if let Some((return_addr, block_depth)) = self.call_stack.pop_frame() { - if self.call_stack.is_empty() { - // We just popped the stack frame for the entry function. Terminate the program. - Action::Break - } else { - self.program_counter = return_addr as usize; - self.outermost_block = block_depth; - Action::Continue - } + // self.debug_values_and_blocks("start do_return"); + + let Frame { + return_addr, + body_block_index, + return_type, + .. + } = self.current_frame; + + // Throw away all locals and values except the return value + let locals_block_index = body_block_index - 1; + let locals_block = &self.blocks[locals_block_index]; + let new_stack_depth = if return_type.is_some() { + self.value_stack + .set(locals_block.vstack, self.value_stack.peek()); + locals_block.vstack + 1 } else { - // We should never get here with real programs, but maybe in tests. Terminate the program. + locals_block.vstack + }; + self.value_stack.truncate(new_stack_depth); + + // Resume executing at the next instruction in the caller function + let new_block_len = locals_block_index; // don't need a -1 because one is a length and the other is an index! + self.blocks.truncate(new_block_len); + self.program_counter = return_addr; + + // self.debug_values_and_blocks("end do_return"); + + if let Some(caller_frame) = self.previous_frames.pop() { + self.current_frame = caller_frame; + Action::Continue + } else { + // We just popped the stack frame for the entry function. Terminate the program. Action::Break } } @@ -393,16 +421,18 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { fn do_break(&mut self, relative_blocks_outward: u32, module: &WasmModule<'a>) { let block_index = self.blocks.len() - 1 - relative_blocks_outward as usize; - match self.blocks[block_index] { - Block::Loop { start_addr, vstack } => { + let Block { ty, vstack } = self.blocks[block_index]; + match ty { + BlockType::Loop(start_addr) => { self.blocks.truncate(block_index + 1); self.value_stack.truncate(vstack); self.program_counter = start_addr; } - Block::Normal { vstack } => { + BlockType::FunctionBody(_) | BlockType::Normal => { self.break_forward(relative_blocks_outward, module); self.value_stack.truncate(vstack); } + BlockType::Locals(_) => unreachable!(), } } @@ -411,7 +441,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { use OpCode::*; let addr = self.program_counter as u32; - let cache_result = self.branch_cache[self.current_function] + let cache_result = self.branch_cache[self.current_frame.fn_index] .iter() .find(|entry| entry.addr == addr && entry.argument == relative_blocks_outward); @@ -437,7 +467,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { _ => {} } } - self.branch_cache[self.current_function].push(BranchCacheEntry { + self.branch_cache[self.current_frame.fn_index].push(BranchCacheEntry { addr, argument: relative_blocks_outward, target: self.program_counter as u32, @@ -452,6 +482,8 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { fn_index: usize, module: &WasmModule<'a>, ) -> Result<(), Error> { + // self.debug_values_and_blocks(&format!("start do_call {}", fn_index)); + let (signature_index, opt_import) = if fn_index < self.import_count { // Imported non-Wasm function let import = &module.import.imports[fn_index]; @@ -474,15 +506,22 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { ); } - let arg_type_bytes = module.types.look_up_arg_type_bytes(signature_index); + let (arg_type_iter, ret_type) = module.types.look_up(signature_index); + let n_args = arg_type_iter.len(); + if self.debug_string.is_some() { + self.debug_call(n_args, ret_type); + } if let Some(import) = opt_import { self.import_arguments.clear(); self.import_arguments - .extend(std::iter::repeat(Value::I64(0)).take(arg_type_bytes.len())); - for (i, type_byte) in arg_type_bytes.iter().copied().enumerate().rev() { + .extend(std::iter::repeat(Value::I64(0)).take(n_args)); + for (i, expected) in arg_type_iter.enumerate().rev() { let arg = self.value_stack.pop(); - assert_eq!(ValueType::from(arg), ValueType::from(type_byte)); + let actual = ValueType::from(arg); + if actual != expected { + return Err(Error::ValueStackType(expected, actual)); + } self.import_arguments[i] = arg; } @@ -499,28 +538,62 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { write!(debug_string, " {}.{}", import.module, import.name).unwrap(); } } else { - let return_addr = self.program_counter as u32; + let return_addr = self.program_counter; + // set PC to start of function bytes let internal_fn_index = fn_index - self.import_count; self.program_counter = module.code.function_offsets[internal_fn_index] as usize; + // advance PC to the start of the local variable declarations + u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap(); - let return_block_depth = self.outermost_block; - self.outermost_block = self.blocks.len() as u32; + self.blocks.push(Block { + ty: BlockType::Locals(fn_index), + vstack: self.value_stack.depth() - n_args, + }); + let body_block_index = self.blocks.len(); - let _function_byte_length = - u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap(); - self.call_stack.push_frame( + let mut swap_frame = Frame::enter( + fn_index, return_addr, - return_block_depth, - arg_type_bytes, - &mut self.value_stack, + body_block_index, + n_args, + ret_type, &module.code.bytes, + &mut self.value_stack, &mut self.program_counter, - )?; + ); + std::mem::swap(&mut swap_frame, &mut self.current_frame); + self.previous_frames.push(swap_frame); + + self.blocks.push(Block { + ty: BlockType::FunctionBody(fn_index), + vstack: self.value_stack.depth(), + }); } - self.current_function = fn_index; + // self.debug_values_and_blocks("end do_call"); + Ok(()) } + fn debug_call(&mut self, n_args: usize, return_type: Option) { + if let Some(debug_string) = self.debug_string.as_mut() { + write!(debug_string, " args=[").unwrap(); + let arg_iter = self + .value_stack + .iter() + .skip(self.value_stack.depth() - n_args); + let mut first = true; + for arg in arg_iter { + if first { + first = false; + } else { + write!(debug_string, ", ").unwrap(); + } + write!(debug_string, "{:x?}", arg).unwrap(); + } + writeln!(debug_string, "] return_type={:?}", return_type).unwrap(); + } + } + pub(crate) fn execute_next_instruction( &mut self, module: &WasmModule<'a>, @@ -546,26 +619,28 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { NOP => {} BLOCK => { self.fetch_immediate_u32(module); // blocktype (ignored) - self.blocks.push(Block::Normal { + self.blocks.push(Block { + ty: BlockType::Normal, vstack: self.value_stack.depth(), }); } LOOP => { self.fetch_immediate_u32(module); // blocktype (ignored) - self.blocks.push(Block::Loop { + self.blocks.push(Block { + ty: BlockType::Loop(self.program_counter), vstack: self.value_stack.depth(), - start_addr: self.program_counter, }); } IF => { self.fetch_immediate_u32(module); // blocktype (ignored) let condition = self.value_stack.pop_i32()?; - self.blocks.push(Block::Normal { + self.blocks.push(Block { + ty: BlockType::Normal, vstack: self.value_stack.depth(), }); if condition == 0 { let addr = self.program_counter as u32; - let cache_result = self.branch_cache[self.current_function] + let cache_result = self.branch_cache[self.current_frame.fn_index] .iter() .find(|entry| entry.addr == addr); if let Some(entry) = cache_result { @@ -598,7 +673,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { _ => {} } } - self.branch_cache[self.current_function].push(BranchCacheEntry { + self.branch_cache[self.current_frame.fn_index].push(BranchCacheEntry { addr, argument: 0, target: self.program_counter as u32, @@ -613,7 +688,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { self.do_break(0, module); } END => { - if self.blocks.len() == self.outermost_block as usize { + if self.blocks.len() == (self.current_frame.body_block_index + 1) { // implicit RETURN at end of function action = self.do_return(); implicit_return = true; @@ -692,18 +767,20 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { } GETLOCAL => { let index = self.fetch_immediate_u32(module); - let value = self.call_stack.get_local(index); + let value = self.current_frame.get_local(&self.value_stack, index); self.value_stack.push(value); } SETLOCAL => { let index = self.fetch_immediate_u32(module); let value = self.value_stack.pop(); - self.call_stack.set_local(index, value)?; + self.current_frame + .set_local(&mut self.value_stack, index, value); } TEELOCAL => { let index = self.fetch_immediate_u32(module); let value = self.value_stack.peek(); - self.call_stack.set_local(index, value)?; + self.current_frame + .set_local(&mut self.value_stack, index, value); } GETGLOBAL => { let index = self.fetch_immediate_u32(module); @@ -1618,17 +1695,156 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> { } if let Some(debug_string) = &self.debug_string { - let base = self.call_stack.value_stack_base(); - let slice = self.value_stack.get_slice(base as usize); - eprintln!("{:06x} {:17} {:?}", file_offset, debug_string, slice); - if op_code == RETURN || (op_code == END && implicit_return) { - let fn_index = pc_to_fn_index(self.program_counter, module); - eprintln!("returning to function {}\n", fn_index); - } else if op_code == CALL || op_code == CALLINDIRECT { - eprintln!(); + if matches!(op_code, CALL | CALLINDIRECT) { + eprintln!("\n{:06x} {}", file_offset, debug_string); + } else { + // For calls, we print special debug stuff in do_call + let base = self.current_frame.locals_start + self.current_frame.locals_count; + let slice = self.value_stack.get_slice(base as usize); + eprintln!("{:06x} {:17} {:x?}", file_offset, debug_string, slice); + } + let is_return = op_code == RETURN || (op_code == END && implicit_return); + let is_program_end = self.program_counter == 0; + if is_return && !is_program_end { + eprintln!( + "returning to function {} at {:06x}", + self.current_frame.fn_index, + self.program_counter + self.module.code.section_offset as usize, + ); } } Ok(action) } + + #[allow(dead_code)] + fn debug_values_and_blocks(&self, label: &str) { + eprintln!("\n========== {} ==========", label); + + let mut block_str = String::new(); + let mut block_iter = self.blocks.iter().enumerate(); + let mut block = block_iter.next(); + + let mut print_blocks = |i| { + block_str.clear(); + while let Some((b, Block { vstack, ty })) = block { + if *vstack > i { + break; + } + write!(block_str, "{}:{:?} ", b, ty).unwrap(); + block = block_iter.next(); + } + if !block_str.is_empty() { + eprintln!("--------------- {}", block_str); + } + }; + + for (i, v) in self.value_stack.iter().enumerate() { + print_blocks(i); + eprintln!("{:3} {:x?}", i, v); + } + print_blocks(self.value_stack.depth()); + + eprintln!(); + } + + /// Dump a stack trace when an error occurs + /// -------------- + /// function 123 + /// address 0x12345 + /// args 0: I64(234), 1: F64(7.15) + /// locals 2: I32(412), 3: F64(3.14) + /// stack [I64(111), F64(3.14)] + /// -------------- + fn debug_stack_trace(&self, buffer: &mut String) -> fmt::Result { + let divider = "-------------------"; + writeln!(buffer, "{}", divider)?; + + let frames = self.previous_frames.iter().chain(once(&self.current_frame)); + let next_frames = frames.clone().skip(1); + + // Find the code address to display for each frame + // For previous frames, show the address of the CALL instruction + // For the current frame, show the program counter value + let mut execution_addrs = { + // for each previous_frame, find return address of the *next* frame + let return_addrs = next_frames.clone().map(|f| f.return_addr); + // roll back to the CALL instruction before that return address, it's more meaningful. + let call_addrs = return_addrs.map(|ra| self.debug_return_addr_to_call_addr(ra)); + // For the current frame, show the program_counter + call_addrs.chain(once(self.program_counter)) + }; + + let mut frame_ends = next_frames.map(|f| f.locals_start); + + for frame in frames { + let Frame { + fn_index, + locals_count, + locals_start, + .. + } = frame; + + let arg_count = { + let signature_index = if *fn_index < self.import_count { + match self.module.import.imports[*fn_index].description { + ImportDesc::Func { signature_index } => signature_index, + _ => unreachable!(), + } + } else { + self.module.function.signatures[fn_index - self.import_count] + }; + self.module.types.look_up(signature_index).0.len() + }; + + // Try to match formatting to wasm-objdump where possible, for easy copy & find + writeln!(buffer, "func[{}]", fn_index)?; + writeln!(buffer, " address {:06x}", execution_addrs.next().unwrap())?; + + write!(buffer, " args ")?; + for local_index in 0..*locals_count { + let value = self.value_stack.get(locals_start + local_index).unwrap(); + if local_index == arg_count { + write!(buffer, "\n locals ")?; + } else if local_index != 0 { + write!(buffer, ", ")?; + } + write!(buffer, "{}: {:?}", local_index, value)?; + } + + write!(buffer, "\n stack [")?; + let frame_end = frame_ends + .next() + .unwrap_or_else(|| self.value_stack.depth()); + let stack_start = locals_start + locals_count; + for i in stack_start..frame_end { + let value = self.value_stack.get(i).unwrap(); + if i != stack_start { + write!(buffer, ", ")?; + } + write!(buffer, "{:?}", value)?; + } + writeln!(buffer, "]")?; + writeln!(buffer, "{}", divider)?; + } + + Ok(()) + } + + // Call address is more intuitive than the return address in the stack trace. Search backward for it. + fn debug_return_addr_to_call_addr(&self, return_addr: usize) -> usize { + // return_addr is pointing at the next instruction after the CALL/CALLINDIRECT. + // Just before that is the LEB-128 function index or type index. + // The last LEB-128 byte is <128, but the others are >=128 so we can't mistake them for CALL/CALLINDIRECT + let mut call_addr = return_addr - 2; + loop { + let byte = self.module.code.bytes[call_addr]; + if byte == OpCode::CALL as u8 || byte == OpCode::CALLINDIRECT as u8 { + break; + } else { + call_addr -= 1; + } + } + call_addr + } } diff --git a/crates/wasm_interp/src/lib.rs b/crates/wasm_interp/src/lib.rs index e6af99c2a8..ef9e59cbdd 100644 --- a/crates/wasm_interp/src/lib.rs +++ b/crates/wasm_interp/src/lib.rs @@ -1,4 +1,4 @@ -mod call_stack; +mod frame; mod instance; mod tests; mod value_stack; @@ -9,8 +9,7 @@ pub use instance::Instance; pub use wasi::{WasiDispatcher, WasiFile}; pub use roc_wasm_module::Value; -use roc_wasm_module::{ValueType, WasmModule}; -use value_stack::ValueStack; +use roc_wasm_module::ValueType; pub trait ImportDispatcher { /// Dispatch a call from WebAssembly to your own code, based on module and function name. @@ -101,22 +100,3 @@ impl From<(ValueType, ValueType)> for Error { Error::ValueStackType(expected, actual) } } - -// Determine which function the program counter is in -pub(crate) fn pc_to_fn_index(program_counter: usize, module: &WasmModule<'_>) -> usize { - if module.code.function_offsets.is_empty() { - 0 - } else { - // Find the first function that starts *after* the given program counter - let next_internal_fn_index = module - .code - .function_offsets - .iter() - .position(|o| *o as usize > program_counter) - .unwrap_or(module.code.function_offsets.len()); - // Go back 1 - let internal_fn_index = next_internal_fn_index - 1; - // Adjust for imports, whose indices come before the code section - module.import.imports.len() + internal_fn_index - } -} diff --git a/crates/wasm_interp/src/tests/mod.rs b/crates/wasm_interp/src/tests/mod.rs index ebc318adb0..a78e4031cc 100644 --- a/crates/wasm_interp/src/tests/mod.rs +++ b/crates/wasm_interp/src/tests/mod.rs @@ -11,7 +11,8 @@ mod test_mem; use crate::{DefaultImportDispatcher, Instance}; use bumpalo::{collections::Vec, Bump}; use roc_wasm_module::{ - opcodes::OpCode, Export, ExportType, SerialBuffer, Signature, Value, ValueType, WasmModule, + opcodes::OpCode, Export, ExportType, SerialBuffer, Serialize, Signature, Value, ValueType, + WasmModule, }; pub fn default_state(arena: &Bump) -> Instance { @@ -92,7 +93,7 @@ where } let mut inst = - Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), false).unwrap(); + Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), true).unwrap(); let return_val = inst.call_export("test", []).unwrap().unwrap(); @@ -126,3 +127,32 @@ pub fn create_exported_function_no_locals<'a, F>( module.code.function_count += 1; module.code.function_offsets.push(offset as u32); } + +pub fn create_exported_function_with_locals<'a, F>( + module: &mut WasmModule<'a>, + name: &'a str, + signature: Signature<'a>, + local_types: &[(u32, ValueType)], + write_instructions: F, +) where + F: FnOnce(&mut Vec<'a, u8>), +{ + let internal_fn_index = module.code.function_offsets.len(); + let fn_index = module.import.function_count() + internal_fn_index; + module.export.exports.push(Export { + name, + ty: ExportType::Func, + index: fn_index as u32, + }); + module.add_function_signature(signature); + + let offset = module.code.bytes.encode_padded_u32(0); + let start = module.code.bytes.len(); + local_types.serialize(&mut module.code.bytes); + write_instructions(&mut module.code.bytes); + let len = module.code.bytes.len() - start; + module.code.bytes.overwrite_padded_u32(offset, len as u32); + + module.code.function_count += 1; + module.code.function_offsets.push(offset as u32); +} diff --git a/crates/wasm_interp/src/tests/test_basics.rs b/crates/wasm_interp/src/tests/test_basics.rs index 1a0095252c..00d5665de9 100644 --- a/crates/wasm_interp/src/tests/test_basics.rs +++ b/crates/wasm_interp/src/tests/test_basics.rs @@ -1,7 +1,11 @@ #![cfg(test)] -use super::{const_value, create_exported_function_no_locals, default_state}; -use crate::{instance::Action, DefaultImportDispatcher, ImportDispatcher, Instance, ValueStack}; +use crate::frame::Frame; +use crate::tests::{ + const_value, create_exported_function_no_locals, create_exported_function_with_locals, + default_state, +}; +use crate::{DefaultImportDispatcher, ImportDispatcher, Instance}; use bumpalo::{collections::Vec, Bump}; use roc_wasm_module::sections::{Import, ImportDesc}; use roc_wasm_module::{ @@ -17,88 +21,95 @@ fn test_loop() { fn test_loop_help(end: i32, expected: i32) { let arena = Bump::new(); let mut module = WasmModule::new(&arena); - let buf = &mut module.code.bytes; + { + let buf = &mut module.code.bytes; - // Loop from 0 to end, adding the loop variable to a total - let var_i = 0; - let var_total = 1; + // Loop from 0 to end, adding the loop variable to a total + let var_i = 0; + let var_total = 1; - // (local i32 i32) - buf.push(1); // one group of the given type - buf.push(2); // two locals in the group - buf.push(ValueType::I32 as u8); + let fn_len_index = buf.encode_padded_u32(0); - // loop - buf.push(OpCode::LOOP as u8); - buf.push(ValueType::VOID as u8); + // (local i32 i32) + buf.push(1); // one group of the given type + buf.push(2); // two locals in the group + buf.push(ValueType::I32 as u8); - // local.get $i - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(var_i); + // loop + buf.push(OpCode::LOOP as u8); + buf.push(ValueType::VOID as u8); - // i32.const 1 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(1); + // local.get $i + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_i); - // i32.add - buf.push(OpCode::I32ADD as u8); + // i32.const 1 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(1); - // local.tee $i - buf.push(OpCode::TEELOCAL as u8); - buf.encode_u32(var_i); + // i32.add + buf.push(OpCode::I32ADD as u8); - // local.get $total - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(var_total); + // local.tee $i + buf.push(OpCode::TEELOCAL as u8); + buf.encode_u32(var_i); - // i32.add - buf.push(OpCode::I32ADD as u8); + // local.get $total + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_total); - // local.set $total - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(var_total); + // i32.add + buf.push(OpCode::I32ADD as u8); - // local.get $i - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(var_i); + // local.set $total + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(var_total); - // i32.const $end - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(end); + // local.get $i + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_i); - // i32.lt_s - buf.push(OpCode::I32LTS as u8); + // i32.const $end + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(end); - // br_if 0 - buf.push(OpCode::BRIF as u8); - buf.encode_u32(0); + // i32.lt_s + buf.push(OpCode::I32LTS as u8); - // end - buf.push(OpCode::END as u8); + // br_if 0 + buf.push(OpCode::BRIF as u8); + buf.encode_u32(0); - // local.get $total - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(var_total); + // end + buf.push(OpCode::END as u8); - // end function - buf.push(OpCode::END as u8); + // local.get $total + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(var_total); - let mut state = default_state(&arena); - state - .call_stack - .push_frame( - 0, - 0, - &[], - &mut state.value_stack, - &module.code.bytes, - &mut state.program_counter, - ) - .unwrap(); + // end function + buf.push(OpCode::END as u8); - while let Ok(Action::Continue) = state.execute_next_instruction(&module) {} + buf.overwrite_padded_u32(fn_len_index, (buf.len() - fn_len_index) as u32); + } + module.code.function_offsets.push(0); + module.code.function_count = 1; - assert_eq!(state.value_stack.pop_i32(), Ok(expected)); + module.add_function_signature(Signature { + param_types: Vec::new_in(&arena), + ret_type: Some(ValueType::I32), + }); + module.export.append(Export { + name: "test", + ty: ExportType::Func, + index: 0, + }); + + let mut inst = + Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), false).unwrap(); + let return_val = inst.call_export("test", []).unwrap().unwrap(); + + assert_eq!(return_val, Value::I32(expected)); } #[test] @@ -111,157 +122,157 @@ fn test_if_else() { fn test_if_else_help(condition: i32, expected: i32) { let arena = Bump::new(); let mut module = WasmModule::new(&arena); - let buf = &mut module.code.bytes; - buf.push(1); // one group of the given type - buf.push(1); // one local in the group - buf.push(ValueType::I32 as u8); + let signature = Signature { + param_types: bumpalo::vec![in &arena], + ret_type: Some(ValueType::I32), + }; + let local_types = [(1, ValueType::I32)]; + create_exported_function_with_locals(&mut module, "test", signature, &local_types, |buf| { + // i32.const + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(condition); - // i32.const - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(condition); + // if + buf.push(OpCode::IF as u8); + buf.push(ValueType::VOID as u8); - // if - buf.push(OpCode::IF as u8); - buf.push(ValueType::VOID as u8); + // i32.const 111 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(111); - // i32.const 111 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(111); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // else + buf.push(OpCode::ELSE as u8); - // else - buf.push(OpCode::ELSE as u8); + // i32.const 222 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(222); - // i32.const 222 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(222); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // local.get 0 + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(0); - // local.get 0 - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(0); + // end function + buf.push(OpCode::END as u8); + }); - // end function - buf.push(OpCode::END as u8); + let is_debug_mode = false; + let mut inst = Instance::for_module( + &arena, + &module, + DefaultImportDispatcher::default(), + is_debug_mode, + ) + .unwrap(); + let result = inst.call_export("test", []).unwrap().unwrap(); - let mut state = default_state(&arena); - state - .call_stack - .push_frame( - 0, - 0, - &[], - &mut state.value_stack, - &module.code.bytes, - &mut state.program_counter, - ) - .unwrap(); - - while let Ok(Action::Continue) = state.execute_next_instruction(&module) {} - - assert_eq!(state.value_stack.pop_i32(), Ok(expected)); + assert_eq!(result, Value::I32(expected)); } #[test] fn test_br() { + let start_fn_name = "test"; let arena = Bump::new(); - let mut state = default_state(&arena); let mut module = WasmModule::new(&arena); - let buf = &mut module.code.bytes; - // (local i32) - buf.encode_u32(1); - buf.encode_u32(1); - buf.push(ValueType::I32 as u8); + let signature = Signature { + param_types: bumpalo::vec![in &arena], + ret_type: Some(ValueType::I32), + }; + let local_types = [(1, ValueType::I32)]; + create_exported_function_with_locals( + &mut module, + start_fn_name, + signature, + &local_types, + |buf| { + // i32.const 111 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(111); - // i32.const 111 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(111); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // block ;; label = @1 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @1 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // block ;; label = @2 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @2 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // block ;; label = @3 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @3 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // br 2 (;@1;) + buf.push(OpCode::BR as u8); + buf.encode_u32(2); - // br 2 (;@1;) - buf.push(OpCode::BR as u8); - buf.encode_u32(2); + // i32.const 444 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(444); - // i32.const 444 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(444); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // i32.const 333 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(333); - // i32.const 333 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(333); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // i32.const 222 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(222); - // i32.const 222 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(222); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // local.get 0) + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(0); - // local.get 0) - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(0); + buf.push(OpCode::END as u8); + }, + ); - buf.push(OpCode::END as u8); + let is_debug_mode = false; + let mut inst = Instance::for_module( + &arena, + &module, + DefaultImportDispatcher::default(), + is_debug_mode, + ) + .unwrap(); + let result = inst.call_export(start_fn_name, []).unwrap().unwrap(); - state - .call_stack - .push_frame( - 0, - 0, - &[], - &mut state.value_stack, - &module.code.bytes, - &mut state.program_counter, - ) - .unwrap(); - - while let Ok(Action::Continue) = state.execute_next_instruction(&module) {} - - assert_eq!(state.value_stack.pop(), Value::I32(111)) + assert_eq!(result, Value::I32(111)) } #[test] @@ -271,98 +282,101 @@ fn test_br_if() { } fn test_br_if_help(condition: i32, expected: i32) { + let start_fn_name = "test"; let arena = Bump::new(); - let mut state = default_state(&arena); let mut module = WasmModule::new(&arena); - let buf = &mut module.code.bytes; - // (local i32) - buf.encode_u32(1); - buf.encode_u32(1); - buf.push(ValueType::I32 as u8); + let signature = Signature { + param_types: bumpalo::vec![in &arena], + ret_type: Some(ValueType::I32), + }; + let local_types = [(1, ValueType::I32)]; + create_exported_function_with_locals( + &mut module, + start_fn_name, + signature, + &local_types, + |buf| { + // i32.const 111 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(111); - // i32.const 111 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(111); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // block ;; label = @1 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @1 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // block ;; label = @2 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @2 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // block ;; label = @3 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @3 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // i32.const + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(condition); - // i32.const - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(condition); + // br_if 2 (;@1;) + buf.push(OpCode::BRIF as u8); + buf.encode_u32(2); - // br_if 2 (;@1;) - buf.push(OpCode::BRIF as u8); - buf.encode_u32(2); + // i32.const 444 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(444); - // i32.const 444 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(444); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // i32.const 333 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(333); - // i32.const 333 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(333); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // i32.const 222 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(222); - // i32.const 222 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(222); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // local.get 0) + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(0); - // local.get 0) - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(0); + buf.push(OpCode::END as u8); + }, + ); - buf.push(OpCode::END as u8); + let is_debug_mode = true; + let mut inst = Instance::for_module( + &arena, + &module, + DefaultImportDispatcher::default(), + is_debug_mode, + ) + .unwrap(); + let result = inst.call_export(start_fn_name, []).unwrap().unwrap(); - state - .call_stack - .push_frame( - 0, - 0, - &[], - &mut state.value_stack, - &module.code.bytes, - &mut state.program_counter, - ) - .unwrap(); - - while let Ok(Action::Continue) = state.execute_next_instruction(&module) {} - - assert_eq!(state.value_stack.pop(), Value::I32(expected)) + assert_eq!(result, Value::I32(expected)) } #[test] @@ -373,103 +387,104 @@ fn test_br_table() { } fn test_br_table_help(condition: i32, expected: i32) { + let start_fn_name = "test"; let arena = Bump::new(); - let mut state = default_state(&arena); let mut module = WasmModule::new(&arena); - let buf = &mut module.code.bytes; - // (local i32) - buf.encode_u32(1); - buf.encode_u32(1); - buf.push(ValueType::I32 as u8); + let signature = Signature { + param_types: bumpalo::vec![in &arena], + ret_type: Some(ValueType::I32), + }; + let local_types = [(1, ValueType::I32)]; + create_exported_function_with_locals( + &mut module, + start_fn_name, + signature, + &local_types, + |buf| { + // i32.const 111 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(111); - // i32.const 111 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(111); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // block ;; label = @1 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @1 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // block ;; label = @2 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @2 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // block ;; label = @3 + buf.push(OpCode::BLOCK as u8); + buf.push(ValueType::VOID); - // block ;; label = @3 - buf.push(OpCode::BLOCK as u8); - buf.push(ValueType::VOID); + // i32.const + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(condition); - // i32.const - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(condition); + // br_table 0 1 2 (;@1;) + buf.push(OpCode::BRTABLE as u8); + buf.encode_u32(2); // number of non-fallback branches + buf.encode_u32(0); + buf.encode_u32(1); + buf.encode_u32(2); - // br_table 0 1 2 (;@1;) - buf.push(OpCode::BRTABLE as u8); - buf.encode_u32(2); // number of non-fallback branches - buf.encode_u32(0); - buf.encode_u32(1); - buf.encode_u32(2); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // i32.const 333 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(333); - // i32.const 333 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(333); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // br 1 + buf.push(OpCode::BR as u8); + buf.encode_u32(1); - // br 1 - buf.push(OpCode::BR as u8); - buf.encode_u32(1); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // i32.const 222 + buf.push(OpCode::I32CONST as u8); + buf.encode_i32(222); - // i32.const 222 - buf.push(OpCode::I32CONST as u8); - buf.encode_i32(222); + // local.set 0 + buf.push(OpCode::SETLOCAL as u8); + buf.encode_u32(0); - // local.set 0 - buf.push(OpCode::SETLOCAL as u8); - buf.encode_u32(0); + // br 0 + buf.push(OpCode::BR as u8); + buf.encode_u32(0); - // br 0 - buf.push(OpCode::BR as u8); - buf.encode_u32(0); + // end + buf.push(OpCode::END as u8); - // end - buf.push(OpCode::END as u8); + // local.get 0) + buf.push(OpCode::GETLOCAL as u8); + buf.encode_u32(0); - // local.get 0) - buf.push(OpCode::GETLOCAL as u8); - buf.encode_u32(0); + buf.push(OpCode::END as u8); + }, + ); - buf.push(OpCode::END as u8); + let is_debug_mode = false; + let mut inst = Instance::for_module( + &arena, + &module, + DefaultImportDispatcher::default(), + is_debug_mode, + ) + .unwrap(); + let result = inst.call_export(start_fn_name, []).unwrap().unwrap(); - println!("{:02x?}", buf); - - state - .call_stack - .push_frame( - 0, - 0, - &[], - &mut state.value_stack, - &module.code.bytes, - &mut state.program_counter, - ) - .unwrap(); - - while let Ok(Action::Continue) = state.execute_next_instruction(&module) {} - - assert_eq!(state.value_stack.pop(), Value::I32(expected)) + assert_eq!(result, Value::I32(expected)) } struct TestDispatcher { @@ -489,7 +504,6 @@ impl ImportDispatcher for TestDispatcher { assert_eq!(arguments.len(), 1); let val = arguments[0].expect_i32().unwrap(); self.internal_state += val; - dbg!(val, self.internal_state); Some(Value::I32(self.internal_state)) } } @@ -631,28 +645,22 @@ fn test_call_return_no_args() { #[test] fn test_call_return_with_args() { let arena = Bump::new(); - let mut state = default_state(&arena); let mut module = WasmModule::new(&arena); // Function 0: calculate 2+2 - let func0_offset = module.code.bytes.len() as u32; - module.code.function_offsets.push(func0_offset); - module.add_function_signature(Signature { - param_types: bumpalo::vec![in &arena;], + let signature0 = Signature { + param_types: bumpalo::vec![in &arena], ret_type: Some(ValueType::I32), + }; + create_exported_function_no_locals(&mut module, "two_plus_two", signature0, |buf| { + buf.push(OpCode::I32CONST as u8); + buf.push(2); + buf.push(OpCode::I32CONST as u8); + buf.push(2); + buf.push(OpCode::CALL as u8); + buf.push(1); + buf.push(OpCode::END as u8); }); - [ - 0, // no locals - OpCode::I32CONST as u8, - 2, - OpCode::I32CONST as u8, - 2, - OpCode::CALL as u8, - 1, - OpCode::END as u8, - ] - .serialize(&mut module.code.bytes); - let func0_first_instruction = func0_offset + 2; // skip function length and locals length // Function 1: add two numbers let func1_offset = module.code.bytes.len() as u32; @@ -672,11 +680,24 @@ fn test_call_return_with_args() { ] .serialize(&mut module.code.bytes); - state.program_counter = func0_first_instruction as usize; + let signature0 = Signature { + param_types: bumpalo::vec![in &arena; ValueType::I32, ValueType::I32], + ret_type: Some(ValueType::I32), + }; + create_exported_function_no_locals(&mut module, "add", signature0, |buf| { + buf.push(OpCode::GETLOCAL as u8); + buf.push(0); + buf.push(OpCode::GETLOCAL as u8); + buf.push(1); + buf.push(OpCode::I32ADD as u8); + buf.push(OpCode::END as u8); + }); - while let Ok(Action::Continue) = state.execute_next_instruction(&module) {} + let mut inst = + Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), false).unwrap(); + let result = inst.call_export("two_plus_two", []).unwrap().unwrap(); - assert_eq!(state.value_stack.peek(), Value::I32(4)); + assert_eq!(result, Value::I32(4)); } #[test] @@ -750,11 +771,9 @@ fn test_call_indirect_help(table_index: u32, elem_index: u32) -> Value { if false { let mut outfile_buf = Vec::new_in(&arena); module.serialize(&mut outfile_buf); - std::fs::write( - format!("/tmp/roc/call_indirect_{}_{}.wasm", table_index, elem_index), - outfile_buf, - ) - .unwrap(); + let filename = format!("/tmp/roc/call_indirect_{}_{}.wasm", table_index, elem_index); + std::fs::write(&filename, outfile_buf).unwrap(); + println!("\nWrote to {}\n", filename); } let mut inst = Instance::for_module( @@ -779,40 +798,32 @@ fn test_select() { fn test_select_help(first: Value, second: Value, condition: i32, expected: Value) { let arena = Bump::new(); let mut module = WasmModule::new(&arena); - let buf = &mut module.code.bytes; - buf.push(0); // no locals + // Function 0: calculate 2+2 + let signature0 = Signature { + param_types: bumpalo::vec![in &arena], + ret_type: Some(ValueType::from(expected)), + }; + create_exported_function_no_locals(&mut module, "test", signature0, |buf| { + const_value(buf, first); + const_value(buf, second); + const_value(buf, Value::I32(condition)); + buf.push(OpCode::SELECT as u8); + buf.push(OpCode::END as u8); + }); - const_value(buf, first); - const_value(buf, second); - const_value(buf, Value::I32(condition)); - buf.push(OpCode::SELECT as u8); - buf.push(OpCode::END as u8); + let mut inst = + Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), false).unwrap(); + let result = inst.call_export("test", []).unwrap().unwrap(); - let mut state = default_state(&arena); - state - .call_stack - .push_frame( - 0, - 0, - &[], - &mut state.value_stack, - &module.code.bytes, - &mut state.program_counter, - ) - .unwrap(); - - while let Ok(Action::Continue) = state.execute_next_instruction(&module) {} - - assert_eq!(state.value_stack.pop(), expected); + assert_eq!(result, expected); } #[test] fn test_set_get_local() { let arena = Bump::new(); - let mut state = default_state(&arena); + let mut inst = default_state(&arena); let mut module = WasmModule::new(&arena); - let mut vs = ValueStack::new(&arena); let mut buffer = vec![]; let mut cursor = 0; @@ -823,10 +834,22 @@ fn test_set_get_local() { (1u32, ValueType::I64), ] .serialize(&mut buffer); - state - .call_stack - .push_frame(0x1234, 0, &[], &mut vs, &buffer, &mut cursor) - .unwrap(); + + let fn_index = 0; + let return_addr = 0x1234; + let return_block_depth = 0; + let n_args = 0; + let ret_type = Some(ValueType::I32); + inst.current_frame = Frame::enter( + fn_index, + return_addr, + return_block_depth, + n_args, + ret_type, + &buffer, + &mut inst.value_stack, + &mut cursor, + ); module.code.bytes.push(OpCode::I32CONST as u8); module.code.bytes.encode_i32(12345); @@ -836,19 +859,18 @@ fn test_set_get_local() { module.code.bytes.push(OpCode::GETLOCAL as u8); module.code.bytes.encode_u32(2); - state.execute_next_instruction(&module).unwrap(); - state.execute_next_instruction(&module).unwrap(); - state.execute_next_instruction(&module).unwrap(); - assert_eq!(state.value_stack.depth(), 1); - assert_eq!(state.value_stack.pop(), Value::I32(12345)); + inst.execute_next_instruction(&module).unwrap(); + inst.execute_next_instruction(&module).unwrap(); + inst.execute_next_instruction(&module).unwrap(); + assert_eq!(inst.value_stack.depth(), 5); + assert_eq!(inst.value_stack.pop(), Value::I32(12345)); } #[test] fn test_tee_get_local() { let arena = Bump::new(); - let mut state = default_state(&arena); + let mut inst = default_state(&arena); let mut module = WasmModule::new(&arena); - let mut vs = ValueStack::new(&arena); let mut buffer = vec![]; let mut cursor = 0; @@ -859,10 +881,22 @@ fn test_tee_get_local() { (1u32, ValueType::I64), ] .serialize(&mut buffer); - state - .call_stack - .push_frame(0x1234, 0, &[], &mut vs, &buffer, &mut cursor) - .unwrap(); + + let fn_index = 0; + let return_addr = 0x1234; + let return_block_depth = 0; + let n_args = 0; + let ret_type = Some(ValueType::I32); + inst.current_frame = Frame::enter( + fn_index, + return_addr, + return_block_depth, + n_args, + ret_type, + &buffer, + &mut inst.value_stack, + &mut cursor, + ); module.code.bytes.push(OpCode::I32CONST as u8); module.code.bytes.encode_i32(12345); @@ -872,12 +906,12 @@ fn test_tee_get_local() { module.code.bytes.push(OpCode::GETLOCAL as u8); module.code.bytes.encode_u32(2); - state.execute_next_instruction(&module).unwrap(); - state.execute_next_instruction(&module).unwrap(); - state.execute_next_instruction(&module).unwrap(); - assert_eq!(state.value_stack.depth(), 2); - assert_eq!(state.value_stack.pop(), Value::I32(12345)); - assert_eq!(state.value_stack.pop(), Value::I32(12345)); + inst.execute_next_instruction(&module).unwrap(); + inst.execute_next_instruction(&module).unwrap(); + inst.execute_next_instruction(&module).unwrap(); + assert_eq!(inst.value_stack.depth(), 6); + assert_eq!(inst.value_stack.pop(), Value::I32(12345)); + assert_eq!(inst.value_stack.pop(), Value::I32(12345)); } #[test] diff --git a/crates/wasm_interp/src/tests/test_mem.rs b/crates/wasm_interp/src/tests/test_mem.rs index 66ccc3f559..1f08dce8ff 100644 --- a/crates/wasm_interp/src/tests/test_mem.rs +++ b/crates/wasm_interp/src/tests/test_mem.rs @@ -249,7 +249,6 @@ fn test_store<'a>( offset: u32, value: Value, ) -> Vec<'a, u8> { - let is_debug_mode = false; let start_fn_name = "test"; module.memory = MemorySection::new(arena, MemorySection::PAGE_SIZE); @@ -286,6 +285,7 @@ fn test_store<'a>( buf.append_u8(OpCode::END as u8); }); + let is_debug_mode = false; let mut inst = Instance::for_module( arena, module, diff --git a/crates/wasm_interp/src/value_stack.rs b/crates/wasm_interp/src/value_stack.rs index d45389b90c..8ce8423fb8 100644 --- a/crates/wasm_interp/src/value_stack.rs +++ b/crates/wasm_interp/src/value_stack.rs @@ -38,6 +38,18 @@ impl<'a> ValueStack<'a> { *self.values.last().unwrap() } + pub(crate) fn get(&self, index: usize) -> Option<&Value> { + self.values.get(index) + } + + pub(crate) fn set(&mut self, index: usize, value: Value) { + self.values[index] = value; + } + + pub(crate) fn extend>(&mut self, values: I) { + self.values.extend(values) + } + /// Memory addresses etc pub(crate) fn pop_u32(&mut self) -> Result { match self.values.pop() { diff --git a/crates/wasm_module/src/sections.rs b/crates/wasm_module/src/sections.rs index be3d22a329..6b6723f836 100644 --- a/crates/wasm_module/src/sections.rs +++ b/crates/wasm_module/src/sections.rs @@ -212,6 +212,46 @@ impl<'a> Serialize for Signature<'a> { } } +#[derive(Debug)] +pub struct SignatureParamsIter<'a> { + bytes: &'a [u8], + index: usize, + end: usize, +} + +impl<'a> Iterator for SignatureParamsIter<'a> { + type Item = ValueType; + + fn next(&mut self) -> Option { + if self.index >= self.end { + None + } else { + self.bytes.get(self.index).map(|b| { + self.index += 1; + ValueType::from(*b) + }) + } + } + + fn size_hint(&self) -> (usize, Option) { + let size = self.end - self.index; + (size, Some(size)) + } +} + +impl<'a> ExactSizeIterator for SignatureParamsIter<'a> {} + +impl<'a> DoubleEndedIterator for SignatureParamsIter<'a> { + fn next_back(&mut self) -> Option { + if self.end == 0 { + None + } else { + self.end -= 1; + self.bytes.get(self.end).map(|b| ValueType::from(*b)) + } + } +} + #[derive(Debug)] pub struct TypeSection<'a> { /// Private. See WasmModule::add_function_signature @@ -258,11 +298,23 @@ impl<'a> TypeSection<'a> { self.bytes.is_empty() } - pub fn look_up_arg_type_bytes(&self, sig_index: u32) -> &[u8] { + pub fn look_up(&'a self, sig_index: u32) -> (SignatureParamsIter<'a>, Option) { let mut offset = self.offsets[sig_index as usize]; offset += 1; // separator - let count = u32::parse((), &self.bytes, &mut offset).unwrap() as usize; - &self.bytes[offset..][..count] + let param_count = u32::parse((), &self.bytes, &mut offset).unwrap() as usize; + let params_iter = SignatureParamsIter { + bytes: &self.bytes[offset..][..param_count], + index: 0, + end: param_count, + }; + offset += param_count; + + let return_type = if self.bytes[offset] == 0 { + None + } else { + Some(ValueType::from(self.bytes[offset + 1])) + }; + (params_iter, return_type) } }