diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 9577fd9703..392dd176ab 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -21,7 +21,52 @@ const UNUSED_DATA_SECTION_BYTES: u32 = 1024; struct LabelId(u32); #[derive(Debug)] -struct SymbolStorage(LocalId, WasmLayout); +enum SymbolStorage { + ParamPrimitive { + local_id: LocalId, + value_type: ValueType, + size: u32, + }, + ParamPointer { + local_id: LocalId, + }, + VarPrimitive { + local_id: LocalId, + value_type: ValueType, + size: u32, + }, + VarStackMemory { + local_id: LocalId, + size: u32, + offset: u32, + }, + VarHeapMemory { + local_id: LocalId, + }, +} + +impl SymbolStorage { + fn local_id(&self) -> LocalId { + match self { + Self::ParamPrimitive { local_id, .. } => *local_id, + Self::ParamPointer { local_id, .. } => *local_id, + Self::VarPrimitive { local_id, .. } => *local_id, + Self::VarStackMemory { local_id, .. } => *local_id, + Self::VarHeapMemory { local_id, .. } => *local_id, + } + } + + #[allow(dead_code)] + fn value_type(&self) -> ValueType { + match self { + Self::ParamPrimitive { value_type, .. } => *value_type, + Self::VarPrimitive { value_type, .. } => *value_type, + Self::ParamPointer { .. } => ValueType::I32, + Self::VarStackMemory { .. } => ValueType::I32, + Self::VarHeapMemory { .. } => ValueType::I32, + } + } +} enum LocalKind { Parameter, @@ -164,39 +209,61 @@ impl<'a> WasmBackend<'a> { let local_index = (self.arg_types.len() + self.locals.len()) as u32; let local_id = LocalId(local_index); - match kind { + let storage = match kind { LocalKind::Parameter => { // Already stack-allocated by the caller if needed. self.arg_types.push(wasm_layout.value_type()); + match wasm_layout { + WasmLayout::LocalOnly(value_type, size) => SymbolStorage::ParamPrimitive { + local_id, + value_type, + size, + }, + _ => SymbolStorage::ParamPointer { local_id }, + } } LocalKind::Variable => { self.locals.push(Local::new(1, wasm_layout.value_type())); - if let WasmLayout::StackMemory { - size, - alignment_bytes, - } = wasm_layout - { - let align = alignment_bytes as i32; - let mut offset = self.stack_memory; - offset += align - 1; - offset &= -align; - self.stack_memory = offset + (size - alignment_bytes) as i32; + match wasm_layout { + WasmLayout::LocalOnly(value_type, size) => SymbolStorage::VarPrimitive { + local_id, + value_type, + size, + }, - let frame_pointer = self.get_or_create_frame_pointer(); + WasmLayout::HeapMemory => SymbolStorage::VarHeapMemory { local_id }, - // initialise the local with the appropriate address - self.instructions.extend([ - GetLocal(frame_pointer.0), - I32Const(offset), - I32Add, - SetLocal(local_index), - ]); + WasmLayout::StackMemory { + size, + alignment_bytes, + } => { + let align = alignment_bytes as i32; + let mut offset = self.stack_memory; + offset += align - 1; + offset &= -align; + self.stack_memory = offset + (size - alignment_bytes) as i32; + + let frame_pointer = self.get_or_create_frame_pointer(); + + // initialise the local with the appropriate address + self.instructions.extend([ + GetLocal(frame_pointer.0), + I32Const(offset), + I32Add, + SetLocal(local_index), + ]); + + SymbolStorage::VarStackMemory { + local_id, + size, + offset: offset as u32, + } + } } } - } + }; - let storage = SymbolStorage(local_id, wasm_layout); self.symbol_storage_map.insert(symbol, storage); local_id @@ -225,14 +292,14 @@ impl<'a> WasmBackend<'a> { } fn local_id_from_symbol(&self, sym: &Symbol) -> Result { - let SymbolStorage(local_id, _) = self.get_symbol_storage(sym)?; - Ok(*local_id) + let storage = self.get_symbol_storage(sym)?; + Ok(storage.local_id()) } - fn load_from_symbol(&mut self, sym: &Symbol) -> Result<(), String> { - let SymbolStorage(LocalId(local_id), _) = self.get_symbol_storage(sym)?; - let id: u32 = *local_id; - self.instructions.push(GetLocal(id)); + fn load_symbol(&mut self, sym: &Symbol) -> Result<(), String> { + let storage = self.get_symbol_storage(sym)?; + let index: u32 = storage.local_id().0; + self.instructions.push(GetLocal(index)); Ok(()) } @@ -265,7 +332,9 @@ impl<'a> WasmBackend<'a> { if let WasmLayout::StackMemory { .. } = wasm_layout { // Map this symbol to the first argument (pointer into caller's stack) // Saves us from having to copy it later - let storage = SymbolStorage(LocalId(0), wasm_layout); + let storage = SymbolStorage::ParamPointer { + local_id: LocalId(0), + }; self.symbol_storage_map.insert(*let_sym, storage); } self.build_expr(let_sym, expr, layout)?; @@ -287,30 +356,33 @@ impl<'a> WasmBackend<'a> { Stmt::Ret(sym) => { use crate::layout::WasmLayout::*; - let SymbolStorage(local_id, wasm_layout) = - self.symbol_storage_map.get(sym).unwrap(); + let storage = self.symbol_storage_map.get(sym).unwrap(); - match wasm_layout { - LocalOnly(_, _) | HeapMemory => { + match storage { + SymbolStorage::ParamPrimitive { local_id, .. } + | SymbolStorage::VarPrimitive { local_id, .. } + | SymbolStorage::ParamPointer { local_id, .. } + | SymbolStorage::VarHeapMemory { local_id, .. } => { self.instructions.push(GetLocal(local_id.0)); self.instructions.push(Return); } - StackMemory { - size, - alignment_bytes, - } => { - let from = local_id.clone(); - let to = LocalId(0); - let copy_size: u32 = *size; - let copy_alignment_bytes: u32 = *alignment_bytes; - copy_memory( - &mut self.instructions, - from, - to, - copy_size, - copy_alignment_bytes, - )?; + SymbolStorage::VarStackMemory { local_id, size, .. } => { + let ret_wasm_layout = WasmLayout::new(ret_layout); + if let StackMemory { alignment_bytes, .. } = ret_wasm_layout { + let from = local_id.clone(); + let to = LocalId(0); + let copy_size: u32 = *size; + copy_memory( + &mut self.instructions, + from, + to, + copy_size, + alignment_bytes, + )?; + } else { + panic!("Return layout doesn't match"); + } } } @@ -436,7 +508,7 @@ impl<'a> WasmBackend<'a> { }) => match call_type { CallType::ByName { name: func_sym, .. } => { for arg in *arguments { - self.load_from_symbol(arg)?; + self.load_symbol(arg)?; } let function_location = self.proc_symbol_map.get(func_sym).ok_or(format!( "Cannot find function {:?} called from {:?}", @@ -495,7 +567,7 @@ impl<'a> WasmBackend<'a> { return_layout: &Layout<'a>, ) -> Result<(), String> { for arg in args { - self.load_from_symbol(arg)?; + self.load_symbol(arg)?; } let wasm_layout = WasmLayout::new(return_layout); self.build_instructions_lowlevel(lowlevel, wasm_layout.value_type())?; @@ -513,7 +585,7 @@ impl<'a> WasmBackend<'a> { // For those, we'll need to pre-process each argument before the main op, // so simple arrays of instructions won't work. But there are common patterns. let instructions: &[Instruction] = match lowlevel { - // Wasm type might not be enough, may need to sign-extend i8 etc. Maybe in load_from_symbol? + // Wasm type might not be enough, may need to sign-extend i8 etc. Maybe in load_symbol? LowLevel::NumAdd => match return_value_type { ValueType::I32 => &[I32Add], ValueType::I64 => &[I64Add],