diff --git a/compiler/gen_wasm/src/lib.rs b/compiler/gen_wasm/src/lib.rs index 8fef9ab562..63872f83e2 100644 --- a/compiler/gen_wasm/src/lib.rs +++ b/compiler/gen_wasm/src/lib.rs @@ -91,7 +91,7 @@ pub fn build_app_module<'a>( // Adjust Wasm function indices to account for functions from the object file let fn_index_offset: u32 = - host_module.import.fn_signatures.len() as u32 + host_module.code.preloaded_count; + host_module.import.function_signature_count() as u32 + host_module.code.preloaded_count; // Collect the symbols & names for the procedures, // and filter out procs we're going to inline diff --git a/compiler/gen_wasm/src/wasm32_result.rs b/compiler/gen_wasm/src/wasm32_result.rs index 341cd26619..b5515ec80b 100644 --- a/compiler/gen_wasm/src/wasm32_result.rs +++ b/compiler/gen_wasm/src/wasm32_result.rs @@ -88,7 +88,7 @@ fn insert_wrapper_metadata<'a>( module: &mut WasmModule<'a>, wrapper_name: &'static str, ) { - let index = (module.import.fn_signatures.len() as u32) + let index = (module.import.function_signature_count() as u32) + module.code.preloaded_count + module.code.code_builders.len() as u32; diff --git a/compiler/gen_wasm/src/wasm_module/mod.rs b/compiler/gen_wasm/src/wasm_module/mod.rs index 832f6bb81e..389b40f002 100644 --- a/compiler/gen_wasm/src/wasm_module/mod.rs +++ b/compiler/gen_wasm/src/wasm_module/mod.rs @@ -111,11 +111,12 @@ impl<'a> WasmModule<'a> { let element = ElementSection::parse(arena, bytes, &mut cursor)?; let indirect_callees = element.indirect_callees(arena); + let imported_fn_signatures = import.function_signatures(arena); let code = CodeSection::parse( arena, bytes, &mut cursor, - &import.fn_signatures, + &imported_fn_signatures, &function.signatures, &indirect_callees, )?; @@ -194,7 +195,7 @@ impl<'a> WasmModule<'a> { self.code.remove_dead_preloads( arena, - self.import.fn_signatures.len(), + self.import.function_signature_count(), &function_indices, called_preload_fns, ) diff --git a/compiler/gen_wasm/src/wasm_module/parse.rs b/compiler/gen_wasm/src/wasm_module/parse.rs index eff8c08140..968c65c6c7 100644 --- a/compiler/gen_wasm/src/wasm_module/parse.rs +++ b/compiler/gen_wasm/src/wasm_module/parse.rs @@ -47,6 +47,14 @@ impl Parse<()> for u32 { } } +impl Parse<()> for u8 { + fn parse(_ctx: (), bytes: &[u8], cursor: &mut usize) -> Result { + let byte = bytes[*cursor]; + *cursor += 1; + Ok(byte) + } +} + /// Decode a signed 32-bit integer from the provided buffer in LEB-128 format /// Return the integer itself and the offset after it ends fn decode_i32(bytes: &[u8]) -> Result<(i32, usize), ()> { diff --git a/compiler/gen_wasm/src/wasm_module/sections.rs b/compiler/gen_wasm/src/wasm_module/sections.rs index 08de8e3286..29262080d9 100644 --- a/compiler/gen_wasm/src/wasm_module/sections.rs +++ b/compiler/gen_wasm/src/wasm_module/sections.rs @@ -315,10 +315,58 @@ pub enum ImportDesc { Global { ty: GlobalType }, } +impl Parse<()> for ImportDesc { + fn parse(_: (), bytes: &[u8], cursor: &mut usize) -> Result { + let type_id = ImportTypeId::from(bytes[*cursor]); + *cursor += 1; + match type_id { + ImportTypeId::Func => { + let signature_index = u32::parse((), bytes, cursor)?; + Ok(ImportDesc::Func { signature_index }) + } + ImportTypeId::Table => { + let ty = TableType::parse((), bytes, cursor)?; + Ok(ImportDesc::Table { ty }) + } + ImportTypeId::Mem => { + let limits = Limits::parse((), bytes, cursor)?; + Ok(ImportDesc::Mem { limits }) + } + ImportTypeId::Global => { + let ty = GlobalType::parse((), bytes, cursor)?; + Ok(ImportDesc::Global { ty }) + } + } + } +} + +impl Serialize for ImportDesc { + fn serialize(&self, buffer: &mut T) { + match self { + Self::Func { signature_index } => { + buffer.append_u8(ImportTypeId::Func as u8); + signature_index.serialize(buffer); + } + Self::Table { ty } => { + buffer.append_u8(ImportTypeId::Table as u8); + ty.serialize(buffer); + } + Self::Mem { limits } => { + buffer.append_u8(ImportTypeId::Mem as u8); + limits.serialize(buffer); + } + Self::Global { ty } => { + buffer.append_u8(ImportTypeId::Global as u8); + ty.serialize(buffer); + } + } + } +} + #[derive(Debug)] -pub struct Import { - pub module: &'static str, - pub name: String, +pub struct Import<'a> { + pub module: &'a str, + pub name: &'a str, pub description: ImportDesc, } @@ -346,97 +394,83 @@ impl From for ImportTypeId { } } -impl Serialize for Import { +impl<'a> Import<'a> { + fn size(&self) -> usize { + self.module.len() + + self.name.len() + + match self.description { + ImportDesc::Func { .. } => MAX_SIZE_ENCODED_U32, + ImportDesc::Table { .. } => 4, + ImportDesc::Mem { .. } => 3, + ImportDesc::Global { .. } => 2, + } + } +} + +impl<'a> Serialize for Import<'a> { fn serialize(&self, buffer: &mut T) { self.module.serialize(buffer); self.name.serialize(buffer); - match &self.description { - ImportDesc::Func { signature_index } => { - buffer.append_u8(ImportTypeId::Func as u8); - buffer.encode_u32(*signature_index); - } - ImportDesc::Table { ty } => { - buffer.append_u8(ImportTypeId::Table as u8); - ty.serialize(buffer); - } - ImportDesc::Mem { limits } => { - buffer.append_u8(ImportTypeId::Mem as u8); - limits.serialize(buffer); - } - ImportDesc::Global { ty } => { - buffer.append_u8(ImportTypeId::Global as u8); - ty.serialize(buffer); - } - } + self.description.serialize(buffer); } } #[derive(Debug)] pub struct ImportSection<'a> { - pub count: u32, - pub fn_signatures: Vec<'a, u32>, - pub bytes: Vec<'a, u8>, + pub imports: Vec<'a, Import<'a>>, } impl<'a> ImportSection<'a> { const ID: SectionId = SectionId::Import; pub fn size(&self) -> usize { - self.bytes.len() + self.imports.iter().map(|imp| imp.size()).sum() + } + + pub fn function_signatures(&self, arena: &'a Bump) -> Vec<'a, u32> { + let sig_iter = self.imports.iter().filter_map(|imp| match imp.description { + ImportDesc::Func { signature_index } => Some(signature_index), + _ => None, + }); + Vec::from_iter_in(sig_iter, arena) + } + + pub fn function_signature_count(&self) -> usize { + self.imports + .iter() + .filter(|imp| matches!(imp.description, ImportDesc::Func { .. })) + .count() } } impl<'a> Parse<&'a Bump> for ImportSection<'a> { fn parse(arena: &'a Bump, module_bytes: &[u8], cursor: &mut usize) -> Result { - let (mut count, range) = parse_section(Self::ID, module_bytes, cursor)?; - let mut bytes = Vec::with_capacity_in(range.len() * 2, arena); - let mut fn_signatures = Vec::with_capacity_in(range.len() / 8, arena); + let (count, range) = parse_section(Self::ID, module_bytes, cursor)?; + let mut imports = Vec::with_capacity_in(count as usize, arena); let end = range.end; while *cursor < end { - let import_start = *cursor; - String::skip_bytes(module_bytes, cursor)?; // import namespace - String::skip_bytes(module_bytes, cursor)?; // import name + let module = <&'a str>::parse(arena, module_bytes, cursor)?; + let name = <&'a str>::parse(arena, module_bytes, cursor)?; + let description = ImportDesc::parse((), module_bytes, cursor)?; - let type_id = ImportTypeId::from(module_bytes[*cursor]); - *cursor += 1; - - match type_id { - ImportTypeId::Func => { - let sig = u32::parse((), module_bytes, cursor)?; - fn_signatures.push(sig); - bytes.extend_from_slice(&module_bytes[import_start..*cursor]); - } - ImportTypeId::Table => { - TableType::skip_bytes(module_bytes, cursor)?; - count -= 1; - } - ImportTypeId::Mem => { - Limits::skip_bytes(module_bytes, cursor)?; - count -= 1; - } - ImportTypeId::Global => { - GlobalType::skip_bytes(module_bytes, cursor)?; - count -= 1; - } - } + imports.push(Import { + module, + name, + description, + }); } - Ok(ImportSection { - count, - fn_signatures, - bytes, - }) + Ok(ImportSection { imports }) } } impl<'a> Serialize for ImportSection<'a> { fn serialize(&self, buffer: &mut B) { - if !self.bytes.is_empty() { + if !self.imports.is_empty() { let header_indices = write_section_header(buffer, Self::ID); - buffer.encode_u32(self.count); - buffer.append_slice(&self.bytes); + self.imports.serialize(buffer); update_section_size(buffer, header_indices); } } @@ -509,6 +543,20 @@ pub enum RefType { Extern = 0x6f, } +impl Parse<()> for RefType { + fn parse(_: (), bytes: &[u8], cursor: &mut usize) -> Result { + let byte = bytes[*cursor]; + *cursor += 1; + match byte { + 0x70 => Ok(Self::Func), + 0x6f => Ok(Self::Extern), + _ => Err(ParseError { + offset: *cursor - 1, + message: format!("Invalid RefType 0x{:2x}", byte), + }), + } + } +} #[derive(Debug)] pub struct TableType { pub ref_type: RefType, @@ -530,6 +578,14 @@ impl SkipBytes for TableType { } } +impl Parse<()> for TableType { + fn parse(_: (), bytes: &[u8], cursor: &mut usize) -> Result { + let ref_type = RefType::parse((), bytes, cursor)?; + let limits = Limits::parse((), bytes, cursor)?; + Ok(TableType { ref_type, limits }) + } +} + #[derive(Debug)] pub struct TableSection { pub function_table: TableType, @@ -719,6 +775,19 @@ impl SkipBytes for GlobalType { } } +impl Parse<()> for GlobalType { + fn parse(_: (), bytes: &[u8], cursor: &mut usize) -> Result { + let value_type = ValueType::from(bytes[*cursor]); + *cursor += 1; + let is_mutable = bytes[*cursor] != 0; + *cursor += 1; + Ok(GlobalType { + value_type, + is_mutable, + }) + } +} + /// Constant expression for initialising globals or data segments /// Note: This is restricted for simplicity, but the spec allows arbitrary constant expressions #[derive(Debug)]