Wasm: implement dead code elimination

This commit is contained in:
Brian Carroll 2022-01-11 21:16:20 +00:00
parent 98400cae1b
commit 8a01c3f98a
6 changed files with 230 additions and 66 deletions

View file

@ -2,7 +2,7 @@ use bumpalo::{self, collections::Vec};
use code_builder::Align;
use roc_builtins::bitcode::{self, IntWidth};
use roc_collections::all::MutMap;
use roc_collections::all::{MutMap, MutSet};
use roc_module::ident::Ident;
use roc_module::low_level::{LowLevel, LowLevelWrapperType};
use roc_module::symbol::{Interns, Symbol};
@ -42,6 +42,7 @@ pub struct WasmBackend<'a> {
next_constant_addr: u32,
fn_index_offset: u32,
preloaded_functions_map: MutMap<&'a [u8], u32>,
called_preload_fns: MutSet<u32>,
proc_symbols: Vec<'a, (Symbol, u32)>,
helper_proc_gen: CodeGenHelp<'a>,
@ -79,6 +80,7 @@ impl<'a> WasmBackend<'a> {
next_constant_addr: CONST_SEGMENT_BASE_ADDR,
fn_index_offset,
preloaded_functions_map,
called_preload_fns: MutSet::default(),
proc_symbols,
helper_proc_gen,
@ -115,7 +117,12 @@ impl<'a> WasmBackend<'a> {
self.module.linking.symbol_table.push(linker_symbol);
}
pub fn into_module(self) -> WasmModule<'a> {
pub fn into_module(mut self, remove_dead_preloads: bool) -> WasmModule<'a> {
if remove_dead_preloads {
self.module
.code
.remove_dead_preloads(self.env.arena, self.called_preload_fns)
}
self.module
}
@ -1466,6 +1473,7 @@ impl<'a> WasmBackend<'a> {
let num_wasm_args = param_types.len();
let has_return_val = ret_type.is_some();
let fn_index = self.preloaded_functions_map[name.as_bytes()];
self.called_preload_fns.insert(fn_index);
let linker_symbol_index = u32::MAX;
self.code_builder

View file

@ -13,7 +13,6 @@ use roc_module::symbol::{Interns, ModuleId, Symbol};
use roc_mono::code_gen_help::CodeGenHelp;
use roc_mono::ir::{Proc, ProcLayout};
use roc_mono::layout::LayoutIds;
use roc_reporting::internal_error;
use crate::backend::WasmBackend;
use crate::wasm_module::{
@ -42,7 +41,7 @@ pub fn build_module<'a>(
procedures: MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>,
) -> Result<std::vec::Vec<u8>, String> {
let (wasm_module, _) = build_module_help(env, interns, preload_bytes, procedures)?;
let mut buffer = std::vec::Vec::with_capacity(4096);
let mut buffer = std::vec::Vec::with_capacity(wasm_module.size());
wasm_module.serialize(&mut buffer);
Ok(buffer)
}
@ -59,6 +58,7 @@ pub fn build_module_help<'a>(
let mut linker_symbols = Vec::with_capacity_in(procedures.len() * 2, env.arena);
let mut exports = Vec::with_capacity_in(4, env.arena);
let mut maybe_main_fn_index = None;
let eliminate_dead_preloads = true;
// Collect the symbols & names for the procedures,
// and filter out procs we're going to inline
@ -146,7 +146,7 @@ pub fn build_module_help<'a>(
backend.build_proc(proc);
}
let module = backend.into_module();
let module = backend.into_module(eliminate_dead_preloads);
let main_function_index = maybe_main_fn_index.unwrap() + fn_index_offset;
Ok((module, main_function_index))
@ -214,10 +214,6 @@ macro_rules! round_up_to_alignment {
};
}
pub fn debug_panic<E: std::fmt::Debug>(error: E) {
internal_error!("{:?}", error);
}
pub struct WasmDebugLogSettings {
proc_start_end: bool,
user_procs_ir: bool,
@ -233,5 +229,5 @@ pub const DEBUG_LOG_SETTINGS: WasmDebugLogSettings = WasmDebugLogSettings {
helper_procs_ir: false && cfg!(debug_assertions),
let_stmt_ir: false && cfg!(debug_assertions),
instructions: false && cfg!(debug_assertions),
keep_test_binary: true && cfg!(debug_assertions),
keep_test_binary: false && cfg!(debug_assertions),
};

View file

@ -37,6 +37,18 @@ impl Serialize for ValueType {
}
}
impl From<u8> for ValueType {
fn from(x: u8) -> Self {
match x {
0x7f => Self::I32,
0x7e => Self::I64,
0x7d => Self::F32,
0x7c => Self::F64,
_ => internal_error!("Invalid ValueType 0x{:02x}", x),
}
}
}
const BLOCK_NO_RESULT: u8 = 0x40;
/// A control block in our model of the VM

View file

@ -1,14 +1,20 @@
use bumpalo::collections::vec::Vec;
use bumpalo::Bump;
use super::opcodes::OpCode;
use super::serialize::{parse_u32_or_panic, SerialBuffer, Serialize, SkipBytes};
use super::{CodeBuilder, ValueType};
#[derive(Debug)]
pub struct DeadCodeMetadata<'a> {
/// Byte offset (in the module) where each function body can be found
/// Byte offset where each function body can be found
code_offsets: Vec<'a, u32>,
/// Vector with one entry per *call*, containing the called function's index
calls: Vec<'a, u32>,
/// Vector with one entry per *function*, indicating its offset in `calls`
calls_offsets: Vec<'a, u32>,
/// Return types of each function (for making almost-empty dummy replacements)
ret_types: Vec<'a, u8>,
ret_types: Vec<'a, Option<ValueType>>,
}
impl<'a> DeadCodeMetadata<'a> {
@ -28,69 +34,110 @@ impl<'a> DeadCodeMetadata<'a> {
/// use this backend without a linker.
pub fn parse_dead_code_metadata<'a>(
arena: &'a Bump,
module_bytes: &[u8],
cursor: &mut usize,
func_count: u32,
code_section_body: &[u8],
ret_types: Vec<'a, Option<ValueType>>,
signature_ids: Vec<'a, u32>,
) -> DeadCodeMetadata<'a> {
if module_bytes[*cursor] != SectionId::Code as u8 {
internal_error!("Expected Code section in object file at offset {}", *cursor);
}
*cursor += 1;
let section_size = parse_u32_or_panic(module_bytes, cursor);
let count_start = *cursor;
let section_end = count_start + section_size as usize;
let func_count = parse_u32_or_panic(module_bytes, cursor);
let mut metadata = DeadCodeMetadata::new(arena, func_count as usize);
metadata
.ret_types
.extend(signature_ids.iter().map(|sig| ret_types[*sig as usize]));
while *cursor < section_end {
metadata.code_offsets.push(*cursor as u32);
let mut cursor: usize = 0;
while cursor < code_section_body.len() {
metadata.code_offsets.push(cursor as u32);
metadata.calls_offsets.push(metadata.calls.len() as u32);
let func_size = parse_u32_or_panic(module_bytes, cursor);
let func_end = *cursor + func_size as usize;
let func_size = parse_u32_or_panic(code_section_body, &mut cursor);
let func_end = cursor + func_size as usize;
// Local variable declarations
let local_groups_count = parse_u32_or_panic(module_bytes, cursor);
let local_groups_count = parse_u32_or_panic(code_section_body, &mut cursor);
for _ in 0..local_groups_count {
let _group_len = parse_u32_or_panic(module_bytes, cursor);
*cursor += 1; // ValueType
parse_u32_or_panic(code_section_body, &mut cursor);
cursor += 1; // ValueType
}
// Instructions
while *cursor < func_end {
let opcode_byte: u8 = module_bytes[*cursor];
while cursor < func_end {
let opcode_byte: u8 = code_section_body[cursor];
if opcode_byte == OpCode::CALL as u8 {
*cursor += 1;
let call_index = parse_u32_or_panic(module_bytes, cursor);
cursor += 1;
let call_index = parse_u32_or_panic(code_section_body, &mut cursor);
metadata.calls.push(call_index as u32);
} else {
OpCode::skip_bytes(module_bytes, cursor);
OpCode::skip_bytes(code_section_body, &mut cursor);
}
}
}
// Extra entries to mark the end of the last function
metadata.code_offsets.push(*cursor as u32);
metadata.code_offsets.push(cursor as u32);
metadata.calls_offsets.push(metadata.calls.len() as u32);
metadata
}
/// Copy used functions (and their dependencies!) from an external module into our Code section
/// Replace unused functions with very small dummies, to avoid changing any indices
pub fn copy_used_functions<'a, T: SerialBuffer>(
/// Trace the dependencies of a list of functions
/// We've already collected metadata saying which functions call each other
/// Now we need to trace the dependency graphs of a specific subset of them
/// Result is the full set of builtins and platform functions used in the app.
/// The rest are "dead code" and can be eliminated.
pub fn trace_function_deps<'a, Indices: IntoIterator<Item = u32>>(
arena: &'a Bump,
buffer: &mut T,
metadata: DeadCodeMetadata<'a>,
external_module: &[u8],
sorted_used_func_indices: &[u32],
) {
let [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] = create_dummy_functions(arena);
metadata: &DeadCodeMetadata<'a>,
called_from_app: Indices,
) -> Vec<'a, u32> {
let mut live_fn_indices: Vec<'a, u32> = Vec::with_capacity_in(metadata.calls.len(), arena);
live_fn_indices.extend(called_from_app);
let num_funcs = metadata.calls_offsets.len();
// Current batch of functions whose call graphs we want to trace
let mut current_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena);
current_trace.clone_from(&live_fn_indices);
// The next batch (don't want to modify the current one while we're iterating over it!)
let mut next_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena);
// Fast lookup for what's already traced so we don't need to do it again
let mut already_traced: Vec<'a, bool> = Vec::from_iter_in((0..num_funcs).map(|_| false), arena);
loop {
live_fn_indices.extend_from_slice(&current_trace);
for func_idx in current_trace.iter() {
let i = *func_idx as usize;
already_traced[i] = true;
let calls_start = metadata.calls_offsets[i] as usize;
let calls_end = metadata.calls_offsets[i + 1] as usize;
let called_indices: &[u32] = &metadata.calls[calls_start..calls_end];
for called_idx in called_indices {
if !already_traced[*called_idx as usize] {
next_trace.push(*called_idx);
}
}
}
if next_trace.is_empty() {
break;
}
current_trace.clone_from(&next_trace);
next_trace.clear();
}
if true {
println!("Hey Brian, don't forget to remove this debug code");
let unsorted_len = live_fn_indices.len();
live_fn_indices.dedup();
debug_assert!(unsorted_len == live_fn_indices.len());
}
live_fn_indices
}
/// Create a set of dummy functions that just return a constant of each possible type
fn create_dummy_functions<'a>(arena: &'a Bump) -> [Vec<'a, u8>; 5] {
/// Create a set of minimum-size dummy functions for each possible return type
fn create_dummy_functions(arena: &Bump) -> [Vec<'_, u8>; 5] {
let mut code_builder_i32 = CodeBuilder::new(arena);
code_builder_i32.i32_const(0);
@ -111,11 +158,12 @@ fn create_dummy_functions<'a>(arena: &'a Bump) -> [Vec<'a, u8>; 5] {
code_builder_f64.build_fn_header_and_footer(&[], 0, None);
code_builder_nil.build_fn_header_and_footer(&[], 0, None);
let mut dummy_i32 = Vec::with_capacity_in(code_builder_i32.size(), arena);
let mut dummy_i64 = Vec::with_capacity_in(code_builder_i64.size(), arena);
let mut dummy_f32 = Vec::with_capacity_in(code_builder_f32.size(), arena);
let mut dummy_f64 = Vec::with_capacity_in(code_builder_f64.size(), arena);
let mut dummy_nil = Vec::with_capacity_in(code_builder_nil.size(), arena);
let capacity = code_builder_f64.size();
let mut dummy_i32 = Vec::with_capacity_in(capacity, arena);
let mut dummy_i64 = Vec::with_capacity_in(capacity, arena);
let mut dummy_f32 = Vec::with_capacity_in(capacity, arena);
let mut dummy_f64 = Vec::with_capacity_in(capacity, arena);
let mut dummy_nil = Vec::with_capacity_in(capacity, arena);
code_builder_i32.serialize(&mut dummy_i32);
code_builder_i64.serialize(&mut dummy_i64);
@ -125,3 +173,41 @@ fn create_dummy_functions<'a>(arena: &'a Bump) -> [Vec<'a, u8>; 5] {
[dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil]
}
/// Copy used functions from an external module into our Code section
/// Replace unused functions with very small dummies, to avoid changing any indices
pub fn copy_live_and_replace_dead<'a, T: SerialBuffer>(
arena: &'a Bump,
buffer: &mut T,
metadata: &DeadCodeMetadata<'a>,
external_code: &[u8],
live_ext_fn_indices: &'a mut [u32],
) {
live_ext_fn_indices.sort_unstable();
let [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] = create_dummy_functions(arena);
let mut prev = 0;
for live32 in live_ext_fn_indices.iter() {
let live = *live32 as usize;
// Replace dead functions with the minimal code body that will pass validation checks
for dead in prev..live {
let dummy_bytes = match metadata.ret_types[dead] {
Some(ValueType::I32) => &dummy_i32,
Some(ValueType::I64) => &dummy_i64,
Some(ValueType::F32) => &dummy_f32,
Some(ValueType::F64) => &dummy_f64,
None => &dummy_nil,
};
buffer.append_slice(dummy_bytes);
}
// Copy the body of the live function from the external module
let live_body_start = metadata.code_offsets[live] as usize;
let live_body_end = metadata.code_offsets[live + 1] as usize;
buffer.append_slice(&external_code[live_body_start..live_body_end]);
prev = live + 1;
}
}

View file

@ -1,4 +1,5 @@
pub mod code_builder;
mod dead_code;
pub mod linking;
pub mod opcodes;
pub mod sections;
@ -44,7 +45,6 @@ impl<'a> WasmModule<'a> {
}
/// Serialize the module to bytes
/// (not using Serialize trait because it's just one more thing to export)
pub fn serialize<T: SerialBuffer>(&self, buffer: &mut T) {
buffer.append_u8(0);
buffer.append_slice("asm".as_bytes());
@ -125,16 +125,19 @@ impl<'a> WasmModule<'a> {
let mut cursor: usize = 8;
let mut types = TypeSection::preload(arena, bytes, &mut cursor);
types.cache_offsets();
let ret_types = types.parse_preloaded_data(arena);
let import = ImportSection::preload(arena, bytes, &mut cursor);
let function = FunctionSection::preload(arena, bytes, &mut cursor);
let signature_ids = function.parse_preloaded_data(arena);
let table = OpaqueSection::preload(SectionId::Table, arena, bytes, &mut cursor);
let memory = MemorySection::preload(arena, bytes, &mut cursor);
let global = GlobalSection::preload(arena, bytes, &mut cursor);
let export = ExportSection::preload(arena, bytes, &mut cursor);
let start = OpaqueSection::preload(SectionId::Start, arena, bytes, &mut cursor);
let element = OpaqueSection::preload(SectionId::Element, arena, bytes, &mut cursor);
let code = CodeSection::preload(arena, bytes, &mut cursor);
let code = CodeSection::preload(arena, bytes, &mut cursor, ret_types, signature_ids);
let data = DataSection::preload(arena, bytes, &mut cursor);
let linking = LinkingSection::new(arena);
let relocations = RelocationSection::new(arena, "reloc.CODE");

View file

@ -3,6 +3,9 @@ use bumpalo::Bump;
use roc_collections::all::MutMap;
use roc_reporting::internal_error;
use super::dead_code::{
copy_live_and_replace_dead, parse_dead_code_metadata, trace_function_deps, DeadCodeMetadata,
};
use super::linking::RelocationEntry;
use super::opcodes::OpCode;
use super::serialize::{
@ -212,8 +215,10 @@ impl<'a> TypeSection<'a> {
sig_id as u32
}
pub fn cache_offsets(&mut self) {
pub fn parse_preloaded_data(&mut self, arena: &'a Bump) -> Vec<'a, Option<ValueType>> {
self.offsets.clear();
let mut ret_types = Vec::with_capacity_in(self.offsets.capacity(), arena);
let mut i = 0;
while i < self.bytes.len() {
self.offsets.push(i);
@ -226,8 +231,18 @@ impl<'a> TypeSection<'a> {
i += n_params as usize; // skip over one byte per param type
let n_return_values = self.bytes[i];
i += 1 + n_return_values as usize;
i += 1;
ret_types.push(if n_return_values == 0 {
None
} else {
Some(ValueType::from(self.bytes[i]))
});
i += n_return_values as usize;
}
ret_types
}
}
@ -412,6 +427,15 @@ impl<'a> FunctionSection<'a> {
self.bytes.encode_u32(sig_id);
self.count += 1;
}
pub fn parse_preloaded_data(&self, arena: &'a Bump) -> Vec<'a, u32> {
let mut preload_signature_ids = Vec::with_capacity_in(self.count as usize, arena);
let mut cursor = 0;
while cursor < self.bytes.len() {
preload_signature_ids.push(parse_u32_or_panic(&self.bytes, &mut cursor));
}
preload_signature_ids
}
}
section_impl!(FunctionSection, SectionId::Function);
@ -663,8 +687,9 @@ section_impl!(ExportSection, SectionId::Export);
#[derive(Debug)]
pub struct CodeSection<'a> {
pub preloaded_count: u32,
pub preloaded_bytes: Vec<'a, u8>,
pub preloaded_bytes: &'a [u8],
pub code_builders: Vec<'a, CodeBuilder<'a>>,
dead_code_metadata: DeadCodeMetadata<'a>,
}
impl<'a> CodeSection<'a> {
@ -677,8 +702,6 @@ impl<'a> CodeSection<'a> {
let header_indices = write_section_header(buffer, SectionId::Code);
buffer.encode_u32(self.preloaded_count + self.code_builders.len() as u32);
buffer.append_slice(&self.preloaded_bytes);
for code_builder in self.code_builders.iter() {
code_builder.serialize_with_relocs(buffer, relocations, header_indices.body_index);
}
@ -694,16 +717,52 @@ impl<'a> CodeSection<'a> {
MAX_SIZE_SECTION_HEADER + self.preloaded_bytes.len() + builders_size
}
pub fn preload(arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize) -> Self {
pub fn preload(
arena: &'a Bump,
module_bytes: &[u8],
cursor: &mut usize,
ret_types: Vec<'a, Option<ValueType>>,
signature_ids: Vec<'a, u32>,
) -> Self {
let (preloaded_count, initial_bytes) = parse_section(SectionId::Code, module_bytes, cursor);
let mut preloaded_bytes = Vec::with_capacity_in(initial_bytes.len() * 2, arena);
preloaded_bytes.extend_from_slice(initial_bytes);
let preloaded_bytes = arena.alloc_slice_copy(initial_bytes);
// TODO: Try to move this metadata preparation to platform build time
let dead_code_metadata = parse_dead_code_metadata(
arena,
preloaded_count,
initial_bytes,
ret_types,
signature_ids,
);
CodeSection {
preloaded_count,
preloaded_bytes,
code_builders: Vec::with_capacity_in(0, arena),
dead_code_metadata,
}
}
pub fn remove_dead_preloads<T>(&mut self, arena: &'a Bump, called_preload_fns: T)
where
T: IntoIterator<Item = u32>,
{
let mut live_ext_fn_indices =
trace_function_deps(arena, &self.dead_code_metadata, called_preload_fns);
let mut buffer = Vec::with_capacity_in(self.preloaded_bytes.len(), arena);
copy_live_and_replace_dead(
arena,
&mut buffer,
&self.dead_code_metadata,
self.preloaded_bytes,
&mut live_ext_fn_indices,
);
self.preloaded_bytes = buffer.into_bump_slice();
}
}
impl<'a> Serialize for CodeSection<'a> {
@ -711,7 +770,7 @@ impl<'a> Serialize for CodeSection<'a> {
let header_indices = write_section_header(buffer, SectionId::Code);
buffer.encode_u32(self.preloaded_count + self.code_builders.len() as u32);
buffer.append_slice(&self.preloaded_bytes);
buffer.append_slice(self.preloaded_bytes);
for code_builder in self.code_builders.iter() {
code_builder.serialize(buffer);
@ -846,7 +905,7 @@ mod tests {
// Reconstruct a new TypeSection by "pre-loading" the bytes of the original
let mut cursor = 0;
let mut preloaded = TypeSection::preload(arena, &original_serialized, &mut cursor);
preloaded.cache_offsets();
preloaded.parse_preloaded_data(arena);
debug_assert_eq!(original.offsets, preloaded.offsets);
debug_assert_eq!(original.bytes, preloaded.bytes);