diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index c89c6b16d2..af834e6233 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -2,7 +2,7 @@ use bumpalo::{self, collections::Vec}; use code_builder::Align; use roc_builtins::bitcode::{self, IntWidth}; -use roc_collections::all::MutMap; +use roc_collections::all::{MutMap, MutSet}; use roc_module::ident::Ident; use roc_module::low_level::{LowLevel, LowLevelWrapperType}; use roc_module::symbol::{Interns, Symbol}; @@ -42,6 +42,7 @@ pub struct WasmBackend<'a> { next_constant_addr: u32, fn_index_offset: u32, preloaded_functions_map: MutMap<&'a [u8], u32>, + called_preload_fns: MutSet, proc_symbols: Vec<'a, (Symbol, u32)>, helper_proc_gen: CodeGenHelp<'a>, @@ -79,6 +80,7 @@ impl<'a> WasmBackend<'a> { next_constant_addr: CONST_SEGMENT_BASE_ADDR, fn_index_offset, preloaded_functions_map, + called_preload_fns: MutSet::default(), proc_symbols, helper_proc_gen, @@ -115,7 +117,12 @@ impl<'a> WasmBackend<'a> { self.module.linking.symbol_table.push(linker_symbol); } - pub fn into_module(self) -> WasmModule<'a> { + pub fn into_module(mut self, remove_dead_preloads: bool) -> WasmModule<'a> { + if remove_dead_preloads { + self.module + .code + .remove_dead_preloads(self.env.arena, self.called_preload_fns) + } self.module } @@ -1466,6 +1473,7 @@ impl<'a> WasmBackend<'a> { let num_wasm_args = param_types.len(); let has_return_val = ret_type.is_some(); let fn_index = self.preloaded_functions_map[name.as_bytes()]; + self.called_preload_fns.insert(fn_index); let linker_symbol_index = u32::MAX; self.code_builder diff --git a/compiler/gen_wasm/src/lib.rs b/compiler/gen_wasm/src/lib.rs index ca591f95fb..d084e75260 100644 --- a/compiler/gen_wasm/src/lib.rs +++ b/compiler/gen_wasm/src/lib.rs @@ -13,7 +13,6 @@ use roc_module::symbol::{Interns, ModuleId, Symbol}; use roc_mono::code_gen_help::CodeGenHelp; use roc_mono::ir::{Proc, ProcLayout}; use roc_mono::layout::LayoutIds; -use roc_reporting::internal_error; use crate::backend::WasmBackend; use crate::wasm_module::{ @@ -42,7 +41,7 @@ pub fn build_module<'a>( procedures: MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>, ) -> Result, String> { let (wasm_module, _) = build_module_help(env, interns, preload_bytes, procedures)?; - let mut buffer = std::vec::Vec::with_capacity(4096); + let mut buffer = std::vec::Vec::with_capacity(wasm_module.size()); wasm_module.serialize(&mut buffer); Ok(buffer) } @@ -59,6 +58,7 @@ pub fn build_module_help<'a>( let mut linker_symbols = Vec::with_capacity_in(procedures.len() * 2, env.arena); let mut exports = Vec::with_capacity_in(4, env.arena); let mut maybe_main_fn_index = None; + let eliminate_dead_preloads = true; // Collect the symbols & names for the procedures, // and filter out procs we're going to inline @@ -146,7 +146,7 @@ pub fn build_module_help<'a>( backend.build_proc(proc); } - let module = backend.into_module(); + let module = backend.into_module(eliminate_dead_preloads); let main_function_index = maybe_main_fn_index.unwrap() + fn_index_offset; Ok((module, main_function_index)) @@ -214,10 +214,6 @@ macro_rules! round_up_to_alignment { }; } -pub fn debug_panic(error: E) { - internal_error!("{:?}", error); -} - pub struct WasmDebugLogSettings { proc_start_end: bool, user_procs_ir: bool, @@ -233,5 +229,5 @@ pub const DEBUG_LOG_SETTINGS: WasmDebugLogSettings = WasmDebugLogSettings { helper_procs_ir: false && cfg!(debug_assertions), let_stmt_ir: false && cfg!(debug_assertions), instructions: false && cfg!(debug_assertions), - keep_test_binary: true && cfg!(debug_assertions), + keep_test_binary: false && cfg!(debug_assertions), }; diff --git a/compiler/gen_wasm/src/wasm_module/code_builder.rs b/compiler/gen_wasm/src/wasm_module/code_builder.rs index 5676508287..db51b2d22e 100644 --- a/compiler/gen_wasm/src/wasm_module/code_builder.rs +++ b/compiler/gen_wasm/src/wasm_module/code_builder.rs @@ -37,6 +37,18 @@ impl Serialize for ValueType { } } +impl From for ValueType { + fn from(x: u8) -> Self { + match x { + 0x7f => Self::I32, + 0x7e => Self::I64, + 0x7d => Self::F32, + 0x7c => Self::F64, + _ => internal_error!("Invalid ValueType 0x{:02x}", x), + } + } +} + const BLOCK_NO_RESULT: u8 = 0x40; /// A control block in our model of the VM diff --git a/compiler/gen_wasm/src/wasm_module/dead_code.rs b/compiler/gen_wasm/src/wasm_module/dead_code.rs index e258bb8d52..af0e1494da 100644 --- a/compiler/gen_wasm/src/wasm_module/dead_code.rs +++ b/compiler/gen_wasm/src/wasm_module/dead_code.rs @@ -1,14 +1,20 @@ +use bumpalo::collections::vec::Vec; +use bumpalo::Bump; + +use super::opcodes::OpCode; +use super::serialize::{parse_u32_or_panic, SerialBuffer, Serialize, SkipBytes}; +use super::{CodeBuilder, ValueType}; #[derive(Debug)] pub struct DeadCodeMetadata<'a> { - /// Byte offset (in the module) where each function body can be found + /// Byte offset where each function body can be found code_offsets: Vec<'a, u32>, /// Vector with one entry per *call*, containing the called function's index calls: Vec<'a, u32>, /// Vector with one entry per *function*, indicating its offset in `calls` calls_offsets: Vec<'a, u32>, /// Return types of each function (for making almost-empty dummy replacements) - ret_types: Vec<'a, u8>, + ret_types: Vec<'a, Option>, } impl<'a> DeadCodeMetadata<'a> { @@ -28,69 +34,110 @@ impl<'a> DeadCodeMetadata<'a> { /// use this backend without a linker. pub fn parse_dead_code_metadata<'a>( arena: &'a Bump, - module_bytes: &[u8], - cursor: &mut usize, + func_count: u32, + code_section_body: &[u8], + ret_types: Vec<'a, Option>, + signature_ids: Vec<'a, u32>, ) -> DeadCodeMetadata<'a> { - if module_bytes[*cursor] != SectionId::Code as u8 { - internal_error!("Expected Code section in object file at offset {}", *cursor); - } - *cursor += 1; - - let section_size = parse_u32_or_panic(module_bytes, cursor); - let count_start = *cursor; - let section_end = count_start + section_size as usize; - let func_count = parse_u32_or_panic(module_bytes, cursor); - let mut metadata = DeadCodeMetadata::new(arena, func_count as usize); + metadata + .ret_types + .extend(signature_ids.iter().map(|sig| ret_types[*sig as usize])); - while *cursor < section_end { - metadata.code_offsets.push(*cursor as u32); + let mut cursor: usize = 0; + while cursor < code_section_body.len() { + metadata.code_offsets.push(cursor as u32); metadata.calls_offsets.push(metadata.calls.len() as u32); - let func_size = parse_u32_or_panic(module_bytes, cursor); - let func_end = *cursor + func_size as usize; + let func_size = parse_u32_or_panic(code_section_body, &mut cursor); + let func_end = cursor + func_size as usize; // Local variable declarations - let local_groups_count = parse_u32_or_panic(module_bytes, cursor); + let local_groups_count = parse_u32_or_panic(code_section_body, &mut cursor); for _ in 0..local_groups_count { - let _group_len = parse_u32_or_panic(module_bytes, cursor); - *cursor += 1; // ValueType + parse_u32_or_panic(code_section_body, &mut cursor); + cursor += 1; // ValueType } // Instructions - while *cursor < func_end { - let opcode_byte: u8 = module_bytes[*cursor]; + while cursor < func_end { + let opcode_byte: u8 = code_section_body[cursor]; if opcode_byte == OpCode::CALL as u8 { - *cursor += 1; - let call_index = parse_u32_or_panic(module_bytes, cursor); + cursor += 1; + let call_index = parse_u32_or_panic(code_section_body, &mut cursor); metadata.calls.push(call_index as u32); } else { - OpCode::skip_bytes(module_bytes, cursor); + OpCode::skip_bytes(code_section_body, &mut cursor); } } } // Extra entries to mark the end of the last function - metadata.code_offsets.push(*cursor as u32); + metadata.code_offsets.push(cursor as u32); metadata.calls_offsets.push(metadata.calls.len() as u32); metadata } -/// Copy used functions (and their dependencies!) from an external module into our Code section -/// Replace unused functions with very small dummies, to avoid changing any indices -pub fn copy_used_functions<'a, T: SerialBuffer>( +/// Trace the dependencies of a list of functions +/// We've already collected metadata saying which functions call each other +/// Now we need to trace the dependency graphs of a specific subset of them +/// Result is the full set of builtins and platform functions used in the app. +/// The rest are "dead code" and can be eliminated. +pub fn trace_function_deps<'a, Indices: IntoIterator>( arena: &'a Bump, - buffer: &mut T, - metadata: DeadCodeMetadata<'a>, - external_module: &[u8], - sorted_used_func_indices: &[u32], -) { - let [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] = create_dummy_functions(arena); + metadata: &DeadCodeMetadata<'a>, + called_from_app: Indices, +) -> Vec<'a, u32> { + let mut live_fn_indices: Vec<'a, u32> = Vec::with_capacity_in(metadata.calls.len(), arena); + live_fn_indices.extend(called_from_app); + + let num_funcs = metadata.calls_offsets.len(); + + // Current batch of functions whose call graphs we want to trace + let mut current_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena); + current_trace.clone_from(&live_fn_indices); + + // The next batch (don't want to modify the current one while we're iterating over it!) + let mut next_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena); + + // Fast lookup for what's already traced so we don't need to do it again + let mut already_traced: Vec<'a, bool> = Vec::from_iter_in((0..num_funcs).map(|_| false), arena); + + loop { + live_fn_indices.extend_from_slice(¤t_trace); + + for func_idx in current_trace.iter() { + let i = *func_idx as usize; + already_traced[i] = true; + let calls_start = metadata.calls_offsets[i] as usize; + let calls_end = metadata.calls_offsets[i + 1] as usize; + let called_indices: &[u32] = &metadata.calls[calls_start..calls_end]; + for called_idx in called_indices { + if !already_traced[*called_idx as usize] { + next_trace.push(*called_idx); + } + } + } + if next_trace.is_empty() { + break; + } + current_trace.clone_from(&next_trace); + next_trace.clear(); + } + + if true { + println!("Hey Brian, don't forget to remove this debug code"); + let unsorted_len = live_fn_indices.len(); + live_fn_indices.dedup(); + debug_assert!(unsorted_len == live_fn_indices.len()); + } + + live_fn_indices } -/// Create a set of dummy functions that just return a constant of each possible type -fn create_dummy_functions<'a>(arena: &'a Bump) -> [Vec<'a, u8>; 5] { +/// Create a set of minimum-size dummy functions for each possible return type +fn create_dummy_functions(arena: &Bump) -> [Vec<'_, u8>; 5] { let mut code_builder_i32 = CodeBuilder::new(arena); code_builder_i32.i32_const(0); @@ -111,11 +158,12 @@ fn create_dummy_functions<'a>(arena: &'a Bump) -> [Vec<'a, u8>; 5] { code_builder_f64.build_fn_header_and_footer(&[], 0, None); code_builder_nil.build_fn_header_and_footer(&[], 0, None); - let mut dummy_i32 = Vec::with_capacity_in(code_builder_i32.size(), arena); - let mut dummy_i64 = Vec::with_capacity_in(code_builder_i64.size(), arena); - let mut dummy_f32 = Vec::with_capacity_in(code_builder_f32.size(), arena); - let mut dummy_f64 = Vec::with_capacity_in(code_builder_f64.size(), arena); - let mut dummy_nil = Vec::with_capacity_in(code_builder_nil.size(), arena); + let capacity = code_builder_f64.size(); + let mut dummy_i32 = Vec::with_capacity_in(capacity, arena); + let mut dummy_i64 = Vec::with_capacity_in(capacity, arena); + let mut dummy_f32 = Vec::with_capacity_in(capacity, arena); + let mut dummy_f64 = Vec::with_capacity_in(capacity, arena); + let mut dummy_nil = Vec::with_capacity_in(capacity, arena); code_builder_i32.serialize(&mut dummy_i32); code_builder_i64.serialize(&mut dummy_i64); @@ -125,3 +173,41 @@ fn create_dummy_functions<'a>(arena: &'a Bump) -> [Vec<'a, u8>; 5] { [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] } + +/// Copy used functions from an external module into our Code section +/// Replace unused functions with very small dummies, to avoid changing any indices +pub fn copy_live_and_replace_dead<'a, T: SerialBuffer>( + arena: &'a Bump, + buffer: &mut T, + metadata: &DeadCodeMetadata<'a>, + external_code: &[u8], + live_ext_fn_indices: &'a mut [u32], +) { + live_ext_fn_indices.sort_unstable(); + + let [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] = create_dummy_functions(arena); + + let mut prev = 0; + for live32 in live_ext_fn_indices.iter() { + let live = *live32 as usize; + + // Replace dead functions with the minimal code body that will pass validation checks + for dead in prev..live { + let dummy_bytes = match metadata.ret_types[dead] { + Some(ValueType::I32) => &dummy_i32, + Some(ValueType::I64) => &dummy_i64, + Some(ValueType::F32) => &dummy_f32, + Some(ValueType::F64) => &dummy_f64, + None => &dummy_nil, + }; + buffer.append_slice(dummy_bytes); + } + + // Copy the body of the live function from the external module + let live_body_start = metadata.code_offsets[live] as usize; + let live_body_end = metadata.code_offsets[live + 1] as usize; + buffer.append_slice(&external_code[live_body_start..live_body_end]); + + prev = live + 1; + } +} diff --git a/compiler/gen_wasm/src/wasm_module/mod.rs b/compiler/gen_wasm/src/wasm_module/mod.rs index cb55201d0d..e273817f72 100644 --- a/compiler/gen_wasm/src/wasm_module/mod.rs +++ b/compiler/gen_wasm/src/wasm_module/mod.rs @@ -1,4 +1,5 @@ pub mod code_builder; +mod dead_code; pub mod linking; pub mod opcodes; pub mod sections; @@ -44,7 +45,6 @@ impl<'a> WasmModule<'a> { } /// Serialize the module to bytes - /// (not using Serialize trait because it's just one more thing to export) pub fn serialize(&self, buffer: &mut T) { buffer.append_u8(0); buffer.append_slice("asm".as_bytes()); @@ -125,16 +125,19 @@ impl<'a> WasmModule<'a> { let mut cursor: usize = 8; let mut types = TypeSection::preload(arena, bytes, &mut cursor); - types.cache_offsets(); + let ret_types = types.parse_preloaded_data(arena); + let import = ImportSection::preload(arena, bytes, &mut cursor); let function = FunctionSection::preload(arena, bytes, &mut cursor); + let signature_ids = function.parse_preloaded_data(arena); + let table = OpaqueSection::preload(SectionId::Table, arena, bytes, &mut cursor); let memory = MemorySection::preload(arena, bytes, &mut cursor); let global = GlobalSection::preload(arena, bytes, &mut cursor); let export = ExportSection::preload(arena, bytes, &mut cursor); let start = OpaqueSection::preload(SectionId::Start, arena, bytes, &mut cursor); let element = OpaqueSection::preload(SectionId::Element, arena, bytes, &mut cursor); - let code = CodeSection::preload(arena, bytes, &mut cursor); + let code = CodeSection::preload(arena, bytes, &mut cursor, ret_types, signature_ids); let data = DataSection::preload(arena, bytes, &mut cursor); let linking = LinkingSection::new(arena); let relocations = RelocationSection::new(arena, "reloc.CODE"); diff --git a/compiler/gen_wasm/src/wasm_module/sections.rs b/compiler/gen_wasm/src/wasm_module/sections.rs index c314106136..fb19adc2ea 100644 --- a/compiler/gen_wasm/src/wasm_module/sections.rs +++ b/compiler/gen_wasm/src/wasm_module/sections.rs @@ -3,6 +3,9 @@ use bumpalo::Bump; use roc_collections::all::MutMap; use roc_reporting::internal_error; +use super::dead_code::{ + copy_live_and_replace_dead, parse_dead_code_metadata, trace_function_deps, DeadCodeMetadata, +}; use super::linking::RelocationEntry; use super::opcodes::OpCode; use super::serialize::{ @@ -212,8 +215,10 @@ impl<'a> TypeSection<'a> { sig_id as u32 } - pub fn cache_offsets(&mut self) { + pub fn parse_preloaded_data(&mut self, arena: &'a Bump) -> Vec<'a, Option> { self.offsets.clear(); + let mut ret_types = Vec::with_capacity_in(self.offsets.capacity(), arena); + let mut i = 0; while i < self.bytes.len() { self.offsets.push(i); @@ -226,8 +231,18 @@ impl<'a> TypeSection<'a> { i += n_params as usize; // skip over one byte per param type let n_return_values = self.bytes[i]; - i += 1 + n_return_values as usize; + i += 1; + + ret_types.push(if n_return_values == 0 { + None + } else { + Some(ValueType::from(self.bytes[i])) + }); + + i += n_return_values as usize; } + + ret_types } } @@ -412,6 +427,15 @@ impl<'a> FunctionSection<'a> { self.bytes.encode_u32(sig_id); self.count += 1; } + + pub fn parse_preloaded_data(&self, arena: &'a Bump) -> Vec<'a, u32> { + let mut preload_signature_ids = Vec::with_capacity_in(self.count as usize, arena); + let mut cursor = 0; + while cursor < self.bytes.len() { + preload_signature_ids.push(parse_u32_or_panic(&self.bytes, &mut cursor)); + } + preload_signature_ids + } } section_impl!(FunctionSection, SectionId::Function); @@ -663,8 +687,9 @@ section_impl!(ExportSection, SectionId::Export); #[derive(Debug)] pub struct CodeSection<'a> { pub preloaded_count: u32, - pub preloaded_bytes: Vec<'a, u8>, + pub preloaded_bytes: &'a [u8], pub code_builders: Vec<'a, CodeBuilder<'a>>, + dead_code_metadata: DeadCodeMetadata<'a>, } impl<'a> CodeSection<'a> { @@ -677,8 +702,6 @@ impl<'a> CodeSection<'a> { let header_indices = write_section_header(buffer, SectionId::Code); buffer.encode_u32(self.preloaded_count + self.code_builders.len() as u32); - buffer.append_slice(&self.preloaded_bytes); - for code_builder in self.code_builders.iter() { code_builder.serialize_with_relocs(buffer, relocations, header_indices.body_index); } @@ -694,16 +717,52 @@ impl<'a> CodeSection<'a> { MAX_SIZE_SECTION_HEADER + self.preloaded_bytes.len() + builders_size } - pub fn preload(arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize) -> Self { + pub fn preload( + arena: &'a Bump, + module_bytes: &[u8], + cursor: &mut usize, + ret_types: Vec<'a, Option>, + signature_ids: Vec<'a, u32>, + ) -> Self { let (preloaded_count, initial_bytes) = parse_section(SectionId::Code, module_bytes, cursor); - let mut preloaded_bytes = Vec::with_capacity_in(initial_bytes.len() * 2, arena); - preloaded_bytes.extend_from_slice(initial_bytes); + let preloaded_bytes = arena.alloc_slice_copy(initial_bytes); + + // TODO: Try to move this metadata preparation to platform build time + let dead_code_metadata = parse_dead_code_metadata( + arena, + preloaded_count, + initial_bytes, + ret_types, + signature_ids, + ); + CodeSection { preloaded_count, preloaded_bytes, code_builders: Vec::with_capacity_in(0, arena), + dead_code_metadata, } } + + pub fn remove_dead_preloads(&mut self, arena: &'a Bump, called_preload_fns: T) + where + T: IntoIterator, + { + let mut live_ext_fn_indices = + trace_function_deps(arena, &self.dead_code_metadata, called_preload_fns); + + let mut buffer = Vec::with_capacity_in(self.preloaded_bytes.len(), arena); + + copy_live_and_replace_dead( + arena, + &mut buffer, + &self.dead_code_metadata, + self.preloaded_bytes, + &mut live_ext_fn_indices, + ); + + self.preloaded_bytes = buffer.into_bump_slice(); + } } impl<'a> Serialize for CodeSection<'a> { @@ -711,7 +770,7 @@ impl<'a> Serialize for CodeSection<'a> { let header_indices = write_section_header(buffer, SectionId::Code); buffer.encode_u32(self.preloaded_count + self.code_builders.len() as u32); - buffer.append_slice(&self.preloaded_bytes); + buffer.append_slice(self.preloaded_bytes); for code_builder in self.code_builders.iter() { code_builder.serialize(buffer); @@ -846,7 +905,7 @@ mod tests { // Reconstruct a new TypeSection by "pre-loading" the bytes of the original let mut cursor = 0; let mut preloaded = TypeSection::preload(arena, &original_serialized, &mut cursor); - preloaded.cache_offsets(); + preloaded.parse_preloaded_data(arena); debug_assert_eq!(original.offsets, preloaded.offsets); debug_assert_eq!(original.bytes, preloaded.bytes);