wasm: function pointers for higher order calls

This commit is contained in:
Brian Carroll 2022-04-02 10:15:51 +01:00
parent 5db3ae0227
commit 8620cdf75c
3 changed files with 46 additions and 11 deletions

View file

@ -170,6 +170,10 @@ impl<'a> WasmBackend<'a> {
#[cfg(not(debug_assertions))]
pub fn register_symbol_debug_names(&self) {}
pub fn get_fn_table_index(&mut self, fn_index: u32) -> i32 {
self.module.element.get_fn_table_index(fn_index)
}
/// Create an IR Symbol for an anonymous value (such as ListLiteral)
pub fn create_symbol(&mut self, debug_name: &str) -> Symbol {
let ident_ids = self

View file

@ -8,8 +8,7 @@ use roc_mono::layout::{Builtin, Layout, UnionLayout};
use roc_mono::low_level::HigherOrder;
use crate::backend::{ProcLookupData, ProcSource, WasmBackend};
use crate::layout::CallConv;
use crate::layout::{StackMemoryFormat, WasmLayout};
use crate::layout::{CallConv, StackMemoryFormat, WasmLayout};
use crate::storage::{StackMemoryLocation, StoredValue};
use crate::wasm_module::{Align, ValueType};
use crate::TARGET_INFO;
@ -1010,6 +1009,9 @@ pub fn call_higher_order_lowlevel<'a>(
let wrapper_fn_idx = backend.register_helper_proc(wrapper_sym, wrapper_layout, source);
let inc_fn_idx = backend.gen_refcount_inc_for_zig(closure_data_layout);
let wrapper_fn_ptr = backend.get_fn_table_index(wrapper_fn_idx);
let inc_fn_ptr = backend.get_fn_table_index(inc_fn_idx);
match op {
// List.map : List elem_old, (elem_old -> elem_new) -> List elem_new
ListMap { xs } => {
@ -1030,13 +1032,13 @@ pub fn call_higher_order_lowlevel<'a>(
// Load return pointer & argument values
backend.storage.load_symbols(cb, &[return_sym]);
backend.storage.load_symbol_zig(cb, *xs);
cb.i32_const(wrapper_fn_idx as i32);
cb.i32_const(wrapper_fn_ptr);
if closure_data_exists {
backend.storage.load_symbols(cb, &[*captured_environment]);
} else {
cb.i32_const(0); // null pointer
}
cb.i32_const(inc_fn_idx as i32);
cb.i32_const(inc_fn_ptr);
cb.i32_const(*owns_captured_environment as i32);
cb.i32_const(elem_new_align as i32); // used for allocating the new list
cb.i32_const(elem_old_size as i32);

View file

@ -805,7 +805,7 @@ enum ElementSegmentFormatId {
#[derive(Debug)]
struct ElementSegment<'a> {
offset: ConstExpr,
offset: ConstExpr, // The starting table index for the segment
fn_indices: Vec<'a, u32>,
}
@ -856,6 +856,9 @@ impl<'a> Serialize for ElementSegment<'a> {
}
}
/// An "element" represents an indirectly-callable function the Wasm runtime's function table.
/// Future Wasm versions might have tables where the elements are DOM references or other things.
/// Elements can be initialised in groups called "segments". Normally there's just one.
#[derive(Debug)]
pub struct ElementSection<'a> {
segments: Vec<'a, ElementSegment<'a>>,
@ -867,6 +870,15 @@ impl<'a> ElementSection<'a> {
pub fn preload(arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize) -> Self {
let (num_segments, body_bytes) = parse_section(Self::ID, module_bytes, cursor);
if num_segments == 0 {
let seg = ElementSegment {
offset: ConstExpr::I32(1),
fn_indices: bumpalo::vec![in arena],
};
ElementSection {
segments: bumpalo::vec![in arena; seg],
}
} else {
let mut segments = Vec::with_capacity_in(num_segments as usize, arena);
let mut body_cursor = 0;
@ -874,9 +886,26 @@ impl<'a> ElementSection<'a> {
let seg = ElementSegment::parse(arena, body_bytes, &mut body_cursor);
segments.push(seg);
}
ElementSection { segments }
}
}
/// Get a table index for a function (equivalent to a function pointer)
/// The function will be inserted into the table if it's not already there.
/// This index is what the call_indirect instruction expects
/// (It works mostly the same as with pointers, except you can't jump to arbitrary code)
pub fn get_fn_table_index(&mut self, fn_index: u32) -> i32 {
// In practice there is always one segment. We allow a bit more generality by using the last one.
let segment = self.segments.last_mut().unwrap();
let pos = segment.fn_indices.iter().position(|f| *f == fn_index);
if let Some(existing_table_index) = pos {
existing_table_index as i32
} else {
let new_table_index = segment.fn_indices.len();
segment.fn_indices.push(fn_index);
new_table_index as i32
}
}
pub fn size(&self) -> usize {
self.segments.iter().map(|seg| seg.size()).sum()