Wasm: Make dead functions unreachable, and the same for all return types

This commit is contained in:
Brian Carroll 2022-01-14 18:17:52 +00:00
parent e7dc442af0
commit 4311b5a410
4 changed files with 27 additions and 135 deletions

View file

@ -3,7 +3,7 @@ use bumpalo::Bump;
use super::opcodes::OpCode;
use super::serialize::{parse_u32_or_panic, SerialBuffer, Serialize, SkipBytes};
use super::{CodeBuilder, ValueType};
use super::CodeBuilder;
/*
@ -14,8 +14,6 @@ Or, more specifically, "dead function replacement"
- On pre-loading the object file:
- Analyse its call graph by finding all `call` instructions in the Code section,
and checking which function index they refer to. Store this in a `DeadCodeMetadata`
- Later we will need to know the return type of each function, so scan the Type and Function
sections to get that information and store it in `DeadCodeMetadata` too.
- While compiling Roc code:
- Run the backend as usual, adding more data into various sections of the Wasm module
- Whenever a call to a builtin or platform function is made, record its index in a Set.
@ -24,45 +22,37 @@ Or, more specifically, "dead function replacement"
- Starting with the set of live preloaded functions, trace their call graphs using the info we
collected earlier in `DeadCodeMetadata`. Mark all function indices in the call graph as "live".
- Dead function replacement:
- We actually don't want to just *delete* dead functions, because that would change the *indices*
- We actually don't want to just *delete* dead functions, because that would change the indices
of the live functions, invalidating all references to them, such as `call` instructions.
- Instead, we replace the dead functions with a tiny but *valid* function that has the same return type!
For example the minimal function returning `i32` contains just one instruction: `i32.const 0`
- This replacement happens during the final serialization phase
- Instead, during serialization, we replace its body with a single `unreachable` instruction
*/
#[derive(Debug)]
pub struct DeadCodeMetadata<'a> {
num_preloads: usize,
/// Byte offset where each function body can be found
code_offsets: Vec<'a, u32>,
/// Vector with one entry per *call*, containing the called function's index
calls: Vec<'a, u32>,
/// Vector with one entry per *function*, indicating its offset in `calls`
calls_offsets: Vec<'a, u32>,
/// Return types of each function (for making almost-empty dummy replacements)
ret_types: Vec<'a, Option<ValueType>>,
}
impl<'a> DeadCodeMetadata<'a> {
pub fn new(arena: &'a Bump, import_fn_count: u32, fn_count: u32) -> Self {
let capacity = (import_fn_count + fn_count) as usize;
let num_preloads = (import_fn_count + fn_count) as usize;
let mut code_offsets = Vec::with_capacity_in(capacity, arena);
let mut ret_types = Vec::with_capacity_in(capacity, arena);
let calls = Vec::with_capacity_in(2 * capacity, arena);
let mut calls_offsets = Vec::with_capacity_in(1 + capacity, arena);
let mut code_offsets = Vec::with_capacity_in(num_preloads, arena);
let calls = Vec::with_capacity_in(2 * num_preloads, arena);
let mut calls_offsets = Vec::with_capacity_in(1 + num_preloads, arena);
// Imported functions have zero code length and no calls
code_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize));
calls_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize));
// We don't care about import return types
// Return types are for replacing dead functions with dummies, which doesn't apply to imports
ret_types.extend(std::iter::repeat(None).take(import_fn_count as usize));
DeadCodeMetadata {
num_preloads,
code_offsets,
ret_types,
calls,
calls_offsets,
}
@ -77,16 +67,9 @@ pub fn parse_dead_code_metadata<'a>(
arena: &'a Bump,
fn_count: u32,
code_section_body: &[u8],
signature_ret_types: Vec<'a, Option<ValueType>>,
internal_fn_sig_ids: Vec<'a, u32>,
import_fn_count: u32,
) -> DeadCodeMetadata<'a> {
let mut metadata = DeadCodeMetadata::new(arena, import_fn_count, fn_count);
metadata.ret_types.extend(
internal_fn_sig_ids
.iter()
.map(|sig| signature_ret_types[*sig as usize]),
);
let mut cursor: usize = 0;
while cursor < code_section_body.len() {
@ -134,7 +117,7 @@ pub fn trace_function_deps<'a, Indices: IntoIterator<Item = u32>>(
exported_fns: &[u32],
called_from_app: Indices,
) -> Vec<'a, u32> {
let num_preloads = metadata.ret_types.len();
let num_preloads = metadata.num_preloads;
// All functions that get called from the app, directly or indirectly
let mut live_fn_indices = Vec::with_capacity_in(num_preloads, arena);
@ -186,44 +169,6 @@ pub fn trace_function_deps<'a, Indices: IntoIterator<Item = u32>>(
live_fn_indices
}
/// Create a set of minimum-size dummy functions for each possible return type
fn create_dummy_functions(arena: &Bump) -> [Vec<'_, u8>; 5] {
let mut code_builder_i32 = CodeBuilder::new(arena);
code_builder_i32.i32_const(0);
let mut code_builder_i64 = CodeBuilder::new(arena);
code_builder_i64.i64_const(0);
let mut code_builder_f32 = CodeBuilder::new(arena);
code_builder_f32.f32_const(0.0);
let mut code_builder_f64 = CodeBuilder::new(arena);
code_builder_f64.f64_const(0.0);
let mut code_builder_nil = CodeBuilder::new(arena);
code_builder_i32.build_fn_header_and_footer(&[], 0, None);
code_builder_i64.build_fn_header_and_footer(&[], 0, None);
code_builder_f32.build_fn_header_and_footer(&[], 0, None);
code_builder_f64.build_fn_header_and_footer(&[], 0, None);
code_builder_nil.build_fn_header_and_footer(&[], 0, None);
let capacity = code_builder_f64.size();
let mut dummy_i32 = Vec::with_capacity_in(capacity, arena);
let mut dummy_i64 = Vec::with_capacity_in(capacity, arena);
let mut dummy_f32 = Vec::with_capacity_in(capacity, arena);
let mut dummy_f64 = Vec::with_capacity_in(capacity, arena);
let mut dummy_nil = Vec::with_capacity_in(capacity, arena);
code_builder_i32.serialize(&mut dummy_i32);
code_builder_i64.serialize(&mut dummy_i64);
code_builder_f32.serialize(&mut dummy_f32);
code_builder_f64.serialize(&mut dummy_f64);
code_builder_nil.serialize(&mut dummy_nil);
[dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil]
}
/// Copy used functions from preloaded object file into our Code section
/// Replace unused functions with very small dummies, to avoid changing any indices
pub fn copy_live_and_replace_dead_preloads<'a, T: SerialBuffer>(
@ -235,16 +180,20 @@ pub fn copy_live_and_replace_dead_preloads<'a, T: SerialBuffer>(
mut live_preload_indices: Vec<'a, u32>,
) {
let preload_idx_start = import_fn_count as usize;
let preload_idx_end = metadata.ret_types.len();
let [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] = create_dummy_functions(arena);
// Create a dummy function with just a single `unreachable` instruction
let mut dummy_builder = CodeBuilder::new(arena);
dummy_builder.unreachable_();
dummy_builder.build_fn_header_and_footer(&[], 0, None);
let mut dummy_bytes = Vec::with_capacity_in(dummy_builder.size(), arena);
dummy_builder.serialize(&mut dummy_bytes);
live_preload_indices.sort_unstable();
live_preload_indices.dedup();
let mut live_iter = live_preload_indices.iter();
let mut next_live_idx = live_iter.next();
for i in preload_idx_start..preload_idx_end {
for i in preload_idx_start..metadata.num_preloads {
match next_live_idx {
Some(live) if *live as usize == i => {
next_live_idx = live_iter.next();
@ -253,15 +202,7 @@ pub fn copy_live_and_replace_dead_preloads<'a, T: SerialBuffer>(
buffer.append_slice(&external_code[live_body_start..live_body_end]);
}
_ => {
let ret_type = metadata.ret_types[i];
let dummy_bytes = match ret_type {
Some(ValueType::I32) => &dummy_i32,
Some(ValueType::I64) => &dummy_i64,
Some(ValueType::F32) => &dummy_f32,
Some(ValueType::F64) => &dummy_f64,
None => &dummy_nil,
};
buffer.append_slice(dummy_bytes);
buffer.append_slice(&dummy_bytes);
}
}
}

View file

@ -130,34 +130,20 @@ impl<'a> WasmModule<'a> {
let mut cursor: usize = 8;
let mut types = TypeSection::preload(arena, bytes, &mut cursor);
let ret_types = types.parse_preloaded_data(arena);
types.parse_offsets();
let import = ImportSection::preload(arena, bytes, &mut cursor);
let function = FunctionSection::preload(arena, bytes, &mut cursor);
let signature_ids = function.parse_preloaded_data(arena);
let table = OpaqueSection::preload(SectionId::Table, arena, bytes, &mut cursor);
let memory = MemorySection::preload(arena, bytes, &mut cursor);
let global = GlobalSection::preload(arena, bytes, &mut cursor);
ExportSection::skip_bytes(bytes, &mut cursor);
let export = ExportSection::empty(arena);
let start = OpaqueSection::preload(SectionId::Start, arena, bytes, &mut cursor);
let element = OpaqueSection::preload(SectionId::Element, arena, bytes, &mut cursor);
let code = CodeSection::preload(
arena,
bytes,
&mut cursor,
ret_types,
signature_ids,
import.function_count,
);
let code = CodeSection::preload(arena, bytes, &mut cursor, import.function_count);
let data = DataSection::preload(arena, bytes, &mut cursor);

View file

@ -9,9 +9,7 @@ use super::dead_code::{
};
use super::linking::RelocationEntry;
use super::opcodes::OpCode;
use super::serialize::{
decode_u32_or_panic, parse_u32_or_panic, SerialBuffer, Serialize, SkipBytes,
};
use super::serialize::{parse_u32_or_panic, SerialBuffer, Serialize, SkipBytes};
use super::{CodeBuilder, ValueType};
/*******************************************************************
@ -223,9 +221,8 @@ impl<'a> TypeSection<'a> {
sig_id as u32
}
pub fn parse_preloaded_data(&mut self, arena: &'a Bump) -> Vec<'a, Option<ValueType>> {
pub fn parse_offsets(&mut self) {
self.offsets.clear();
let mut ret_types = Vec::with_capacity_in(self.offsets.capacity(), arena);
let mut i = 0;
while i < self.bytes.len() {
@ -234,23 +231,12 @@ impl<'a> TypeSection<'a> {
debug_assert!(self.bytes[i] == Signature::SEPARATOR);
i += 1;
let (n_params, n_params_size) = decode_u32_or_panic(&self.bytes[i..]);
i += n_params_size; // skip over the array length that we just decoded
let n_params = parse_u32_or_panic(&self.bytes, &mut i);
i += n_params as usize; // skip over one byte per param type
let n_return_values = self.bytes[i];
i += 1;
ret_types.push(if n_return_values == 0 {
None
} else {
Some(ValueType::from(self.bytes[i]))
});
i += n_return_values as usize;
i += 1 + n_return_values as usize;
}
ret_types
}
}
@ -451,15 +437,6 @@ impl<'a> FunctionSection<'a> {
self.bytes.encode_u32(sig_id);
self.count += 1;
}
pub fn parse_preloaded_data(&self, arena: &'a Bump) -> Vec<'a, u32> {
let mut preload_signature_ids = Vec::with_capacity_in(self.count as usize, arena);
let mut cursor = 0;
while cursor < self.bytes.len() {
preload_signature_ids.push(parse_u32_or_panic(&self.bytes, &mut cursor));
}
preload_signature_ids
}
}
section_impl!(FunctionSection, SectionId::Function);
@ -760,22 +737,14 @@ impl<'a> CodeSection<'a> {
arena: &'a Bump,
module_bytes: &[u8],
cursor: &mut usize,
ret_types: Vec<'a, Option<ValueType>>,
internal_fn_sig_ids: Vec<'a, u32>,
import_fn_count: u32,
) -> Self {
let (preloaded_count, initial_bytes) = parse_section(SectionId::Code, module_bytes, cursor);
let preloaded_bytes = arena.alloc_slice_copy(initial_bytes);
// TODO: Try to move this metadata preparation to platform build time
let dead_code_metadata = parse_dead_code_metadata(
arena,
preloaded_count,
initial_bytes,
ret_types,
internal_fn_sig_ids,
import_fn_count,
);
let dead_code_metadata =
parse_dead_code_metadata(arena, preloaded_count, initial_bytes, import_fn_count);
CodeSection {
preloaded_count,
@ -1053,7 +1022,7 @@ mod tests {
// Reconstruct a new TypeSection by "pre-loading" the bytes of the original
let mut cursor = 0;
let mut preloaded = TypeSection::preload(arena, &original_serialized, &mut cursor);
preloaded.parse_preloaded_data(arena);
preloaded.parse_offsets();
debug_assert_eq!(original.offsets, preloaded.offsets);
debug_assert_eq!(original.bytes, preloaded.bytes);

View file

@ -250,10 +250,6 @@ pub fn decode_u32(bytes: &[u8]) -> Result<(u32, usize), String> {
))
}
pub fn decode_u32_or_panic(bytes: &[u8]) -> (u32, usize) {
decode_u32(bytes).unwrap_or_else(|e| internal_error!("{}", e))
}
pub fn parse_u32_or_panic(bytes: &[u8], cursor: &mut usize) -> u32 {
let (value, len) = decode_u32(&bytes[*cursor..]).unwrap_or_else(|e| internal_error!("{}", e));
*cursor += len;