mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 23:04:49 +00:00
Wasm: adjust dead code elimination to account for import function indices
This commit is contained in:
parent
ca2597973e
commit
9dabc2db15
5 changed files with 124 additions and 61 deletions
|
@ -18,12 +18,27 @@ pub struct DeadCodeMetadata<'a> {
|
|||
}
|
||||
|
||||
impl<'a> DeadCodeMetadata<'a> {
|
||||
pub fn new(arena: &'a Bump, func_count: usize) -> Self {
|
||||
pub fn new(arena: &'a Bump, import_fn_count: u32, fn_count: u32) -> Self {
|
||||
let capacity = (import_fn_count + fn_count) as usize;
|
||||
|
||||
let mut code_offsets = Vec::with_capacity_in(capacity, arena);
|
||||
let mut ret_types = Vec::with_capacity_in(capacity, arena);
|
||||
let calls = Vec::with_capacity_in(2 * capacity, arena);
|
||||
let mut calls_offsets = Vec::with_capacity_in(1 + capacity, arena);
|
||||
|
||||
// Imported functions have zero code length and no calls
|
||||
code_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize));
|
||||
calls_offsets.extend(std::iter::repeat(0).take(import_fn_count as usize));
|
||||
|
||||
// We don't care about import return types
|
||||
// Return types are for replacing dead functions with dummies, which doesn't apply to imports
|
||||
ret_types.extend(std::iter::repeat(None).take(import_fn_count as usize));
|
||||
|
||||
DeadCodeMetadata {
|
||||
code_offsets: Vec::with_capacity_in(func_count, arena),
|
||||
ret_types: Vec::with_capacity_in(func_count, arena),
|
||||
calls: Vec::with_capacity_in(2 * func_count, arena),
|
||||
calls_offsets: Vec::with_capacity_in(1 + func_count, arena),
|
||||
code_offsets,
|
||||
ret_types,
|
||||
calls,
|
||||
calls_offsets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -34,15 +49,18 @@ impl<'a> DeadCodeMetadata<'a> {
|
|||
/// use this backend without a linker.
|
||||
pub fn parse_dead_code_metadata<'a>(
|
||||
arena: &'a Bump,
|
||||
func_count: u32,
|
||||
fn_count: u32,
|
||||
code_section_body: &[u8],
|
||||
ret_types: Vec<'a, Option<ValueType>>,
|
||||
signature_ids: Vec<'a, u32>,
|
||||
signature_ret_types: Vec<'a, Option<ValueType>>,
|
||||
internal_fn_sig_ids: Vec<'a, u32>,
|
||||
import_fn_count: u32,
|
||||
) -> DeadCodeMetadata<'a> {
|
||||
let mut metadata = DeadCodeMetadata::new(arena, func_count as usize);
|
||||
metadata
|
||||
.ret_types
|
||||
.extend(signature_ids.iter().map(|sig| ret_types[*sig as usize]));
|
||||
let mut metadata = DeadCodeMetadata::new(arena, import_fn_count, fn_count);
|
||||
metadata.ret_types.extend(
|
||||
internal_fn_sig_ids
|
||||
.iter()
|
||||
.map(|sig| signature_ret_types[*sig as usize]),
|
||||
);
|
||||
|
||||
let mut cursor: usize = 0;
|
||||
while cursor < code_section_body.len() {
|
||||
|
@ -89,20 +107,20 @@ pub fn trace_function_deps<'a, Indices: IntoIterator<Item = u32>>(
|
|||
metadata: &DeadCodeMetadata<'a>,
|
||||
called_from_app: Indices,
|
||||
) -> Vec<'a, u32> {
|
||||
let mut live_fn_indices: Vec<'a, u32> = Vec::with_capacity_in(metadata.calls.len(), arena);
|
||||
live_fn_indices.extend(called_from_app);
|
||||
let num_funcs = metadata.ret_types.len();
|
||||
|
||||
let num_funcs = metadata.calls_offsets.len();
|
||||
// All functions that get called from the app, directly or indirectly
|
||||
let mut live_fn_indices = Vec::with_capacity_in(num_funcs, arena);
|
||||
|
||||
// Current batch of functions whose call graphs we want to trace
|
||||
let mut current_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena);
|
||||
current_trace.clone_from(&live_fn_indices);
|
||||
// Current & next batch of functions whose call graphs we want to trace through the metadata
|
||||
// (2 separate vectors so that we're not iterating over the same one we're changing)
|
||||
// If the max call depth is N then we will do N traces or less
|
||||
let mut current_trace = Vec::with_capacity_in(num_funcs, arena);
|
||||
current_trace.extend(called_from_app);
|
||||
let mut next_trace = Vec::with_capacity_in(num_funcs, arena);
|
||||
|
||||
// The next batch (don't want to modify the current one while we're iterating over it!)
|
||||
let mut next_trace: Vec<'a, u32> = Vec::with_capacity_in(num_funcs, arena);
|
||||
|
||||
// Fast lookup for what's already traced so we don't need to do it again
|
||||
let mut already_traced: Vec<'a, bool> = Vec::from_iter_in((0..num_funcs).map(|_| false), arena);
|
||||
// Fast per-function lookup table to see if its dependencies have already been traced
|
||||
let mut already_traced = Vec::from_iter_in(std::iter::repeat(false).take(num_funcs), arena);
|
||||
|
||||
loop {
|
||||
live_fn_indices.extend_from_slice(¤t_trace);
|
||||
|
@ -126,13 +144,6 @@ pub fn trace_function_deps<'a, Indices: IntoIterator<Item = u32>>(
|
|||
next_trace.clear();
|
||||
}
|
||||
|
||||
if true {
|
||||
println!("Hey Brian, don't forget to remove this debug code");
|
||||
let unsorted_len = live_fn_indices.len();
|
||||
live_fn_indices.dedup();
|
||||
debug_assert!(unsorted_len == live_fn_indices.len());
|
||||
}
|
||||
|
||||
live_fn_indices
|
||||
}
|
||||
|
||||
|
@ -181,15 +192,20 @@ pub fn copy_live_and_replace_dead<'a, T: SerialBuffer>(
|
|||
buffer: &mut T,
|
||||
metadata: &DeadCodeMetadata<'a>,
|
||||
external_code: &[u8],
|
||||
live_ext_fn_indices: &'a mut [u32],
|
||||
import_fn_count: u32,
|
||||
mut live_ext_fn_indices: Vec<'a, u32>,
|
||||
) {
|
||||
live_ext_fn_indices.sort_unstable();
|
||||
|
||||
let [dummy_i32, dummy_i64, dummy_f32, dummy_f64, dummy_nil] = create_dummy_functions(arena);
|
||||
|
||||
let mut prev = 0;
|
||||
for live32 in live_ext_fn_indices.iter() {
|
||||
let live = *live32 as usize;
|
||||
let mut prev = import_fn_count as usize;
|
||||
for live32 in live_ext_fn_indices.into_iter() {
|
||||
if live32 < import_fn_count {
|
||||
continue;
|
||||
}
|
||||
|
||||
let live = live32 as usize;
|
||||
|
||||
// Replace dead functions with the minimal code body that will pass validation checks
|
||||
for dead in prev..live {
|
||||
|
@ -210,4 +226,21 @@ pub fn copy_live_and_replace_dead<'a, T: SerialBuffer>(
|
|||
|
||||
prev = live + 1;
|
||||
}
|
||||
|
||||
let num_preloaded_fns = metadata.ret_types.len();
|
||||
// Replace dead functions with the minimal code body that will pass validation checks
|
||||
for dead in prev..num_preloaded_fns {
|
||||
if dead < import_fn_count as usize {
|
||||
continue;
|
||||
}
|
||||
let ret_type = metadata.ret_types[dead];
|
||||
let dummy_bytes = match ret_type {
|
||||
Some(ValueType::I32) => &dummy_i32,
|
||||
Some(ValueType::I64) => &dummy_i64,
|
||||
Some(ValueType::F32) => &dummy_f32,
|
||||
Some(ValueType::F64) => &dummy_f64,
|
||||
None => &dummy_nil,
|
||||
};
|
||||
buffer.append_slice(dummy_bytes);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -137,7 +137,14 @@ impl<'a> WasmModule<'a> {
|
|||
let export = ExportSection::preload(arena, bytes, &mut cursor);
|
||||
let start = OpaqueSection::preload(SectionId::Start, arena, bytes, &mut cursor);
|
||||
let element = OpaqueSection::preload(SectionId::Element, arena, bytes, &mut cursor);
|
||||
let code = CodeSection::preload(arena, bytes, &mut cursor, ret_types, signature_ids);
|
||||
let code = CodeSection::preload(
|
||||
arena,
|
||||
bytes,
|
||||
&mut cursor,
|
||||
ret_types,
|
||||
signature_ids,
|
||||
import.function_count,
|
||||
);
|
||||
let data = DataSection::preload(arena, bytes, &mut cursor);
|
||||
let linking = LinkingSection::new(arena);
|
||||
let relocations = RelocationSection::new(arena, "reloc.CODE");
|
||||
|
|
|
@ -324,6 +324,7 @@ pub struct Import {
|
|||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Debug)]
|
||||
enum ImportTypeId {
|
||||
Func = 0,
|
||||
Table = 1,
|
||||
|
@ -391,10 +392,10 @@ impl<'a> ImportSection<'a> {
|
|||
String::skip_bytes(&self.bytes, &mut cursor);
|
||||
String::skip_bytes(&self.bytes, &mut cursor);
|
||||
|
||||
let type_id = self.bytes[cursor];
|
||||
let type_id = ImportTypeId::from(self.bytes[cursor]);
|
||||
cursor += 1;
|
||||
|
||||
match ImportTypeId::from(type_id) {
|
||||
match type_id {
|
||||
ImportTypeId::Func => {
|
||||
f_count += 1;
|
||||
u32::skip_bytes(&self.bytes, &mut cursor);
|
||||
|
@ -425,7 +426,11 @@ impl<'a> ImportSection<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
section_impl!(ImportSection, SectionId::Import, ImportSection::from_count_and_bytes);
|
||||
section_impl!(
|
||||
ImportSection,
|
||||
SectionId::Import,
|
||||
ImportSection::from_count_and_bytes
|
||||
);
|
||||
|
||||
/*******************************************************************
|
||||
*
|
||||
|
@ -740,7 +745,8 @@ impl<'a> CodeSection<'a> {
|
|||
module_bytes: &[u8],
|
||||
cursor: &mut usize,
|
||||
ret_types: Vec<'a, Option<ValueType>>,
|
||||
signature_ids: Vec<'a, u32>,
|
||||
internal_fn_sig_ids: Vec<'a, u32>,
|
||||
import_fn_count: u32,
|
||||
) -> Self {
|
||||
let (preloaded_count, initial_bytes) = parse_section(SectionId::Code, module_bytes, cursor);
|
||||
let preloaded_bytes = arena.alloc_slice_copy(initial_bytes);
|
||||
|
@ -751,7 +757,8 @@ impl<'a> CodeSection<'a> {
|
|||
preloaded_count,
|
||||
initial_bytes,
|
||||
ret_types,
|
||||
signature_ids,
|
||||
internal_fn_sig_ids,
|
||||
import_fn_count,
|
||||
);
|
||||
|
||||
CodeSection {
|
||||
|
@ -762,11 +769,13 @@ impl<'a> CodeSection<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn remove_dead_preloads<T>(&mut self, arena: &'a Bump, called_preload_fns: T)
|
||||
where
|
||||
T: IntoIterator<Item = u32>,
|
||||
{
|
||||
let mut live_ext_fn_indices =
|
||||
pub fn remove_dead_preloads<T: IntoIterator<Item = u32>>(
|
||||
&mut self,
|
||||
arena: &'a Bump,
|
||||
import_fn_count: u32,
|
||||
called_preload_fns: T,
|
||||
) {
|
||||
let live_ext_fn_indices =
|
||||
trace_function_deps(arena, &self.dead_code_metadata, called_preload_fns);
|
||||
|
||||
let mut buffer = Vec::with_capacity_in(self.preloaded_bytes.len(), arena);
|
||||
|
@ -776,7 +785,8 @@ impl<'a> CodeSection<'a> {
|
|||
&mut buffer,
|
||||
&self.dead_code_metadata,
|
||||
self.preloaded_bytes,
|
||||
&mut live_ext_fn_indices,
|
||||
import_fn_count,
|
||||
live_ext_fn_indices,
|
||||
);
|
||||
|
||||
self.preloaded_bytes = buffer.into_bump_slice();
|
||||
|
|
|
@ -268,25 +268,27 @@ pub trait SkipBytes {
|
|||
|
||||
impl SkipBytes for u32 {
|
||||
fn skip_bytes(bytes: &[u8], cursor: &mut usize) {
|
||||
let imax = 5;
|
||||
let mut i = *cursor;
|
||||
while (bytes[i] & 0x80 != 0) && (i < imax) {
|
||||
i += 1;
|
||||
const MAX_LEN: usize = 5;
|
||||
for (i, byte) in bytes.iter().enumerate().skip(*cursor).take(MAX_LEN) {
|
||||
if byte & 0x80 == 0 {
|
||||
*cursor = i + 1;
|
||||
return;
|
||||
}
|
||||
}
|
||||
debug_assert!(i < imax);
|
||||
*cursor = i + 1
|
||||
internal_error!("Invalid LEB encoding");
|
||||
}
|
||||
}
|
||||
|
||||
impl SkipBytes for u64 {
|
||||
fn skip_bytes(bytes: &[u8], cursor: &mut usize) {
|
||||
let imax = 10;
|
||||
let mut i = *cursor;
|
||||
while (bytes[i] & 0x80 != 0) && (i < imax) {
|
||||
i += 1;
|
||||
const MAX_LEN: usize = 10;
|
||||
for (i, byte) in bytes.iter().enumerate().skip(*cursor).take(MAX_LEN) {
|
||||
if byte & 0x80 == 0 {
|
||||
*cursor = i + 1;
|
||||
return;
|
||||
}
|
||||
}
|
||||
debug_assert!(i < imax);
|
||||
*cursor = i + 1
|
||||
internal_error!("Invalid LEB encoding");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -299,6 +301,15 @@ impl SkipBytes for u8 {
|
|||
impl SkipBytes for String {
|
||||
fn skip_bytes(bytes: &[u8], cursor: &mut usize) {
|
||||
let len = parse_u32_or_panic(bytes, cursor);
|
||||
|
||||
if false {
|
||||
let str_bytes = &bytes[*cursor..(*cursor + len as usize)];
|
||||
println!(
|
||||
"Skipping String {:?}",
|
||||
String::from_utf8(str_bytes.to_vec()).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
*cursor += len as usize;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue