wasm_interp: store return type in frame & create SignatureParamsIter

This commit is contained in:
Brian Carroll 2022-12-17 13:14:13 +00:00
parent d51beb073f
commit b0d2e7a409
No known key found for this signature in database
GPG key ID: 5C7B2EC4101703C0
4 changed files with 153 additions and 61 deletions

View file

@ -4,7 +4,7 @@ use std::iter;
use roc_wasm_module::opcodes::OpCode;
use roc_wasm_module::parse::{Parse, SkipBytes};
use roc_wasm_module::sections::{ImportDesc, MemorySection};
use roc_wasm_module::sections::{ImportDesc, MemorySection, SignatureParamsIter};
use roc_wasm_module::{ExportType, WasmModule};
use roc_wasm_module::{Value, ValueType};
@ -18,10 +18,18 @@ pub enum Action {
Break,
}
#[derive(Debug)]
enum Block {
Loop { vstack: usize, start_addr: usize },
Normal { vstack: usize },
#[derive(Debug, Clone, Copy)]
enum BlockType {
Loop(usize), // Loop block, with start address to loop back to
Normal, // Block created by `block` instruction
Locals(usize), // Special "block" for locals. Holds function index for debug
FunctionBody(usize), // Special block surrounding the function body. Holds function index for debug
}
#[derive(Debug, Clone, Copy)]
struct Block {
ty: BlockType,
vstack: usize,
}
#[derive(Debug, Clone)]
@ -162,14 +170,11 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
where
A: IntoIterator<Item = Value>,
{
let (fn_index, arg_type_bytes) =
let (fn_index, param_type_iter, ret_type) =
self.call_export_help_before_arg_load(self.module, fn_name)?;
let n_args = param_type_iter.len();
for (i, (value, type_byte)) in arg_values
.into_iter()
.zip(arg_type_bytes.iter().copied())
.enumerate()
{
for (i, (value, type_byte)) in arg_values.into_iter().zip(param_type_iter).enumerate() {
let expected_type = ValueType::from(type_byte);
let actual_type = ValueType::from(value);
if actual_type != expected_type {
@ -181,7 +186,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.value_stack.push(value);
}
self.call_export_help_after_arg_load(self.module, fn_index, arg_type_bytes)
self.call_export_help_after_arg_load(self.module, fn_index, n_args, ret_type)
}
pub fn call_export_from_cli(
@ -203,11 +208,13 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
// Implement the "basic numbers" CLI
// Check if the called Wasm function takes numeric arguments, and if so, try to parse them from the CLI.
let (fn_index, arg_type_bytes) = self.call_export_help_before_arg_load(module, fn_name)?;
let (fn_index, param_type_iter, ret_type) =
self.call_export_help_before_arg_load(module, fn_name)?;
let n_args = param_type_iter.len();
for (value_bytes, type_byte) in arg_strings
.iter()
.skip(1) // first string is the .wasm filename
.zip(arg_type_bytes.iter().copied())
.zip(param_type_iter)
{
use ValueType::*;
let value_str = String::from_utf8_lossy(value_bytes);
@ -220,14 +227,14 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.value_stack.push(value);
}
self.call_export_help_after_arg_load(module, fn_index, arg_type_bytes)
self.call_export_help_after_arg_load(module, fn_index, n_args, ret_type)
}
fn call_export_help_before_arg_load<'m>(
&mut self,
module: &'m WasmModule<'a>,
fn_name: &str,
) -> Result<(usize, &'m [u8]), String> {
) -> Result<(usize, SignatureParamsIter<'m>, Option<ValueType>), String> {
let fn_index = {
let mut export_iter = module.export.exports.iter();
export_iter
@ -270,9 +277,9 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
cursor
};
let arg_type_bytes = {
let (param_type_iter, return_type) = {
let signature_index = module.function.signatures[internal_fn_index];
module.types.look_up_arg_type_bytes(signature_index)
module.types.look_up(signature_index)
};
if self.debug_string.is_some() {
@ -284,29 +291,36 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
);
}
Ok((fn_index, arg_type_bytes))
Ok((fn_index, param_type_iter, return_type))
}
fn call_export_help_after_arg_load(
&mut self,
module: &WasmModule<'a>,
fn_index: usize,
arg_type_bytes: &[u8],
n_args: usize,
return_type: Option<ValueType>,
) -> Result<Option<Value>, String> {
self.previous_frames.clear();
self.blocks.clear();
self.blocks.push(Block::Normal {
self.blocks.push(Block {
ty: BlockType::Locals(fn_index),
vstack: self.value_stack.depth(),
});
self.current_frame = Frame::enter(
fn_index,
0, // return_addr
self.blocks.len(),
arg_type_bytes,
n_args,
return_type,
&module.code.bytes,
&mut self.value_stack,
&mut self.program_counter,
);
self.blocks.push(Block {
ty: BlockType::FunctionBody(fn_index),
vstack: self.value_stack.depth(),
});
loop {
match self.execute_next_instruction(module) {
@ -351,14 +365,16 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
fn do_return(&mut self) -> Action {
let Frame {
return_addr,
function_block_depth,
locals_start,
body_block_index,
..
} = self.current_frame;
// Check where in the value stack the current block started
let current_block_base = match self.blocks.last() {
Some(Block::Loop { vstack, .. } | Block::Normal { vstack }) => *vstack,
Some(Block { ty, vstack }) => {
debug_assert!(!matches!(ty, BlockType::Locals(_)));
*vstack
}
_ => 0,
};
@ -369,14 +385,18 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
None
};
self.value_stack.truncate(locals_start);
// Throw away all values from arg[0] upward
let locals_block_index = body_block_index - 1;
let locals_block = &self.blocks[locals_block_index];
self.value_stack.truncate(locals_block.vstack);
if let Some(val) = return_value {
self.value_stack.push(val);
}
self.blocks.truncate(function_block_depth - 1);
// Resume executing at the next instruction in the caller function
let new_block_len = locals_block_index; // don't need a -1 because one is a length and the other is an index!
self.blocks.truncate(new_block_len);
self.program_counter = return_addr;
if let Some(caller_frame) = self.previous_frames.pop() {
self.current_frame = caller_frame;
Action::Continue
@ -416,16 +436,18 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
fn do_break(&mut self, relative_blocks_outward: u32, module: &WasmModule<'a>) {
let block_index = self.blocks.len() - 1 - relative_blocks_outward as usize;
match self.blocks[block_index] {
Block::Loop { start_addr, vstack } => {
let Block { ty, vstack } = self.blocks[block_index];
match ty {
BlockType::Loop(start_addr) => {
self.blocks.truncate(block_index + 1);
self.value_stack.truncate(vstack);
self.program_counter = start_addr;
}
Block::Normal { vstack } => {
BlockType::FunctionBody(_) | BlockType::Normal => {
self.break_forward(relative_blocks_outward, module);
self.value_stack.truncate(vstack);
}
BlockType::Locals(_) => unreachable!(),
}
}
@ -497,16 +519,14 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
);
}
let arg_type_bytes = module.types.look_up_arg_type_bytes(signature_index);
let (arg_type_iter, ret_type) = module.types.look_up(signature_index);
if let Some(import) = opt_import {
self.import_arguments.clear();
self.import_arguments
.extend(std::iter::repeat(Value::I64(0)).take(arg_type_bytes.len()));
for (i, type_byte) in arg_type_bytes.iter().copied().enumerate().rev() {
.extend(std::iter::repeat(Value::I64(0)).take(arg_type_iter.len()));
for (i, expected) in arg_type_iter.enumerate().rev() {
let arg = self.value_stack.pop();
let expected = ValueType::from(type_byte);
let actual = ValueType::from(arg);
if actual != expected {
return Err(Error::ValueStackType(expected, actual));
@ -535,23 +555,31 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
// advance PC to the start of the local variable declarations
u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap();
self.blocks.push(Block::Normal {
self.blocks.push(Block {
ty: BlockType::Locals(fn_index),
vstack: self.value_stack.depth(),
});
let function_block_depth = self.blocks.len();
let body_block_index = self.blocks.len();
let mut swap_frame = Frame::enter(
fn_index,
return_addr,
function_block_depth,
arg_type_bytes,
body_block_index,
arg_type_iter.len(),
ret_type,
&module.code.bytes,
&mut self.value_stack,
&mut self.program_counter,
);
std::mem::swap(&mut swap_frame, &mut self.current_frame);
self.previous_frames.push(swap_frame);
self.blocks.push(Block {
ty: BlockType::FunctionBody(fn_index),
vstack: self.value_stack.depth(),
});
}
Ok(())
}
@ -580,21 +608,23 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
NOP => {}
BLOCK => {
self.fetch_immediate_u32(module); // blocktype (ignored)
self.blocks.push(Block::Normal {
self.blocks.push(Block {
ty: BlockType::Normal,
vstack: self.value_stack.depth(),
});
}
LOOP => {
self.fetch_immediate_u32(module); // blocktype (ignored)
self.blocks.push(Block::Loop {
self.blocks.push(Block {
ty: BlockType::Loop(self.program_counter),
vstack: self.value_stack.depth(),
start_addr: self.program_counter,
});
}
IF => {
self.fetch_immediate_u32(module); // blocktype (ignored)
let condition = self.value_stack.pop_i32()?;
self.blocks.push(Block::Normal {
self.blocks.push(Block {
ty: BlockType::Normal,
vstack: self.value_stack.depth(),
});
if condition == 0 {
@ -647,7 +677,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.do_break(0, module);
}
END => {
if self.blocks.len() == self.current_frame.function_block_depth {
if self.blocks.len() == (self.current_frame.body_block_index + 1) {
// implicit RETURN at end of function
action = self.do_return();
implicit_return = true;
@ -1661,9 +1691,11 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
let is_program_end = self.program_counter == 0;
if is_return && !is_program_end {
eprintln!(
"returning to function {} at {:06x}\n",
"returning to function {} at {:06x}\nwith stack+locals {:?}\nand block indices {:?}\n",
self.current_frame.fn_index,
self.program_counter + self.module.code.section_offset as usize
self.program_counter + self.module.code.section_offset as usize,
&self.value_stack,
&self.blocks
);
} else if op_code == CALL || op_code == CALLINDIRECT {
eprintln!();