diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 529715115d..74d76c0260 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -170,6 +170,10 @@ impl<'a> WasmBackend<'a> { #[cfg(not(debug_assertions))] pub fn register_symbol_debug_names(&self) {} + pub fn get_fn_table_index(&mut self, fn_index: u32) -> i32 { + self.module.element.get_fn_table_index(fn_index) + } + /// Create an IR Symbol for an anonymous value (such as ListLiteral) pub fn create_symbol(&mut self, debug_name: &str) -> Symbol { let ident_ids = self diff --git a/compiler/gen_wasm/src/low_level.rs b/compiler/gen_wasm/src/low_level.rs index 2fc2fcefc3..20f719b6bd 100644 --- a/compiler/gen_wasm/src/low_level.rs +++ b/compiler/gen_wasm/src/low_level.rs @@ -8,8 +8,7 @@ use roc_mono::layout::{Builtin, Layout, UnionLayout}; use roc_mono::low_level::HigherOrder; use crate::backend::{ProcLookupData, ProcSource, WasmBackend}; -use crate::layout::CallConv; -use crate::layout::{StackMemoryFormat, WasmLayout}; +use crate::layout::{CallConv, StackMemoryFormat, WasmLayout}; use crate::storage::{StackMemoryLocation, StoredValue}; use crate::wasm_module::{Align, ValueType}; use crate::TARGET_INFO; @@ -1010,6 +1009,9 @@ pub fn call_higher_order_lowlevel<'a>( let wrapper_fn_idx = backend.register_helper_proc(wrapper_sym, wrapper_layout, source); let inc_fn_idx = backend.gen_refcount_inc_for_zig(closure_data_layout); + let wrapper_fn_ptr = backend.get_fn_table_index(wrapper_fn_idx); + let inc_fn_ptr = backend.get_fn_table_index(inc_fn_idx); + match op { // List.map : List elem_old, (elem_old -> elem_new) -> List elem_new ListMap { xs } => { @@ -1030,13 +1032,13 @@ pub fn call_higher_order_lowlevel<'a>( // Load return pointer & argument values backend.storage.load_symbols(cb, &[return_sym]); backend.storage.load_symbol_zig(cb, *xs); - cb.i32_const(wrapper_fn_idx as i32); + cb.i32_const(wrapper_fn_ptr); if closure_data_exists { backend.storage.load_symbols(cb, &[*captured_environment]); } else { cb.i32_const(0); // null pointer } - cb.i32_const(inc_fn_idx as i32); + cb.i32_const(inc_fn_ptr); cb.i32_const(*owns_captured_environment as i32); cb.i32_const(elem_new_align as i32); // used for allocating the new list cb.i32_const(elem_old_size as i32); diff --git a/compiler/gen_wasm/src/wasm_module/sections.rs b/compiler/gen_wasm/src/wasm_module/sections.rs index 4d30809e1f..f622bef579 100644 --- a/compiler/gen_wasm/src/wasm_module/sections.rs +++ b/compiler/gen_wasm/src/wasm_module/sections.rs @@ -805,7 +805,7 @@ enum ElementSegmentFormatId { #[derive(Debug)] struct ElementSegment<'a> { - offset: ConstExpr, + offset: ConstExpr, // The starting table index for the segment fn_indices: Vec<'a, u32>, } @@ -856,6 +856,9 @@ impl<'a> Serialize for ElementSegment<'a> { } } +/// An "element" represents an indirectly-callable function the Wasm runtime's function table. +/// Future Wasm versions might have tables where the elements are DOM references or other things. +/// Elements can be initialised in groups called "segments". Normally there's just one. #[derive(Debug)] pub struct ElementSection<'a> { segments: Vec<'a, ElementSegment<'a>>, @@ -867,15 +870,41 @@ impl<'a> ElementSection<'a> { pub fn preload(arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize) -> Self { let (num_segments, body_bytes) = parse_section(Self::ID, module_bytes, cursor); - let mut segments = Vec::with_capacity_in(num_segments as usize, arena); + if num_segments == 0 { + let seg = ElementSegment { + offset: ConstExpr::I32(1), + fn_indices: bumpalo::vec![in arena], + }; + ElementSection { + segments: bumpalo::vec![in arena; seg], + } + } else { + let mut segments = Vec::with_capacity_in(num_segments as usize, arena); - let mut body_cursor = 0; - for _ in 0..num_segments { - let seg = ElementSegment::parse(arena, body_bytes, &mut body_cursor); - segments.push(seg); + let mut body_cursor = 0; + for _ in 0..num_segments { + let seg = ElementSegment::parse(arena, body_bytes, &mut body_cursor); + segments.push(seg); + } + ElementSection { segments } } + } - ElementSection { segments } + /// Get a table index for a function (equivalent to a function pointer) + /// The function will be inserted into the table if it's not already there. + /// This index is what the call_indirect instruction expects + /// (It works mostly the same as with pointers, except you can't jump to arbitrary code) + pub fn get_fn_table_index(&mut self, fn_index: u32) -> i32 { + // In practice there is always one segment. We allow a bit more generality by using the last one. + let segment = self.segments.last_mut().unwrap(); + let pos = segment.fn_indices.iter().position(|f| *f == fn_index); + if let Some(existing_table_index) = pos { + existing_table_index as i32 + } else { + let new_table_index = segment.fn_indices.len(); + segment.fn_indices.push(fn_index); + new_table_index as i32 + } } pub fn size(&self) -> usize {