Wasm: adjust dead code elimination to account for import function indices

This commit is contained in:
Brian Carroll 2022-01-12 09:31:00 +00:00
parent ca2597973e
commit 9dabc2db15
5 changed files with 124 additions and 61 deletions

View file

@ -18,12 +18,27 @@ pub struct DeadCodeMetadata<'a> {
}
impl<'a> DeadCodeMetadata<'a> {
pub fn new(arena: &'a Bump, func_count: usize) -> Self {
pub fn new(arena: &'a Bump, import_fn_count: u32, fn_count: u32) -> Self {
let capacity = (import_fn_count + fn_count) as usize;
let mut code_offsets = Vec::with_capacity_in(capacity, arena);
let mut ret_types = Vec::with_capacity_in(capacity, arena);
let calls = Vec::with_capacity_in(2 * capacity, arena);
let mut calls_offsets = Vec::with_capacity_in(1 + capacity, arena);
// Imported functions have zero code length and no calls
code_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize));
calls_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize));
// We don't care about import return types
// Return types are for replacing dead functions with dummies, which doesn't apply to imports
ret_types.extend(std::iter::repeat(None).take(import_fn_count as usize));
DeadCodeMetadata {
code_offsets: Vec::with_capacity_in(func_count, arena),
ret_types: Vec::with_capacity_in(func_count, arena),
calls: Vec::with_capacity_in(2 * func_count, arena),
calls_offsets: Vec::with_capacity_in(1 + func_count, arena),
code_offsets,
ret_types,
calls,
calls_offsets,
}
}
}
@ -34,15 +49,18 @@ impl<'a> DeadCodeMetadata<'a> {
/// use this backend without a linker.
pub fn parse_dead_code_metadata<'a>(
arena: &'a Bump,
func_count: u32,
fn_count: u32,
code_section_body: &[u8],
ret_types: Vec<'a, Option<ValueType>>,
signature_ids: Vec<'a, u32>,
signature_ret_types: Vec<'a, Option<ValueType>>,
internal_fn_sig_ids: Vec<'a, u32>,
import_fn_count: u32,
) -> DeadCodeMetadata<'a> {
let mut metadata = DeadCodeMetadata::new(arena, func_count as usize);
metadata
.ret_types
.extend(signature_ids.iter().map(|sig| ret_types[*sig as usize]));
let mut metadata = DeadCodeMetadata::new(arena, import_fn_count, fn_count);
metadata.ret_types.extend(
internal_fn_sig_ids
.iter()
.map(|sig| signature_ret_types[*sig as usize]),
);
let mut cursor: usize = 0;
while cursor < code_section_body.len() {
@ -89,20 +107,20 @@ pub fn trace_function_deps<'a, Indices: IntoIterator<Item = u32>>(
metadata: &DeadCodeMetadata<'a>,
called_from_app: Indices,
) -> Vec<'a, u32> {
let mut live_fn_indices: Vec<'a, u32> = Vec::with_capacity_in(metadata.calls.len(), arena);
live_fn_indices.extend(called_from_app);
let num_funcs = metadata.ret_types.len();
let num_funcs = metadata.calls_offsets.len();
// All functions that get called from the app, directly or indirectly
let mut live_fn_indices = Vec::with_capacity_in(num_funcs, arena);
// Current batch of functions whose call graphs we want to trace
let mut current_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena);
current_trace.clone_from(&live_fn_indices);
// Current & next batch of functions whose call graphs we want to trace through the metadata
// (2 separate vectors so that we're not iterating over the same one we're changing)
// If the max call depth is N then we will do N traces or less
let mut current_trace = Vec::with_capacity_in(num_funcs, arena);
current_trace.extend(called_from_app);
let mut next_trace = Vec::with_capacity_in(num_funcs, arena);
// The next batch (don't want to modify the current one while we're iterating over it!)
let mut next_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena);
// Fast lookup for what's already traced so we don't need to do it again
let mut already_traced: Vec<'a, bool> = Vec::from_iter_in((0..num_funcs).map(|_| false), arena);
// Fast per-function lookup table to see if its dependencies have already been traced
let mut already_traced = Vec::from_iter_in(std::iter::repeat(false).take(num_funcs), arena);
loop {
live_fn_indices.extend_from_slice(&current_trace);
@ -126,13 +144,6 @@ pub fn trace_function_deps<'a, Indices: IntoIterator<Item = u32>>(
next_trace.clear();
}
if true {
println!("Hey Brian, don't forget to remove this debug code");
let unsorted_len = live_fn_indices.len();
live_fn_indices.dedup();
debug_assert!(unsorted_len == live_fn_indices.len());
}
live_fn_indices
}
@ -181,15 +192,20 @@ pub fn copy_live_and_replace_dead<'a, T: SerialBuffer>(
buffer: &mut T,
metadata: &DeadCodeMetadata<'a>,
external_code: &[u8],
live_ext_fn_indices: &'a mut [u32],
import_fn_count: u32,
mut live_ext_fn_indices: Vec<'a, u32>,
) {
live_ext_fn_indices.sort_unstable();
let [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] = create_dummy_functions(arena);
let mut prev = 0;
for live32 in live_ext_fn_indices.iter() {
let live = *live32 as usize;
let mut prev = import_fn_count as usize;
for live32 in live_ext_fn_indices.into_iter() {
if live32 < import_fn_count {
continue;
}
let live = live32 as usize;
// Replace dead functions with the minimal code body that will pass validation checks
for dead in prev..live {
@ -210,4 +226,21 @@ pub fn copy_live_and_replace_dead<'a, T: SerialBuffer>(
prev = live + 1;
}
let num_preloaded_fns = metadata.ret_types.len();
// Replace dead functions with the minimal code body that will pass validation checks
for dead in prev..num_preloaded_fns {
if dead < import_fn_count as usize {
continue;
}
let ret_type = metadata.ret_types[dead];
let dummy_bytes = match ret_type {
Some(ValueType::I32) => &dummy_i32,
Some(ValueType::I64) => &dummy_i64,
Some(ValueType::F32) => &dummy_f32,
Some(ValueType::F64) => &dummy_f64,
None => &dummy_nil,
};
buffer.append_slice(dummy_bytes);
}
}

View file

@ -137,7 +137,14 @@ impl<'a> WasmModule<'a> {
let export = ExportSection::preload(arena, bytes, &mut cursor);
let start = OpaqueSection::preload(SectionId::Start, arena, bytes, &mut cursor);
let element = OpaqueSection::preload(SectionId::Element, arena, bytes, &mut cursor);
let code = CodeSection::preload(arena, bytes, &mut cursor, ret_types, signature_ids);
let code = CodeSection::preload(
arena,
bytes,
&mut cursor,
ret_types,
signature_ids,
import.function_count,
);
let data = DataSection::preload(arena, bytes, &mut cursor);
let linking = LinkingSection::new(arena);
let relocations = RelocationSection::new(arena, "reloc.CODE");

View file

@ -324,6 +324,7 @@ pub struct Import {
}
#[repr(u8)]
#[derive(Debug)]
enum ImportTypeId {
Func = 0,
Table = 1,
@ -391,10 +392,10 @@ impl<'a> ImportSection<'a> {
String::skip_bytes(&self.bytes, &mut cursor);
String::skip_bytes(&self.bytes, &mut cursor);
let type_id = self.bytes[cursor];
let type_id = ImportTypeId::from(self.bytes[cursor]);
cursor += 1;
match ImportTypeId::from(type_id) {
match type_id {
ImportTypeId::Func => {
f_count += 1;
u32::skip_bytes(&self.bytes, &mut cursor);
@ -425,7 +426,11 @@ impl<'a> ImportSection<'a> {
}
}
section_impl!(ImportSection, SectionId::Import, ImportSection::from_count_and_bytes);
section_impl!(
ImportSection,
SectionId::Import,
ImportSection::from_count_and_bytes
);
/*******************************************************************
*
@ -740,7 +745,8 @@ impl<'a> CodeSection<'a> {
module_bytes: &[u8],
cursor: &mut usize,
ret_types: Vec<'a, Option<ValueType>>,
signature_ids: Vec<'a, u32>,
internal_fn_sig_ids: Vec<'a, u32>,
import_fn_count: u32,
) -> Self {
let (preloaded_count, initial_bytes) = parse_section(SectionId::Code, module_bytes, cursor);
let preloaded_bytes = arena.alloc_slice_copy(initial_bytes);
@ -751,7 +757,8 @@ impl<'a> CodeSection<'a> {
preloaded_count,
initial_bytes,
ret_types,
signature_ids,
internal_fn_sig_ids,
import_fn_count,
);
CodeSection {
@ -762,11 +769,13 @@ impl<'a> CodeSection<'a> {
}
}
pub fn remove_dead_preloads<T>(&mut self, arena: &'a Bump, called_preload_fns: T)
where
T: IntoIterator<Item = u32>,
{
let mut live_ext_fn_indices =
pub fn remove_dead_preloads<T: IntoIterator<Item = u32>>(
&mut self,
arena: &'a Bump,
import_fn_count: u32,
called_preload_fns: T,
) {
let live_ext_fn_indices =
trace_function_deps(arena, &self.dead_code_metadata, called_preload_fns);
let mut buffer = Vec::with_capacity_in(self.preloaded_bytes.len(), arena);
@ -776,7 +785,8 @@ impl<'a> CodeSection<'a> {
&mut buffer,
&self.dead_code_metadata,
self.preloaded_bytes,
&mut live_ext_fn_indices,
import_fn_count,
live_ext_fn_indices,
);
self.preloaded_bytes = buffer.into_bump_slice();

View file

@ -268,25 +268,27 @@ pub trait SkipBytes {
impl SkipBytes for u32 {
fn skip_bytes(bytes: &[u8], cursor: &mut usize) {
let imax = 5;
let mut i = *cursor;
while (bytes[i] & 0x80 != 0) && (i < imax) {
i += 1;
const MAX_LEN: usize = 5;
for (i, byte) in bytes.iter().enumerate().skip(*cursor).take(MAX_LEN) {
if byte & 0x80 == 0 {
*cursor = i + 1;
return;
}
}
debug_assert!(i < imax);
*cursor = i + 1
internal_error!("Invalid LEB encoding");
}
}
impl SkipBytes for u64 {
fn skip_bytes(bytes: &[u8], cursor: &mut usize) {
let imax = 10;
let mut i = *cursor;
while (bytes[i] & 0x80 != 0) && (i < imax) {
i += 1;
const MAX_LEN: usize = 10;
for (i, byte) in bytes.iter().enumerate().skip(*cursor).take(MAX_LEN) {
if byte & 0x80 == 0 {
*cursor = i + 1;
return;
}
}
debug_assert!(i < imax);
*cursor = i + 1
internal_error!("Invalid LEB encoding");
}
}
@ -299,6 +301,15 @@ impl SkipBytes for u8 {
impl SkipBytes for String {
fn skip_bytes(bytes: &[u8], cursor: &mut usize) {
let len = parse_u32_or_panic(bytes, cursor);
if false {
let str_bytes = &bytes[*cursor..(*cursor + len as usize)];
println!(
"Skipping String {:?}",
String::from_utf8(str_bytes.to_vec()).unwrap()
);
}
*cursor += len as usize;
}
}