diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 960d99a610..24f71cf832 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -20,11 +20,12 @@ use crate::storage::{StackMemoryLocation, Storage, StoredValue, StoredValueKind} use crate::wasm_module::linking::{DataSymbol, LinkingSegment, WasmObjectSymbol}; use crate::wasm_module::sections::{DataMode, DataSegment}; use crate::wasm_module::{ - code_builder, CodeBuilder, LocalId, Signature, SymInfo, ValueType, WasmModule, + code_builder, CodeBuilder, Export, ExportType, LocalId, Signature, SymInfo, ValueType, + WasmModule, }; use crate::{ - copy_memory, round_up_to_alignment, CopyMemoryConfig, Env, DEBUG_LOG_SETTINGS, PTR_SIZE, - PTR_TYPE, + copy_memory, round_up_to_alignment, CopyMemoryConfig, Env, DEBUG_LOG_SETTINGS, MEMORY_NAME, + PTR_SIZE, PTR_TYPE, STACK_POINTER_GLOBAL_ID, STACK_POINTER_NAME, }; /// The memory address where the constants data will be loaded during module instantiation. @@ -41,7 +42,6 @@ pub struct WasmBackend<'a> { layout_ids: LayoutIds<'a>, next_constant_addr: u32, fn_index_offset: u32, - preloaded_functions_map: MutMap<&'a [u8], u32>, called_preload_fns: MutSet, proc_symbols: Vec<'a, (Symbol, u32)>, helper_proc_gen: CodeGenHelp<'a>, @@ -64,11 +64,21 @@ impl<'a> WasmBackend<'a> { interns: &'a mut Interns, layout_ids: LayoutIds<'a>, proc_symbols: Vec<'a, (Symbol, u32)>, - module: WasmModule<'a>, + mut module: WasmModule<'a>, fn_index_offset: u32, - preloaded_functions_map: MutMap<&'a [u8], u32>, helper_proc_gen: CodeGenHelp<'a>, ) -> Self { + module.export.append(Export { + name: MEMORY_NAME.as_bytes(), + ty: ExportType::Mem, + index: 0, + }); + module.export.append(Export { + name: STACK_POINTER_NAME.as_bytes(), + ty: ExportType::Global, + index: STACK_POINTER_GLOBAL_ID, + }); + WasmBackend { env, interns, @@ -79,7 +89,6 @@ impl<'a> WasmBackend<'a> { layout_ids, next_constant_addr: CONST_SEGMENT_BASE_ADDR, fn_index_offset, - preloaded_functions_map, called_preload_fns: MutSet::default(), proc_symbols, helper_proc_gen, @@ -117,15 +126,8 @@ impl<'a> WasmBackend<'a> { self.module.linking.symbol_table.push(linker_symbol); } - pub fn into_module(mut self, remove_dead_preloads: bool) -> WasmModule<'a> { - if remove_dead_preloads { - self.module.code.remove_dead_preloads( - self.env.arena, - self.module.import.function_count, - self.called_preload_fns, - ) - } - self.module + pub fn finalize(self) -> (WasmModule<'a>, MutSet) { + (self.module, self.called_preload_fns) } /// Register the debug names of Symbols in a global lookup table @@ -1474,7 +1476,7 @@ impl<'a> WasmBackend<'a> { ) { let num_wasm_args = param_types.len(); let has_return_val = ret_type.is_some(); - let fn_index = self.preloaded_functions_map[name.as_bytes()]; + let fn_index = self.module.names.functions[name.as_bytes()]; self.called_preload_fns.insert(fn_index); let linker_symbol_index = u32::MAX; diff --git a/compiler/gen_wasm/src/lib.rs b/compiler/gen_wasm/src/lib.rs index 791626da06..0849bf2027 100644 --- a/compiler/gen_wasm/src/lib.rs +++ b/compiler/gen_wasm/src/lib.rs @@ -41,9 +41,11 @@ pub fn build_module<'a>( preload_bytes: &[u8], procedures: MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>, ) -> Result, String> { - // In production we don't want the test wrapper, just serialize it - let (wasm_module, _) = + let (mut wasm_module, called_preload_fns, _) = build_module_without_test_wrapper(env, interns, preload_bytes, procedures); + + wasm_module.remove_dead_preloads(env.arena, called_preload_fns); + let mut buffer = std::vec::Vec::with_capacity(wasm_module.size()); wasm_module.serialize(&mut buffer); Ok(buffer) @@ -55,14 +57,13 @@ pub fn build_module_without_test_wrapper<'a>( interns: &'a mut Interns, preload_bytes: &[u8], procedures: MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>, -) -> (WasmModule<'a>, u32) { +) -> (WasmModule<'a>, MutSet, u32) { let mut layout_ids = LayoutIds::default(); let mut procs = Vec::with_capacity_in(procedures.len(), env.arena); let mut proc_symbols = Vec::with_capacity_in(procedures.len() * 2, env.arena); let mut linker_symbols = Vec::with_capacity_in(procedures.len() * 2, env.arena); let mut exports = Vec::with_capacity_in(4, env.arena); let mut maybe_main_fn_index = None; - let eliminate_dead_preloads = true; // Collect the symbols & names for the procedures, // and filter out procs we're going to inline @@ -103,11 +104,6 @@ pub fn build_module_without_test_wrapper<'a>( let fn_index_offset: u32 = initial_module.import.function_count + initial_module.code.preloaded_count; - // Get a map of name to index for the preloaded functions - // Assumes the preloaded object file has all symbols exported, as per `zig build-lib -dymamic` - let preloaded_functions_map: MutMap<&'a [u8], u32> = - initial_module.export.function_index_map(env.arena); - let mut backend = WasmBackend::new( env, interns, @@ -115,7 +111,6 @@ pub fn build_module_without_test_wrapper<'a>( proc_symbols, initial_module, fn_index_offset, - preloaded_functions_map, CodeGenHelp::new(env.arena, IntWidth::I32, env.module_id), ); @@ -150,10 +145,10 @@ pub fn build_module_without_test_wrapper<'a>( backend.build_proc(proc); } - let module = backend.into_module(eliminate_dead_preloads); - + let (module, called_preload_fns) = backend.finalize(); let main_function_index = maybe_main_fn_index.unwrap() + fn_index_offset; - (module, main_function_index) + + (module, called_preload_fns, main_function_index) } pub struct CopyMemoryConfig { diff --git a/compiler/gen_wasm/src/wasm_module/dead_code.rs b/compiler/gen_wasm/src/wasm_module/dead_code.rs index c09e6b406f..320f72c8bc 100644 --- a/compiler/gen_wasm/src/wasm_module/dead_code.rs +++ b/compiler/gen_wasm/src/wasm_module/dead_code.rs @@ -131,22 +131,32 @@ pub fn parse_dead_code_metadata<'a>( pub fn trace_function_deps<'a, Indices: IntoIterator>( arena: &'a Bump, metadata: &DeadCodeMetadata<'a>, + exported_fns: &[u32], called_from_app: Indices, ) -> Vec<'a, u32> { - let num_funcs = metadata.ret_types.len(); + let num_preloads = metadata.ret_types.len(); // All functions that get called from the app, directly or indirectly - let mut live_fn_indices = Vec::with_capacity_in(num_funcs, arena); + let mut live_fn_indices = Vec::with_capacity_in(num_preloads, arena); // Current & next batch of functions whose call graphs we want to trace through the metadata // (2 separate vectors so that we're not iterating over the same one we're changing) // If the max call depth is N then we will do N traces or less - let mut current_trace = Vec::with_capacity_in(num_funcs, arena); + let mut current_trace = Vec::with_capacity_in(num_preloads, arena); + let mut next_trace = Vec::with_capacity_in(num_preloads, arena); + + // Start with preloaded functions called from the app or exported directly to Wasm host current_trace.extend(called_from_app); - let mut next_trace = Vec::with_capacity_in(num_funcs, arena); + current_trace.extend( + exported_fns + .iter() + .filter(|idx| **idx < num_preloads as u32), + ); + current_trace.sort_unstable(); + current_trace.dedup(); // Fast per-function lookup table to see if its dependencies have already been traced - let mut already_traced = Vec::from_iter_in(std::iter::repeat(false).take(num_funcs), arena); + let mut already_traced = Vec::from_iter_in(std::iter::repeat(false).take(num_preloads), arena); loop { live_fn_indices.extend_from_slice(¤t_trace); diff --git a/compiler/gen_wasm/src/wasm_module/mod.rs b/compiler/gen_wasm/src/wasm_module/mod.rs index c1519bee6f..6ac96253a6 100644 --- a/compiler/gen_wasm/src/wasm_module/mod.rs +++ b/compiler/gen_wasm/src/wasm_module/mod.rs @@ -11,13 +11,17 @@ pub use linking::SymInfo; use roc_reporting::internal_error; pub use sections::{ConstExpr, Export, ExportType, Global, GlobalType, Signature}; +use crate::wasm_module::serialize::SkipBytes; + use self::linking::{LinkingSection, RelocationSection}; use self::sections::{ CodeSection, DataSection, ExportSection, FunctionSection, GlobalSection, ImportSection, - MemorySection, OpaqueSection, Section, SectionId, TypeSection, + MemorySection, NameSection, OpaqueSection, Section, SectionId, TypeSection, }; use self::serialize::{SerialBuffer, Serialize}; +/// A representation of the WebAssembly binary file format +/// https://webassembly.github.io/spec/core/binary/modules.html #[derive(Debug)] pub struct WasmModule<'a> { pub types: TypeSection<'a>, @@ -31,6 +35,7 @@ pub struct WasmModule<'a> { pub element: OpaqueSection<'a>, pub code: CodeSection<'a>, pub data: DataSection<'a>, + pub names: NameSection<'a>, pub linking: LinkingSection<'a>, pub relocations: RelocationSection<'a>, } @@ -128,15 +133,23 @@ impl<'a> WasmModule<'a> { let ret_types = types.parse_preloaded_data(arena); let import = ImportSection::preload(arena, bytes, &mut cursor); + let function = FunctionSection::preload(arena, bytes, &mut cursor); let signature_ids = function.parse_preloaded_data(arena); let table = OpaqueSection::preload(SectionId::Table, arena, bytes, &mut cursor); + let memory = MemorySection::preload(arena, bytes, &mut cursor); + let global = GlobalSection::preload(arena, bytes, &mut cursor); - let export = ExportSection::preload(arena, bytes, &mut cursor); + + ExportSection::skip_bytes(bytes, &mut cursor); + let export = ExportSection::empty(arena); + let start = OpaqueSection::preload(SectionId::Start, arena, bytes, &mut cursor); + let element = OpaqueSection::preload(SectionId::Element, arena, bytes, &mut cursor); + let code = CodeSection::preload( arena, bytes, @@ -145,7 +158,11 @@ impl<'a> WasmModule<'a> { signature_ids, import.function_count, ); + let data = DataSection::preload(arena, bytes, &mut cursor); + + // Metadata sections + let names = NameSection::parse(arena, bytes, &mut cursor); let linking = LinkingSection::new(arena); let relocations = RelocationSection::new(arena, "reloc.CODE"); @@ -161,10 +178,24 @@ impl<'a> WasmModule<'a> { element, code, data, + names, linking, relocations, } } + + pub fn remove_dead_preloads>( + &mut self, + arena: &'a Bump, + called_preload_fns: T, + ) { + self.code.remove_dead_preloads( + arena, + self.import.function_count, + &self.export.function_indices, + called_preload_fns, + ) + } } /// Helper struct to count non-empty sections. diff --git a/compiler/gen_wasm/src/wasm_module/sections.rs b/compiler/gen_wasm/src/wasm_module/sections.rs index bfd38347f1..8b9ae099c0 100644 --- a/compiler/gen_wasm/src/wasm_module/sections.rs +++ b/compiler/gen_wasm/src/wasm_module/sections.rs @@ -668,38 +668,53 @@ impl Serialize for Export<'_> { pub struct ExportSection<'a> { pub count: u32, pub bytes: Vec<'a, u8>, + pub function_indices: Vec<'a, u32>, } impl<'a> ExportSection<'a> { + const ID: SectionId = SectionId::Export; + pub fn append(&mut self, export: Export) { export.serialize(&mut self.bytes); self.count += 1; + if matches!(export.ty, ExportType::Func) { + self.function_indices.push(export.index); + } } - pub fn function_index_map(&self, arena: &'a Bump) -> MutMap<&'a [u8], u32> { - let mut map = MutMap::default(); + pub fn size(&self) -> usize { + let id = 1; + let encoded_length = 5; + let encoded_count = 5; - let mut cursor = 0; - while cursor < self.bytes.len() { - let name_len = parse_u32_or_panic(&self.bytes, &mut cursor); - let name_end = cursor + name_len as usize; - let name_bytes = &self.bytes[cursor..name_end]; - let ty = self.bytes[name_end]; + id + encoded_length + encoded_count + self.bytes.len() + } - cursor = name_end + 1; - let index = parse_u32_or_panic(&self.bytes, &mut cursor); - - if ty == ExportType::Func as u8 { - let name: &'a [u8] = arena.alloc_slice_clone(name_bytes); - map.insert(name, index); - } + pub fn empty(arena: &'a Bump) -> Self { + ExportSection { + count: 0, + bytes: Vec::with_capacity_in(256, arena), + function_indices: Vec::with_capacity_in(4, arena), } - - map } } -section_impl!(ExportSection, SectionId::Export); +impl SkipBytes for ExportSection<'_> { + fn skip_bytes(bytes: &[u8], cursor: &mut usize) { + parse_section(Self::ID, bytes, cursor); + } +} + +impl<'a> Serialize for ExportSection<'a> { + fn serialize(&self, buffer: &mut T) { + if !self.bytes.is_empty() { + let header_indices = write_section_header(buffer, Self::ID); + buffer.encode_u32(self.count); + buffer.append_slice(&self.bytes); + update_section_size(buffer, header_indices); + } + } +} /******************************************************************* * @@ -769,14 +784,19 @@ impl<'a> CodeSection<'a> { } } - pub fn remove_dead_preloads>( + pub(super) fn remove_dead_preloads>( &mut self, arena: &'a Bump, import_fn_count: u32, + exported_fns: &[u32], called_preload_fns: T, ) { - let live_ext_fn_indices = - trace_function_deps(arena, &self.dead_code_metadata, called_preload_fns); + let live_ext_fn_indices = trace_function_deps( + arena, + &self.dead_code_metadata, + exported_fns, + called_preload_fns, + ); let mut buffer = Vec::with_capacity_in(self.preloaded_bytes.len(), arena); @@ -871,9 +891,7 @@ section_impl!(DataSection, SectionId::Data); /******************************************************************* * - * Module - * - * https://webassembly.github.io/spec/core/binary/modules.html + * Opaque section * *******************************************************************/ @@ -920,6 +938,107 @@ impl Serialize for OpaqueSection<'_> { } } +/******************************************************************* + * + * Name section + * https://webassembly.github.io/spec/core/appendix/custom.html#name-section + * + *******************************************************************/ + +#[repr(u8)] +#[allow(dead_code)] +enum NameSubSections { + ModuleName = 0, + FunctionNames = 1, + LocalNames = 2, +} + +#[derive(Debug, Default)] +pub struct NameSection<'a> { + pub functions: MutMap<&'a [u8], u32>, +} + +impl<'a> NameSection<'a> { + const ID: SectionId = SectionId::Custom; + const NAME: &'static str = "name"; + + pub fn parse(arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize) -> Self { + let functions = MutMap::default(); + let mut section = NameSection { functions }; + section.parse_help(arena, module_bytes, cursor); + section + } + + fn parse_help(&mut self, arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize) { + // Custom section ID + let section_id_byte = module_bytes[*cursor]; + if section_id_byte != Self::ID as u8 { + internal_error!( + "Expected section ID 0x{:x}, but found 0x{:x} at offset 0x{:x}", + Self::ID as u8, + section_id_byte, + *cursor + ); + } + *cursor += 1; + + // Section size + let section_size = parse_u32_or_panic(module_bytes, cursor); + let section_end = *cursor + section_size as usize; + + // Custom section name + let section_name_len = parse_u32_or_panic(module_bytes, cursor); + let section_name_end = *cursor + section_name_len as usize; + let section_name = &module_bytes[*cursor..section_name_end]; + if section_name != Self::NAME.as_bytes() { + internal_error!( + "Expected Custon section {:?}, found {:?}", + Self::NAME, + std::str::from_utf8(section_name) + ); + } + *cursor = section_name_end; + + // Find function names subsection + let mut found_function_names = false; + for _possible_subsection_id in 0..2 { + let subsection_id = module_bytes[*cursor]; + *cursor += 1; + let subsection_size = parse_u32_or_panic(module_bytes, cursor); + if subsection_id == NameSubSections::FunctionNames as u8 { + found_function_names = true; + break; + } + *cursor += subsection_size as usize; + if *cursor >= section_end { + internal_error!("Failed to parse Name section"); + } + } + if !found_function_names { + internal_error!("Failed to parse Name section"); + } + + // Function names + let num_entries = parse_u32_or_panic(module_bytes, cursor) as usize; + for _ in 0..num_entries { + let fn_index = parse_u32_or_panic(module_bytes, cursor); + let name_len = parse_u32_or_panic(module_bytes, cursor); + let name_end = *cursor + name_len as usize; + let name_bytes: &[u8] = &module_bytes[*cursor..name_end]; + *cursor = name_end; + + self.functions + .insert(arena.alloc_slice_copy(name_bytes), fn_index); + } + } +} + +/******************************************************************* + * + * Unit tests + * + *******************************************************************/ + #[cfg(test)] mod tests { use super::*; diff --git a/compiler/test_gen/src/helpers/wasm.rs b/compiler/test_gen/src/helpers/wasm.rs index 0869214de8..21d6c6b36c 100644 --- a/compiler/test_gen/src/helpers/wasm.rs +++ b/compiler/test_gen/src/helpers/wasm.rs @@ -1,4 +1,5 @@ use core::cell::Cell; +use roc_gen_wasm::wasm_module::{Export, ExportType}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::marker::PhantomData; @@ -16,6 +17,7 @@ const PLATFORM_FILENAME: &str = "wasm_test_platform"; const OUT_DIR_VAR: &str = "TEST_GEN_OUT"; const TEST_WRAPPER_NAME: &str = "test_wrapper"; +const INIT_REFCOUNT_NAME: &str = "init_refcount_test"; fn promote_expr_to_module(src: &str) -> String { let mut buffer = String::from("app \"test\" provides [ main ] to \"./platform\"\n\nmain =\n"); @@ -125,17 +127,29 @@ fn compile_roc_to_wasm_bytes<'a, T: Wasm32TestResult>( exposed_to_host, }; - let (mut wasm_module, main_fn_index) = roc_gen_wasm::build_module_without_test_wrapper( - &env, - &mut interns, - preload_bytes, - procedures, - ); + let (mut module, called_preload_fns, main_fn_index) = + roc_gen_wasm::build_module_without_test_wrapper( + &env, + &mut interns, + preload_bytes, + procedures, + ); - T::insert_test_wrapper(arena, &mut wasm_module, TEST_WRAPPER_NAME, main_fn_index); + T::insert_test_wrapper(arena, &mut module, TEST_WRAPPER_NAME, main_fn_index); - let mut app_module_bytes = std::vec::Vec::with_capacity(4096); - wasm_module.serialize(&mut app_module_bytes); + // Export the initialiser function for refcount tests + let init_refcount_bytes = INIT_REFCOUNT_NAME.as_bytes(); + let init_refcount_idx = module.names.functions[init_refcount_bytes]; + module.export.append(Export { + name: arena.alloc_slice_copy(init_refcount_bytes), + ty: ExportType::Func, + index: init_refcount_idx, + }); + + module.remove_dead_preloads(env.arena, called_preload_fns); + + let mut app_module_bytes = std::vec::Vec::with_capacity(module.size()); + module.serialize(&mut app_module_bytes); app_module_bytes } @@ -235,7 +249,7 @@ where let memory = instance.exports.get_memory(MEMORY_NAME).unwrap(); let expected_len = num_refcounts as i32; - let init_refcount_test = instance.exports.get_function("init_refcount_test").unwrap(); + let init_refcount_test = instance.exports.get_function(INIT_REFCOUNT_NAME).unwrap(); let init_result = init_refcount_test.call(&[wasmer::Value::I32(expected_len)]); let refcount_vector_addr = match init_result { Err(e) => return Err(format!("{:?}", e)),