From c00b43c36b645c7858731096a04fa92fccb3e911 Mon Sep 17 00:00:00 2001 From: Brian Carroll Date: Tue, 7 Jun 2022 23:52:33 +0100 Subject: [PATCH] wasm: Move some of the DCE code from CodeSection to WasmModule --- compiler/gen_wasm/src/lib.rs | 2 +- .../gen_wasm/src/wasm_module/dead_code.rs | 2 + compiler/gen_wasm/src/wasm_module/mod.rs | 119 +++++++++++++----- compiler/gen_wasm/src/wasm_module/sections.rs | 87 +------------ compiler/test_gen/src/helpers/wasm.rs | 2 +- repl_wasm/src/repl.rs | 2 +- 6 files changed, 98 insertions(+), 116 deletions(-) diff --git a/compiler/gen_wasm/src/lib.rs b/compiler/gen_wasm/src/lib.rs index 81226668bc..611307adba 100644 --- a/compiler/gen_wasm/src/lib.rs +++ b/compiler/gen_wasm/src/lib.rs @@ -67,7 +67,7 @@ pub fn build_app_binary<'a>( let (mut wasm_module, called_preload_fns, _) = build_app_module(env, interns, host_module, procedures); - wasm_module.remove_dead_preloads(env.arena, called_preload_fns); + wasm_module.eliminate_dead_code(env.arena, called_preload_fns); let mut buffer = std::vec::Vec::with_capacity(wasm_module.size()); wasm_module.serialize(&mut buffer); diff --git a/compiler/gen_wasm/src/wasm_module/dead_code.rs b/compiler/gen_wasm/src/wasm_module/dead_code.rs index dcd54ec9e3..a228b9f1b4 100644 --- a/compiler/gen_wasm/src/wasm_module/dead_code.rs +++ b/compiler/gen_wasm/src/wasm_module/dead_code.rs @@ -192,6 +192,8 @@ pub fn trace_call_graph<'a, Indices: IntoIterator>( next_trace.clear(); } + live_fn_indices.sort_unstable(); + live_fn_indices.dedup(); live_fn_indices } diff --git a/compiler/gen_wasm/src/wasm_module/mod.rs b/compiler/gen_wasm/src/wasm_module/mod.rs index 4a43093e78..9dedc114bf 100644 --- a/compiler/gen_wasm/src/wasm_module/mod.rs +++ b/compiler/gen_wasm/src/wasm_module/mod.rs @@ -11,6 +11,9 @@ pub use code_builder::{Align, CodeBuilder, LocalId, ValueType, VmSymbolState}; pub use linking::{OffsetRelocType, RelocationEntry, SymInfo}; pub use sections::{ConstExpr, Export, ExportType, Global, GlobalType, Signature}; +use self::dead_code::{ + copy_preloads_shrinking_dead_fns, parse_preloads_call_graph, trace_call_graph, +}; use self::linking::{LinkingSection, RelocationSection}; use self::parse::{Parse, ParseError}; use self::sections::{ @@ -109,20 +112,8 @@ impl<'a> WasmModule<'a> { let export = ExportSection::parse(arena, bytes, &mut cursor)?; let start = OpaqueSection::parse((arena, SectionId::Start), bytes, &mut cursor)?; let element = ElementSection::parse(arena, bytes, &mut cursor)?; - let indirect_callees = element.indirect_callees(arena); - - let imported_fn_signatures = import.function_signatures(arena); - let code = CodeSection::parse( - arena, - bytes, - &mut cursor, - &imported_fn_signatures, - &function.signatures, - &indirect_callees, - )?; - + let code = CodeSection::parse(arena, bytes, &mut cursor)?; let data = DataSection::parse(arena, bytes, &mut cursor)?; - let linking = LinkingSection::parse(arena, bytes, &mut cursor)?; let reloc_code = RelocationSection::parse((arena, "reloc.CODE"), bytes, &mut cursor)?; let reloc_data = RelocationSection::parse((arena, "reloc.DATA"), bytes, &mut cursor)?; @@ -180,14 +171,29 @@ impl<'a> WasmModule<'a> { }) } - pub fn remove_dead_preloads>( + pub fn eliminate_dead_code>( &mut self, arena: &'a Bump, called_preload_fns: T, ) { - let host_import_count = - self.import.imports.len() + self.code.dead_import_dummy_count as usize; + // + // Parse the host's call graph + // + let indirect_callees = self.element.indirect_callees(arena); + let import_signatures = self.import.function_signatures(arena); + let preloads_call_graph = parse_preloads_call_graph( + arena, + &self.code.preloaded_bytes, + &import_signatures, + &self.function.signatures, + &indirect_callees, + ) + .unwrap(); + // + // Trace all live host functions, using the call graph + // Start with the functions called from Roc, and those exported to JS + // let exported_fn_iter = self .export .exports @@ -195,17 +201,29 @@ impl<'a> WasmModule<'a> { .filter(|ex| ex.ty == ExportType::Func) .map(|ex| ex.index); let exported_fn_indices = Vec::from_iter_in(exported_fn_iter, arena); - - let live_import_fns = self.code.remove_dead_preloads( + let live_preload_fns = trace_call_graph( arena, - self.import.function_count(), + &preloads_call_graph, &exported_fn_indices, called_preload_fns, - &self.reloc_code, - &self.linking, ); - // Retain any imported functions whose index appears in live_import_fns + // + // Categorise the live functions as either imports from JS, or internal Wasm functions + // + let host_import_count = + self.import.imports.len() + self.code.dead_import_dummy_count as usize; + let split_at = live_preload_fns + .iter() + .position(|f| *f as usize >= host_import_count) + .unwrap_or(live_preload_fns.len()); + let mut live_import_fns = live_preload_fns; + let live_wasm_fns = live_import_fns.split_off(split_at); + + // + // Remove all unused JS imports + // We don't want to force the web page to provide dummy JS functions, it's a pain! + // let mut fn_index = 0; let mut live_index = 0; self.import.imports.retain(|import| { @@ -223,21 +241,66 @@ impl<'a> WasmModule<'a> { } }); - // Update function signatures + // + // Update function signatures & debug names for imports that changed index + // for (new_index, old_index) in live_import_fns.iter().enumerate() { // Safe because `old_index >= new_index` self.function.signatures[new_index] = self.function.signatures[*old_index as usize]; - } - - // Update debug names - for (new_index, old_index) in live_import_fns.iter().enumerate() { - // Safe because `old_index >= new_index` self.names.function_names[new_index] = self.names.function_names[*old_index as usize]; } let first_dead_import_index = live_import_fns.last().map(|x| x + 1).unwrap_or(0) as usize; for i in first_dead_import_index..host_import_count { self.names.function_names[i] = (i as u32, "unused_host_import"); } + + // + // Relocate Wasm calls to JS imports + // This must happen *before* we run dead code elimination on the code section, + // so that the host's linking data will still be valid. + // + for (new_index, old_index) in live_import_fns.iter().enumerate() { + if new_index == *old_index as usize { + continue; + } + let sym_index = self + .linking + .find_imported_function_symbol(*old_index) + .unwrap_or_else(|| { + panic!( + "Linking failed! Can't find fn #{} in host symbol table", + old_index + ) + }); + self.reloc_code.apply_relocs_u32( + &mut self.code.preloaded_bytes, + self.code.preloaded_reloc_offset, + sym_index, + new_index as u32, + ); + } + + // + // For every eliminated JS import, insert a dummy Wasm function at the same index. + // This avoids shifting the indices of Wasm functions, which would require more linking work. + // + let dead_import_count = host_import_count - live_import_fns.len(); + self.code.dead_import_dummy_count += dead_import_count as u32; + + // + // Dead code elimination. Replace dead functions with tiny dummies. + // This avoids changing function indices, which would require more linking work. + // + let mut buffer = Vec::with_capacity_in(self.code.preloaded_bytes.len(), arena); + copy_preloads_shrinking_dead_fns( + arena, + &mut buffer, + &preloads_call_graph, + &self.code.preloaded_bytes, + host_import_count, + live_wasm_fns, + ); + self.code.preloaded_bytes = buffer; } pub fn get_exported_global_u32(&self, name: &str) -> Option { diff --git a/compiler/gen_wasm/src/wasm_module/sections.rs b/compiler/gen_wasm/src/wasm_module/sections.rs index 164e9e70a9..a766f143e7 100644 --- a/compiler/gen_wasm/src/wasm_module/sections.rs +++ b/compiler/gen_wasm/src/wasm_module/sections.rs @@ -4,11 +4,6 @@ use bumpalo::collections::vec::Vec; use bumpalo::Bump; use roc_error_macros::internal_error; -use super::dead_code::{ - copy_preloads_shrinking_dead_fns, parse_preloads_call_graph, trace_call_graph, - PreloadsCallGraph, -}; -use super::linking::{LinkingSection, RelocationSection}; use super::opcodes::OpCode; use super::parse::{Parse, ParseError, SkipBytes}; use super::serialize::{SerialBuffer, Serialize, MAX_SIZE_ENCODED_U32}; @@ -1179,7 +1174,6 @@ pub struct CodeSection<'a> { /// Dead imports are replaced with dummy functions in CodeSection pub dead_import_dummy_count: u32, pub code_builders: Vec<'a, CodeBuilder<'a>>, - dead_code_metadata: PreloadsCallGraph<'a>, } impl<'a> CodeSection<'a> { @@ -1193,9 +1187,6 @@ impl<'a> CodeSection<'a> { arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize, - import_signatures: &[u32], - function_signatures: &[u32], - indirect_callees: &[u32], ) -> Result { if module_bytes[*cursor] != SectionId::Code as u8 { return Err(ParseError { @@ -1211,9 +1202,8 @@ impl<'a> CodeSection<'a> { let next_section_start = count_start + section_size as usize; *cursor = next_section_start; - // Relocation offsets are based from the start of the section body, which includes function count - // But preloaded_bytes does not include the function count, only the function bodies! - // When we do relocations, we need to account for this + // `preloaded_bytes` is offset from the start of the section, since we skip the function count. + // When we do relocations, we need to account for this offset, so let's record it here. let preloaded_reloc_offset = (function_bodies_start - count_start) as u32; let preloaded_bytes = Vec::from_iter_in( @@ -1223,87 +1213,14 @@ impl<'a> CodeSection<'a> { arena, ); - let dead_code_metadata = parse_preloads_call_graph( - arena, - &preloaded_bytes, - import_signatures, - function_signatures, - indirect_callees, - )?; - Ok(CodeSection { preloaded_count: count, preloaded_reloc_offset, preloaded_bytes, dead_import_dummy_count: 0, code_builders: Vec::with_capacity_in(0, arena), - dead_code_metadata, }) } - - pub(super) fn remove_dead_preloads>( - &mut self, - arena: &'a Bump, - import_fn_count: usize, - exported_fns: &[u32], - called_preload_fns: T, - reloc_code: &RelocationSection<'a>, - linking: &LinkingSection<'a>, - ) -> Vec { - let mut live_preload_fns = trace_call_graph( - arena, - &self.dead_code_metadata, - exported_fns, - called_preload_fns, - ); - live_preload_fns.sort_unstable(); - live_preload_fns.dedup(); - - let host_import_count = import_fn_count + self.dead_import_dummy_count as usize; - let split_at = live_preload_fns - .iter() - .position(|f| *f as usize >= host_import_count) - .unwrap_or(live_preload_fns.len()); - let mut live_import_fns = live_preload_fns; - let live_defined_fns = live_import_fns.split_off(split_at); - - let dead_import_count = host_import_count - live_import_fns.len(); - self.dead_import_dummy_count += dead_import_count as u32; - for (new_index, old_index) in live_import_fns.iter().enumerate() { - if new_index == *old_index as usize { - continue; - } - let sym_index = linking - .find_imported_function_symbol(*old_index) - .unwrap_or_else(|| { - panic!( - "Linking failed! Can't find fn #{} in host symbol table", - old_index - ) - }); - reloc_code.apply_relocs_u32( - &mut self.preloaded_bytes, - self.preloaded_reloc_offset, - sym_index, - new_index as u32, - ); - } - - let mut buffer = Vec::with_capacity_in(self.preloaded_bytes.len(), arena); - - copy_preloads_shrinking_dead_fns( - arena, - &mut buffer, - &self.dead_code_metadata, - &self.preloaded_bytes, - host_import_count, - live_defined_fns, - ); - - self.preloaded_bytes = buffer; - - live_import_fns - } } impl<'a> Serialize for CodeSection<'a> { diff --git a/compiler/test_gen/src/helpers/wasm.rs b/compiler/test_gen/src/helpers/wasm.rs index 110924ccfc..a95070623f 100644 --- a/compiler/test_gen/src/helpers/wasm.rs +++ b/compiler/test_gen/src/helpers/wasm.rs @@ -151,7 +151,7 @@ fn compile_roc_to_wasm_bytes<'a, T: Wasm32Result>( index: init_refcount_idx, }); - module.remove_dead_preloads(env.arena, called_preload_fns); + module.eliminate_dead_code(env.arena, called_preload_fns); let mut app_module_bytes = std::vec::Vec::with_capacity(module.size()); module.serialize(&mut app_module_bytes); diff --git a/repl_wasm/src/repl.rs b/repl_wasm/src/repl.rs index bb98e49973..c454a0eb66 100644 --- a/repl_wasm/src/repl.rs +++ b/repl_wasm/src/repl.rs @@ -221,7 +221,7 @@ pub async fn entrypoint_from_js(src: String) -> Result { &main_fn_layout.result, ); - module.remove_dead_preloads(env.arena, called_preload_fns); + module.eliminate_dead_code(env.arena, called_preload_fns); let mut buffer = Vec::with_capacity_in(module.size(), arena); module.serialize(&mut buffer);