wasm_interp: store return type in frame & create SignatureParamsIter

This commit is contained in:
Brian Carroll 2022-12-17 13:14:13 +00:00
parent d51beb073f
commit b0d2e7a409
No known key found for this signature in database
GPG key ID: 5C7B2EC4101703C0
4 changed files with 153 additions and 61 deletions

View file

@ -10,12 +10,14 @@ pub struct Frame {
pub fn_index: usize, pub fn_index: usize,
/// Address in the code section where this frame returns to /// Address in the code section where this frame returns to
pub return_addr: usize, pub return_addr: usize,
/// Depth of the "function block" for this frame /// Depth of the "function body block" for this frame
pub function_block_depth: usize, pub body_block_index: usize,
/// Offset in the ValueStack where the args & locals begin /// Offset in the ValueStack where the args & locals begin
pub locals_start: usize, pub locals_start: usize,
/// Number of args & locals in the frame /// Number of args & locals in the frame
pub locals_count: usize, pub locals_count: usize,
/// Expected return type, if any
pub return_type: Option<ValueType>,
} }
impl Frame { impl Frame {
@ -23,22 +25,23 @@ impl Frame {
Frame { Frame {
fn_index: 0, fn_index: 0,
return_addr: 0, return_addr: 0,
function_block_depth: 0, body_block_index: 0,
locals_start: 0, locals_start: 0,
locals_count: 0, locals_count: 0,
return_type: None,
} }
} }
pub fn enter( pub fn enter(
fn_index: usize, fn_index: usize,
return_addr: usize, return_addr: usize,
function_block_depth: usize, body_block_index: usize,
arg_type_bytes: &[u8], n_args: usize,
return_type: Option<ValueType>,
code_bytes: &[u8], code_bytes: &[u8],
value_stack: &mut ValueStack<'_>, value_stack: &mut ValueStack<'_>,
pc: &mut usize, pc: &mut usize,
) -> Self { ) -> Self {
let n_args = arg_type_bytes.len();
let locals_start = value_stack.depth() - n_args; let locals_start = value_stack.depth() - n_args;
// Parse local variable declarations in the function header. They're grouped by type. // Parse local variable declarations in the function header. They're grouped by type.
@ -60,9 +63,10 @@ impl Frame {
Frame { Frame {
fn_index, fn_index,
return_addr, return_addr,
function_block_depth, body_block_index,
locals_start, locals_start,
locals_count, locals_count,
return_type,
} }
} }

View file

@ -4,7 +4,7 @@ use std::iter;
use roc_wasm_module::opcodes::OpCode; use roc_wasm_module::opcodes::OpCode;
use roc_wasm_module::parse::{Parse, SkipBytes}; use roc_wasm_module::parse::{Parse, SkipBytes};
use roc_wasm_module::sections::{ImportDesc, MemorySection}; use roc_wasm_module::sections::{ImportDesc, MemorySection, SignatureParamsIter};
use roc_wasm_module::{ExportType, WasmModule}; use roc_wasm_module::{ExportType, WasmModule};
use roc_wasm_module::{Value, ValueType}; use roc_wasm_module::{Value, ValueType};
@ -18,10 +18,18 @@ pub enum Action {
Break, Break,
} }
#[derive(Debug)] #[derive(Debug, Clone, Copy)]
enum Block { enum BlockType {
Loop { vstack: usize, start_addr: usize }, Loop(usize), // Loop block, with start address to loop back to
Normal { vstack: usize }, Normal, // Block created by `block` instruction
Locals(usize), // Special "block" for locals. Holds function index for debug
FunctionBody(usize), // Special block surrounding the function body. Holds function index for debug
}
#[derive(Debug, Clone, Copy)]
struct Block {
ty: BlockType,
vstack: usize,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -162,14 +170,11 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
where where
A: IntoIterator<Item = Value>, A: IntoIterator<Item = Value>,
{ {
let (fn_index, arg_type_bytes) = let (fn_index, param_type_iter, ret_type) =
self.call_export_help_before_arg_load(self.module, fn_name)?; self.call_export_help_before_arg_load(self.module, fn_name)?;
let n_args = param_type_iter.len();
for (i, (value, type_byte)) in arg_values for (i, (value, type_byte)) in arg_values.into_iter().zip(param_type_iter).enumerate() {
.into_iter()
.zip(arg_type_bytes.iter().copied())
.enumerate()
{
let expected_type = ValueType::from(type_byte); let expected_type = ValueType::from(type_byte);
let actual_type = ValueType::from(value); let actual_type = ValueType::from(value);
if actual_type != expected_type { if actual_type != expected_type {
@ -181,7 +186,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.value_stack.push(value); self.value_stack.push(value);
} }
self.call_export_help_after_arg_load(self.module, fn_index, arg_type_bytes) self.call_export_help_after_arg_load(self.module, fn_index, n_args, ret_type)
} }
pub fn call_export_from_cli( pub fn call_export_from_cli(
@ -203,11 +208,13 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
// Implement the "basic numbers" CLI // Implement the "basic numbers" CLI
// Check if the called Wasm function takes numeric arguments, and if so, try to parse them from the CLI. // Check if the called Wasm function takes numeric arguments, and if so, try to parse them from the CLI.
let (fn_index, arg_type_bytes) = self.call_export_help_before_arg_load(module, fn_name)?; let (fn_index, param_type_iter, ret_type) =
self.call_export_help_before_arg_load(module, fn_name)?;
let n_args = param_type_iter.len();
for (value_bytes, type_byte) in arg_strings for (value_bytes, type_byte) in arg_strings
.iter() .iter()
.skip(1) // first string is the .wasm filename .skip(1) // first string is the .wasm filename
.zip(arg_type_bytes.iter().copied()) .zip(param_type_iter)
{ {
use ValueType::*; use ValueType::*;
let value_str = String::from_utf8_lossy(value_bytes); let value_str = String::from_utf8_lossy(value_bytes);
@ -220,14 +227,14 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.value_stack.push(value); self.value_stack.push(value);
} }
self.call_export_help_after_arg_load(module, fn_index, arg_type_bytes) self.call_export_help_after_arg_load(module, fn_index, n_args, ret_type)
} }
fn call_export_help_before_arg_load<'m>( fn call_export_help_before_arg_load<'m>(
&mut self, &mut self,
module: &'m WasmModule<'a>, module: &'m WasmModule<'a>,
fn_name: &str, fn_name: &str,
) -> Result<(usize, &'m [u8]), String> { ) -> Result<(usize, SignatureParamsIter<'m>, Option<ValueType>), String> {
let fn_index = { let fn_index = {
let mut export_iter = module.export.exports.iter(); let mut export_iter = module.export.exports.iter();
export_iter export_iter
@ -270,9 +277,9 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
cursor cursor
}; };
let arg_type_bytes = { let (param_type_iter, return_type) = {
let signature_index = module.function.signatures[internal_fn_index]; let signature_index = module.function.signatures[internal_fn_index];
module.types.look_up_arg_type_bytes(signature_index) module.types.look_up(signature_index)
}; };
if self.debug_string.is_some() { if self.debug_string.is_some() {
@ -284,29 +291,36 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
); );
} }
Ok((fn_index, arg_type_bytes)) Ok((fn_index, param_type_iter, return_type))
} }
fn call_export_help_after_arg_load( fn call_export_help_after_arg_load(
&mut self, &mut self,
module: &WasmModule<'a>, module: &WasmModule<'a>,
fn_index: usize, fn_index: usize,
arg_type_bytes: &[u8], n_args: usize,
return_type: Option<ValueType>,
) -> Result<Option<Value>, String> { ) -> Result<Option<Value>, String> {
self.previous_frames.clear(); self.previous_frames.clear();
self.blocks.clear(); self.blocks.clear();
self.blocks.push(Block::Normal { self.blocks.push(Block {
ty: BlockType::Locals(fn_index),
vstack: self.value_stack.depth(), vstack: self.value_stack.depth(),
}); });
self.current_frame = Frame::enter( self.current_frame = Frame::enter(
fn_index, fn_index,
0, // return_addr 0, // return_addr
self.blocks.len(), self.blocks.len(),
arg_type_bytes, n_args,
return_type,
&module.code.bytes, &module.code.bytes,
&mut self.value_stack, &mut self.value_stack,
&mut self.program_counter, &mut self.program_counter,
); );
self.blocks.push(Block {
ty: BlockType::FunctionBody(fn_index),
vstack: self.value_stack.depth(),
});
loop { loop {
match self.execute_next_instruction(module) { match self.execute_next_instruction(module) {
@ -351,14 +365,16 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
fn do_return(&mut self) -> Action { fn do_return(&mut self) -> Action {
let Frame { let Frame {
return_addr, return_addr,
function_block_depth, body_block_index,
locals_start,
.. ..
} = self.current_frame; } = self.current_frame;
// Check where in the value stack the current block started // Check where in the value stack the current block started
let current_block_base = match self.blocks.last() { let current_block_base = match self.blocks.last() {
Some(Block::Loop { vstack, .. } | Block::Normal { vstack }) => *vstack, Some(Block { ty, vstack }) => {
debug_assert!(!matches!(ty, BlockType::Locals(_)));
*vstack
}
_ => 0, _ => 0,
}; };
@ -369,14 +385,18 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
None None
}; };
self.value_stack.truncate(locals_start); // Throw away all values from arg[0] upward
let locals_block_index = body_block_index - 1;
let locals_block = &self.blocks[locals_block_index];
self.value_stack.truncate(locals_block.vstack);
if let Some(val) = return_value { if let Some(val) = return_value {
self.value_stack.push(val); self.value_stack.push(val);
} }
self.blocks.truncate(function_block_depth - 1); // Resume executing at the next instruction in the caller function
let new_block_len = locals_block_index; // don't need a -1 because one is a length and the other is an index!
self.blocks.truncate(new_block_len);
self.program_counter = return_addr; self.program_counter = return_addr;
if let Some(caller_frame) = self.previous_frames.pop() { if let Some(caller_frame) = self.previous_frames.pop() {
self.current_frame = caller_frame; self.current_frame = caller_frame;
Action::Continue Action::Continue
@ -416,16 +436,18 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
fn do_break(&mut self, relative_blocks_outward: u32, module: &WasmModule<'a>) { fn do_break(&mut self, relative_blocks_outward: u32, module: &WasmModule<'a>) {
let block_index = self.blocks.len() - 1 - relative_blocks_outward as usize; let block_index = self.blocks.len() - 1 - relative_blocks_outward as usize;
match self.blocks[block_index] { let Block { ty, vstack } = self.blocks[block_index];
Block::Loop { start_addr, vstack } => { match ty {
BlockType::Loop(start_addr) => {
self.blocks.truncate(block_index + 1); self.blocks.truncate(block_index + 1);
self.value_stack.truncate(vstack); self.value_stack.truncate(vstack);
self.program_counter = start_addr; self.program_counter = start_addr;
} }
Block::Normal { vstack } => { BlockType::FunctionBody(_) | BlockType::Normal => {
self.break_forward(relative_blocks_outward, module); self.break_forward(relative_blocks_outward, module);
self.value_stack.truncate(vstack); self.value_stack.truncate(vstack);
} }
BlockType::Locals(_) => unreachable!(),
} }
} }
@ -497,16 +519,14 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
); );
} }
let arg_type_bytes = module.types.look_up_arg_type_bytes(signature_index); let (arg_type_iter, ret_type) = module.types.look_up(signature_index);
if let Some(import) = opt_import { if let Some(import) = opt_import {
self.import_arguments.clear(); self.import_arguments.clear();
self.import_arguments self.import_arguments
.extend(std::iter::repeat(Value::I64(0)).take(arg_type_bytes.len())); .extend(std::iter::repeat(Value::I64(0)).take(arg_type_iter.len()));
for (i, type_byte) in arg_type_bytes.iter().copied().enumerate().rev() { for (i, expected) in arg_type_iter.enumerate().rev() {
let arg = self.value_stack.pop(); let arg = self.value_stack.pop();
let expected = ValueType::from(type_byte);
let actual = ValueType::from(arg); let actual = ValueType::from(arg);
if actual != expected { if actual != expected {
return Err(Error::ValueStackType(expected, actual)); return Err(Error::ValueStackType(expected, actual));
@ -535,23 +555,31 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
// advance PC to the start of the local variable declarations // advance PC to the start of the local variable declarations
u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap(); u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap();
self.blocks.push(Block::Normal { self.blocks.push(Block {
ty: BlockType::Locals(fn_index),
vstack: self.value_stack.depth(), vstack: self.value_stack.depth(),
}); });
let function_block_depth = self.blocks.len(); let body_block_index = self.blocks.len();
let mut swap_frame = Frame::enter( let mut swap_frame = Frame::enter(
fn_index, fn_index,
return_addr, return_addr,
function_block_depth, body_block_index,
arg_type_bytes, arg_type_iter.len(),
ret_type,
&module.code.bytes, &module.code.bytes,
&mut self.value_stack, &mut self.value_stack,
&mut self.program_counter, &mut self.program_counter,
); );
std::mem::swap(&mut swap_frame, &mut self.current_frame); std::mem::swap(&mut swap_frame, &mut self.current_frame);
self.previous_frames.push(swap_frame); self.previous_frames.push(swap_frame);
self.blocks.push(Block {
ty: BlockType::FunctionBody(fn_index),
vstack: self.value_stack.depth(),
});
} }
Ok(()) Ok(())
} }
@ -580,21 +608,23 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
NOP => {} NOP => {}
BLOCK => { BLOCK => {
self.fetch_immediate_u32(module); // blocktype (ignored) self.fetch_immediate_u32(module); // blocktype (ignored)
self.blocks.push(Block::Normal { self.blocks.push(Block {
ty: BlockType::Normal,
vstack: self.value_stack.depth(), vstack: self.value_stack.depth(),
}); });
} }
LOOP => { LOOP => {
self.fetch_immediate_u32(module); // blocktype (ignored) self.fetch_immediate_u32(module); // blocktype (ignored)
self.blocks.push(Block::Loop { self.blocks.push(Block {
ty: BlockType::Loop(self.program_counter),
vstack: self.value_stack.depth(), vstack: self.value_stack.depth(),
start_addr: self.program_counter,
}); });
} }
IF => { IF => {
self.fetch_immediate_u32(module); // blocktype (ignored) self.fetch_immediate_u32(module); // blocktype (ignored)
let condition = self.value_stack.pop_i32()?; let condition = self.value_stack.pop_i32()?;
self.blocks.push(Block::Normal { self.blocks.push(Block {
ty: BlockType::Normal,
vstack: self.value_stack.depth(), vstack: self.value_stack.depth(),
}); });
if condition == 0 { if condition == 0 {
@ -647,7 +677,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.do_break(0, module); self.do_break(0, module);
} }
END => { END => {
if self.blocks.len() == self.current_frame.function_block_depth { if self.blocks.len() == (self.current_frame.body_block_index + 1) {
// implicit RETURN at end of function // implicit RETURN at end of function
action = self.do_return(); action = self.do_return();
implicit_return = true; implicit_return = true;
@ -1661,9 +1691,11 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
let is_program_end = self.program_counter == 0; let is_program_end = self.program_counter == 0;
if is_return && !is_program_end { if is_return && !is_program_end {
eprintln!( eprintln!(
"returning to function {} at {:06x}\n", "returning to function {} at {:06x}\nwith stack+locals {:?}\nand block indices {:?}\n",
self.current_frame.fn_index, self.current_frame.fn_index,
self.program_counter + self.module.code.section_offset as usize self.program_counter + self.module.code.section_offset as usize,
&self.value_stack,
&self.blocks
); );
} else if op_code == CALL || op_code == CALLINDIRECT { } else if op_code == CALL || op_code == CALLINDIRECT {
eprintln!(); eprintln!();

View file

@ -838,12 +838,14 @@ fn test_set_get_local() {
let fn_index = 0; let fn_index = 0;
let return_addr = 0x1234; let return_addr = 0x1234;
let return_block_depth = 0; let return_block_depth = 0;
let arg_type_bytes = &[]; let n_args = 0;
let ret_type = Some(ValueType::I32);
inst.current_frame = Frame::enter( inst.current_frame = Frame::enter(
fn_index, fn_index,
return_addr, return_addr,
return_block_depth, return_block_depth,
arg_type_bytes, n_args,
ret_type,
&buffer, &buffer,
&mut inst.value_stack, &mut inst.value_stack,
&mut cursor, &mut cursor,
@ -883,12 +885,14 @@ fn test_tee_get_local() {
let fn_index = 0; let fn_index = 0;
let return_addr = 0x1234; let return_addr = 0x1234;
let return_block_depth = 0; let return_block_depth = 0;
let arg_type_bytes = &[]; let n_args = 0;
let ret_type = Some(ValueType::I32);
inst.current_frame = Frame::enter( inst.current_frame = Frame::enter(
fn_index, fn_index,
return_addr, return_addr,
return_block_depth, return_block_depth,
arg_type_bytes, n_args,
ret_type,
&buffer, &buffer,
&mut inst.value_stack, &mut inst.value_stack,
&mut cursor, &mut cursor,

View file

@ -212,6 +212,46 @@ impl<'a> Serialize for Signature<'a> {
} }
} }
#[derive(Debug)]
pub struct SignatureParamsIter<'a> {
bytes: &'a [u8],
index: usize,
end: usize,
}
impl<'a> Iterator for SignatureParamsIter<'a> {
type Item = ValueType;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.end {
None
} else {
self.bytes.get(self.index).map(|b| {
self.index += 1;
ValueType::from(*b)
})
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let size = self.end - self.index;
(size, Some(size))
}
}
impl<'a> ExactSizeIterator for SignatureParamsIter<'a> {}
impl<'a> DoubleEndedIterator for SignatureParamsIter<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.end == 0 {
None
} else {
self.end -= 1;
self.bytes.get(self.end).map(|b| ValueType::from(*b))
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct TypeSection<'a> { pub struct TypeSection<'a> {
/// Private. See WasmModule::add_function_signature /// Private. See WasmModule::add_function_signature
@ -258,11 +298,23 @@ impl<'a> TypeSection<'a> {
self.bytes.is_empty() self.bytes.is_empty()
} }
pub fn look_up_arg_type_bytes(&self, sig_index: u32) -> &[u8] { pub fn look_up(&'a self, sig_index: u32) -> (SignatureParamsIter<'a>, Option<ValueType>) {
let mut offset = self.offsets[sig_index as usize]; let mut offset = self.offsets[sig_index as usize];
offset += 1; // separator offset += 1; // separator
let count = u32::parse((), &self.bytes, &mut offset).unwrap() as usize; let param_count = u32::parse((), &self.bytes, &mut offset).unwrap() as usize;
&self.bytes[offset..][..count] let params_iter = SignatureParamsIter {
bytes: &self.bytes[offset..][..param_count],
index: 0,
end: param_count,
};
offset += param_count;
let return_type = if self.bytes[offset] == 0 {
None
} else {
Some(ValueType::from(self.bytes[offset + 1]))
};
(params_iter, return_type)
} }
} }