From 63c33d82e38d7bc8f7ec4da2d089175489780219 Mon Sep 17 00:00:00 2001 From: Brian Carroll Date: Sat, 12 Feb 2022 00:09:35 +0000 Subject: [PATCH] wasm: Improve dead code elimination to handle indirect calls --- .../gen_wasm/src/wasm_module/dead_code.rs | 33 +++++++++--- compiler/gen_wasm/src/wasm_module/mod.rs | 21 +++++++- compiler/gen_wasm/src/wasm_module/sections.rs | 51 ++++++++++++++----- 3 files changed, 83 insertions(+), 22 deletions(-) diff --git a/compiler/gen_wasm/src/wasm_module/dead_code.rs b/compiler/gen_wasm/src/wasm_module/dead_code.rs index 382f306322..946b84cf84 100644 --- a/compiler/gen_wasm/src/wasm_module/dead_code.rs +++ b/compiler/gen_wasm/src/wasm_module/dead_code.rs @@ -39,16 +39,16 @@ pub struct PreloadsCallGraph<'a> { } impl<'a> PreloadsCallGraph<'a> { - pub fn new(arena: &'a Bump, import_fn_count: u32, fn_count: u32) -> Self { - let num_preloads = (import_fn_count + fn_count) as usize; + pub fn new(arena: &'a Bump, import_fn_count: usize, fn_count: usize) -> Self { + let num_preloads = import_fn_count + fn_count; let mut code_offsets = Vec::with_capacity_in(num_preloads, arena); let calls = Vec::with_capacity_in(2 * num_preloads, arena); let mut calls_offsets = Vec::with_capacity_in(1 + num_preloads, arena); // Imported functions have zero code length and no calls - code_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize)); - calls_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize)); + code_offsets.extend(std::iter::repeat(0).take(import_fn_count)); + calls_offsets.extend(std::iter::repeat(0).take(import_fn_count)); PreloadsCallGraph { num_preloads, @@ -65,11 +65,18 @@ impl<'a> PreloadsCallGraph<'a> { /// use this backend without a linker. pub fn parse_preloads_call_graph<'a>( arena: &'a Bump, - fn_count: u32, code_section_body: &[u8], - import_fn_count: u32, + import_signatures: &[u32], + function_signatures: &[u32], + indirect_callees: &[u32], ) -> PreloadsCallGraph<'a> { - let mut call_graph = PreloadsCallGraph::new(arena, import_fn_count, fn_count); + let mut call_graph = + PreloadsCallGraph::new(arena, import_signatures.len(), function_signatures.len()); + + let mut signatures = + Vec::with_capacity_in(import_signatures.len() + function_signatures.len(), arena); + signatures.extend_from_slice(import_signatures); + signatures.extend_from_slice(function_signatures); // Iterate over the bytes of the Code section let mut cursor: usize = 0; @@ -88,13 +95,23 @@ pub fn parse_preloads_call_graph<'a>( cursor += 1; // ValueType } - // Parse `call` instructions and skip over all other instructions + // Parse `call` and `call_indirect` instructions, skip over everything else while cursor < func_end { let opcode_byte: u8 = code_section_body[cursor]; if opcode_byte == OpCode::CALL as u8 { cursor += 1; let call_index = parse_u32_or_panic(code_section_body, &mut cursor); call_graph.calls.push(call_index as u32); + } else if opcode_byte == OpCode::CALLINDIRECT as u8 { + cursor += 1; + // Insert all indirect callees with a matching type signature + let sig = parse_u32_or_panic(code_section_body, &mut cursor); + call_graph.calls.extend( + indirect_callees + .iter() + .filter(|f| signatures[**f as usize] == sig), + ); + u32::skip_bytes(code_section_body, &mut cursor); // table_idx } else { OpCode::skip_bytes(code_section_body, &mut cursor); } diff --git a/compiler/gen_wasm/src/wasm_module/mod.rs b/compiler/gen_wasm/src/wasm_module/mod.rs index 619edfb5ab..4badf3aefc 100644 --- a/compiler/gen_wasm/src/wasm_module/mod.rs +++ b/compiler/gen_wasm/src/wasm_module/mod.rs @@ -133,18 +133,35 @@ impl<'a> WasmModule<'a> { let mut types = TypeSection::preload(arena, bytes, &mut cursor); types.parse_offsets(); - let import = ImportSection::preload(arena, bytes, &mut cursor); + let mut import = ImportSection::preload(arena, bytes, &mut cursor); + let import_signatures = import.parse(arena); + let function = FunctionSection::preload(arena, bytes, &mut cursor); + let function_signatures = function.parse(arena); + let table = OpaqueSection::preload(SectionId::Table, arena, bytes, &mut cursor); + let memory = MemorySection::preload(arena, bytes, &mut cursor); + let global = GlobalSection::preload(arena, bytes, &mut cursor); ExportSection::skip_bytes(bytes, &mut cursor); + let export = ExportSection::empty(arena); let start = OpaqueSection::preload(SectionId::Start, arena, bytes, &mut cursor); + let element = ElementSection::preload(arena, bytes, &mut cursor); - let code = CodeSection::preload(arena, bytes, &mut cursor, import.function_count); + let indirect_callees = element.indirect_callees(arena); + + let code = CodeSection::preload( + arena, + bytes, + &mut cursor, + &import_signatures, + &function_signatures, + &indirect_callees, + ); let data = DataSection::preload(arena, bytes, &mut cursor); diff --git a/compiler/gen_wasm/src/wasm_module/sections.rs b/compiler/gen_wasm/src/wasm_module/sections.rs index 8a42862dcd..f49ce674e7 100644 --- a/compiler/gen_wasm/src/wasm_module/sections.rs +++ b/compiler/gen_wasm/src/wasm_module/sections.rs @@ -377,8 +377,8 @@ impl<'a> ImportSection<'a> { self.count += 1; } - fn update_function_count(&mut self) { - let mut f_count = 0; + pub fn parse(&mut self, arena: &'a Bump) -> Vec<'a, u32> { + let mut fn_signatures = bumpalo::vec![in arena]; let mut cursor = 0; while cursor < self.bytes.len() { String::skip_bytes(&self.bytes, &mut cursor); @@ -389,8 +389,7 @@ impl<'a> ImportSection<'a> { match type_id { ImportTypeId::Func => { - f_count += 1; - u32::skip_bytes(&self.bytes, &mut cursor); + fn_signatures.push(parse_u32_or_panic(&self.bytes, &mut cursor)); } ImportTypeId::Table => { TableType::skip_bytes(&self.bytes, &mut cursor); @@ -404,17 +403,16 @@ impl<'a> ImportSection<'a> { } } - self.function_count = f_count; + self.function_count = fn_signatures.len() as u32; + fn_signatures } pub fn from_count_and_bytes(count: u32, bytes: Vec<'a, u8>) -> Self { - let mut created = ImportSection { + ImportSection { bytes, count, function_count: 0, - }; - created.update_function_count(); - created + } } } @@ -442,6 +440,16 @@ impl<'a> FunctionSection<'a> { self.bytes.encode_u32(sig_id); self.count += 1; } + + pub fn parse(&self, arena: &'a Bump) -> Vec<'a, u32> { + let count = self.count as usize; + let mut signatures = Vec::with_capacity_in(count, arena); + let mut cursor = 0; + for _ in 0..count { + signatures.push(parse_u32_or_panic(&self.bytes, &mut cursor)); + } + signatures + } } section_impl!(FunctionSection, SectionId::Function); @@ -887,6 +895,18 @@ impl<'a> ElementSection<'a> { pub fn size(&self) -> usize { self.segments.iter().map(|seg| seg.size()).sum() } + + pub fn indirect_callees(&self, arena: &'a Bump) -> Vec<'a, u32> { + let mut result = bumpalo::vec![in arena]; + for segment in self.segments.iter() { + if let ElementSegment::ActiveImplicitTableIndex { fn_indices, .. } = segment { + result.extend_from_slice(fn_indices); + } else { + internal_error!("Unsupported ElementSegment {:?}", self) + } + } + result + } } impl<'a> Serialize for ElementSection<'a> { @@ -940,14 +960,21 @@ impl<'a> CodeSection<'a> { arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize, - import_fn_count: u32, + import_signatures: &[u32], + function_signatures: &[u32], + indirect_callees: &[u32], ) -> Self { let (preloaded_count, initial_bytes) = parse_section(SectionId::Code, module_bytes, cursor); let preloaded_bytes = arena.alloc_slice_copy(initial_bytes); // TODO: Try to move this call_graph preparation to platform build time - let dead_code_metadata = - parse_preloads_call_graph(arena, preloaded_count, initial_bytes, import_fn_count); + let dead_code_metadata = parse_preloads_call_graph( + arena, + initial_bytes, + import_signatures, + function_signatures, + indirect_callees, + ); CodeSection { preloaded_count,