wasm: Move some of the DCE code from CodeSection to WasmModule

This commit is contained in:
Brian Carroll 2022-06-07 23:52:33 +01:00
parent dfda992d93
commit c00b43c36b
No known key found for this signature in database
GPG key ID: 9CF4E3BF9C4722C7
6 changed files with 98 additions and 116 deletions

View file

@ -67,7 +67,7 @@ pub fn build_app_binary<'a>(
let (mut wasm_module, called_preload_fns, _) =
build_app_module(env, interns, host_module, procedures);
wasm_module.remove_dead_preloads(env.arena, called_preload_fns);
wasm_module.eliminate_dead_code(env.arena, called_preload_fns);
let mut buffer = std::vec::Vec::with_capacity(wasm_module.size());
wasm_module.serialize(&mut buffer);

View file

@ -192,6 +192,8 @@ pub fn trace_call_graph<'a, Indices: IntoIterator<Item = u32>>(
next_trace.clear();
}
live_fn_indices.sort_unstable();
live_fn_indices.dedup();
live_fn_indices
}

View file

@ -11,6 +11,9 @@ pub use code_builder::{Align, CodeBuilder, LocalId, ValueType, VmSymbolState};
pub use linking::{OffsetRelocType, RelocationEntry, SymInfo};
pub use sections::{ConstExpr, Export, ExportType, Global, GlobalType, Signature};
use self::dead_code::{
copy_preloads_shrinking_dead_fns, parse_preloads_call_graph, trace_call_graph,
};
use self::linking::{LinkingSection, RelocationSection};
use self::parse::{Parse, ParseError};
use self::sections::{
@ -109,20 +112,8 @@ impl<'a> WasmModule<'a> {
let export = ExportSection::parse(arena, bytes, &mut cursor)?;
let start = OpaqueSection::parse((arena, SectionId::Start), bytes, &mut cursor)?;
let element = ElementSection::parse(arena, bytes, &mut cursor)?;
let indirect_callees = element.indirect_callees(arena);
let imported_fn_signatures = import.function_signatures(arena);
let code = CodeSection::parse(
arena,
bytes,
&mut cursor,
&imported_fn_signatures,
&function.signatures,
&indirect_callees,
)?;
let code = CodeSection::parse(arena, bytes, &mut cursor)?;
let data = DataSection::parse(arena, bytes, &mut cursor)?;
let linking = LinkingSection::parse(arena, bytes, &mut cursor)?;
let reloc_code = RelocationSection::parse((arena, "reloc.CODE"), bytes, &mut cursor)?;
let reloc_data = RelocationSection::parse((arena, "reloc.DATA"), bytes, &mut cursor)?;
@ -180,14 +171,29 @@ impl<'a> WasmModule<'a> {
})
}
pub fn remove_dead_preloads<T: IntoIterator<Item = u32>>(
pub fn eliminate_dead_code<T: IntoIterator<Item = u32>>(
&mut self,
arena: &'a Bump,
called_preload_fns: T,
) {
let host_import_count =
self.import.imports.len() + self.code.dead_import_dummy_count as usize;
//
// Parse the host's call graph
//
let indirect_callees = self.element.indirect_callees(arena);
let import_signatures = self.import.function_signatures(arena);
let preloads_call_graph = parse_preloads_call_graph(
arena,
&self.code.preloaded_bytes,
&import_signatures,
&self.function.signatures,
&indirect_callees,
)
.unwrap();
//
// Trace all live host functions, using the call graph
// Start with the functions called from Roc, and those exported to JS
//
let exported_fn_iter = self
.export
.exports
@ -195,17 +201,29 @@ impl<'a> WasmModule<'a> {
.filter(|ex| ex.ty == ExportType::Func)
.map(|ex| ex.index);
let exported_fn_indices = Vec::from_iter_in(exported_fn_iter, arena);
let live_import_fns = self.code.remove_dead_preloads(
let live_preload_fns = trace_call_graph(
arena,
self.import.function_count(),
&preloads_call_graph,
&exported_fn_indices,
called_preload_fns,
&self.reloc_code,
&self.linking,
);
// Retain any imported functions whose index appears in live_import_fns
//
// Categorise the live functions as either imports from JS, or internal Wasm functions
//
let host_import_count =
self.import.imports.len() + self.code.dead_import_dummy_count as usize;
let split_at = live_preload_fns
.iter()
.position(|f| *f as usize >= host_import_count)
.unwrap_or(live_preload_fns.len());
let mut live_import_fns = live_preload_fns;
let live_wasm_fns = live_import_fns.split_off(split_at);
//
// Remove all unused JS imports
// We don't want to force the web page to provide dummy JS functions, it's a pain!
//
let mut fn_index = 0;
let mut live_index = 0;
self.import.imports.retain(|import| {
@ -223,21 +241,66 @@ impl<'a> WasmModule<'a> {
}
});
// Update function signatures
//
// Update function signatures & debug names for imports that changed index
//
for (new_index, old_index) in live_import_fns.iter().enumerate() {
// Safe because `old_index >= new_index`
self.function.signatures[new_index] = self.function.signatures[*old_index as usize];
}
// Update debug names
for (new_index, old_index) in live_import_fns.iter().enumerate() {
// Safe because `old_index >= new_index`
self.names.function_names[new_index] = self.names.function_names[*old_index as usize];
}
let first_dead_import_index = live_import_fns.last().map(|x| x + 1).unwrap_or(0) as usize;
for i in first_dead_import_index..host_import_count {
self.names.function_names[i] = (i as u32, "unused_host_import");
}
//
// Relocate Wasm calls to JS imports
// This must happen *before* we run dead code elimination on the code section,
// so that the host's linking data will still be valid.
//
for (new_index, old_index) in live_import_fns.iter().enumerate() {
if new_index == *old_index as usize {
continue;
}
let sym_index = self
.linking
.find_imported_function_symbol(*old_index)
.unwrap_or_else(|| {
panic!(
"Linking failed! Can't find fn #{} in host symbol table",
old_index
)
});
self.reloc_code.apply_relocs_u32(
&mut self.code.preloaded_bytes,
self.code.preloaded_reloc_offset,
sym_index,
new_index as u32,
);
}
//
// For every eliminated JS import, insert a dummy Wasm function at the same index.
// This avoids shifting the indices of Wasm functions, which would require more linking work.
//
let dead_import_count = host_import_count - live_import_fns.len();
self.code.dead_import_dummy_count += dead_import_count as u32;
//
// Dead code elimination. Replace dead functions with tiny dummies.
// This avoids changing function indices, which would require more linking work.
//
let mut buffer = Vec::with_capacity_in(self.code.preloaded_bytes.len(), arena);
copy_preloads_shrinking_dead_fns(
arena,
&mut buffer,
&preloads_call_graph,
&self.code.preloaded_bytes,
host_import_count,
live_wasm_fns,
);
self.code.preloaded_bytes = buffer;
}
pub fn get_exported_global_u32(&self, name: &str) -> Option<u32> {

View file

@ -4,11 +4,6 @@ use bumpalo::collections::vec::Vec;
use bumpalo::Bump;
use roc_error_macros::internal_error;
use super::dead_code::{
copy_preloads_shrinking_dead_fns, parse_preloads_call_graph, trace_call_graph,
PreloadsCallGraph,
};
use super::linking::{LinkingSection, RelocationSection};
use super::opcodes::OpCode;
use super::parse::{Parse, ParseError, SkipBytes};
use super::serialize::{SerialBuffer, Serialize, MAX_SIZE_ENCODED_U32};
@ -1179,7 +1174,6 @@ pub struct CodeSection<'a> {
/// Dead imports are replaced with dummy functions in CodeSection
pub dead_import_dummy_count: u32,
pub code_builders: Vec<'a, CodeBuilder<'a>>,
dead_code_metadata: PreloadsCallGraph<'a>,
}
impl<'a> CodeSection<'a> {
@ -1193,9 +1187,6 @@ impl<'a> CodeSection<'a> {
arena: &'a Bump,
module_bytes: &[u8],
cursor: &mut usize,
import_signatures: &[u32],
function_signatures: &[u32],
indirect_callees: &[u32],
) -> Result<Self, ParseError> {
if module_bytes[*cursor] != SectionId::Code as u8 {
return Err(ParseError {
@ -1211,9 +1202,8 @@ impl<'a> CodeSection<'a> {
let next_section_start = count_start + section_size as usize;
*cursor = next_section_start;
// Relocation offsets are based from the start of the section body, which includes function count
// But preloaded_bytes does not include the function count, only the function bodies!
// When we do relocations, we need to account for this
// `preloaded_bytes` is offset from the start of the section, since we skip the function count.
// When we do relocations, we need to account for this offset, so let's record it here.
let preloaded_reloc_offset = (function_bodies_start - count_start) as u32;
let preloaded_bytes = Vec::from_iter_in(
@ -1223,87 +1213,14 @@ impl<'a> CodeSection<'a> {
arena,
);
let dead_code_metadata = parse_preloads_call_graph(
arena,
&preloaded_bytes,
import_signatures,
function_signatures,
indirect_callees,
)?;
Ok(CodeSection {
preloaded_count: count,
preloaded_reloc_offset,
preloaded_bytes,
dead_import_dummy_count: 0,
code_builders: Vec::with_capacity_in(0, arena),
dead_code_metadata,
})
}
pub(super) fn remove_dead_preloads<T: IntoIterator<Item = u32>>(
&mut self,
arena: &'a Bump,
import_fn_count: usize,
exported_fns: &[u32],
called_preload_fns: T,
reloc_code: &RelocationSection<'a>,
linking: &LinkingSection<'a>,
) -> Vec<u32> {
let mut live_preload_fns = trace_call_graph(
arena,
&self.dead_code_metadata,
exported_fns,
called_preload_fns,
);
live_preload_fns.sort_unstable();
live_preload_fns.dedup();
let host_import_count = import_fn_count + self.dead_import_dummy_count as usize;
let split_at = live_preload_fns
.iter()
.position(|f| *f as usize >= host_import_count)
.unwrap_or(live_preload_fns.len());
let mut live_import_fns = live_preload_fns;
let live_defined_fns = live_import_fns.split_off(split_at);
let dead_import_count = host_import_count - live_import_fns.len();
self.dead_import_dummy_count += dead_import_count as u32;
for (new_index, old_index) in live_import_fns.iter().enumerate() {
if new_index == *old_index as usize {
continue;
}
let sym_index = linking
.find_imported_function_symbol(*old_index)
.unwrap_or_else(|| {
panic!(
"Linking failed! Can't find fn #{} in host symbol table",
old_index
)
});
reloc_code.apply_relocs_u32(
&mut self.preloaded_bytes,
self.preloaded_reloc_offset,
sym_index,
new_index as u32,
);
}
let mut buffer = Vec::with_capacity_in(self.preloaded_bytes.len(), arena);
copy_preloads_shrinking_dead_fns(
arena,
&mut buffer,
&self.dead_code_metadata,
&self.preloaded_bytes,
host_import_count,
live_defined_fns,
);
self.preloaded_bytes = buffer;
live_import_fns
}
}
impl<'a> Serialize for CodeSection<'a> {

View file

@ -151,7 +151,7 @@ fn compile_roc_to_wasm_bytes<'a, T: Wasm32Result>(
index: init_refcount_idx,
});
module.remove_dead_preloads(env.arena, called_preload_fns);
module.eliminate_dead_code(env.arena, called_preload_fns);
let mut app_module_bytes = std::vec::Vec::with_capacity(module.size());
module.serialize(&mut app_module_bytes);

View file

@ -221,7 +221,7 @@ pub async fn entrypoint_from_js(src: String) -> Result<String, String> {
&main_fn_layout.result,
);
module.remove_dead_preloads(env.arena, called_preload_fns);
module.eliminate_dead_code(env.arena, called_preload_fns);
let mut buffer = Vec::with_capacity_in(module.size(), arena);
module.serialize(&mut buffer);