use bumpalo::{self, collections::Vec}; use code_builder::Align; use roc_builtins::bitcode::{self, FloatWidth}; use roc_collections::all::MutMap; use roc_module::low_level::LowLevel; use roc_module::symbol::Symbol; use roc_mono::ir::{CallType, Expr, JoinPointId, Literal, Proc, Stmt}; use roc_mono::layout::{Layout, LayoutIds}; use crate::layout::WasmLayout; use crate::storage::{Storage, StoredValue, StoredValueKind}; use crate::wasm_module::linking::{ DataSymbol, LinkingSection, RelocationSection, WasmObjectSymbol, WASM_SYM_BINDING_WEAK, WASM_SYM_UNDEFINED, }; use crate::wasm_module::sections::{ CodeSection, DataMode, DataSection, DataSegment, ExportSection, FunctionSection, GlobalSection, Import, ImportDesc, ImportSection, MemorySection, TypeSection, WasmModule, }; use crate::wasm_module::{ code_builder, BlockType, CodeBuilder, ConstExpr, Export, ExportType, Global, GlobalType, LocalId, Signature, SymInfo, ValueType, }; use crate::{ copy_memory, CopyMemoryConfig, Env, BUILTINS_IMPORT_MODULE_NAME, MEMORY_NAME, PTR_TYPE, STACK_POINTER_NAME, }; /// The memory address where the constants data will be loaded during module instantiation. /// We avoid address zero and anywhere near it. They're valid addresses but maybe bug-prone. /// Follow Emscripten's example by leaving 1kB unused (though 4 bytes would probably do!) const CONST_SEGMENT_BASE_ADDR: u32 = 1024; /// Index of the data segment where we store constants const CONST_SEGMENT_INDEX: usize = 0; pub struct WasmBackend<'a> { env: &'a Env<'a>, // Module-level data pub module: WasmModule<'a>, layout_ids: LayoutIds<'a>, constant_sym_index_map: MutMap<&'a str, usize>, builtin_sym_index_map: MutMap<&'a str, usize>, proc_symbols: Vec<'a, Symbol>, pub linker_symbols: Vec<'a, SymInfo>, // Function-level data code_builder: CodeBuilder<'a>, storage: Storage<'a>, /// how many blocks deep are we (used for jumps) block_depth: u32, joinpoint_label_map: MutMap)>, } impl<'a> WasmBackend<'a> { pub fn new( env: &'a Env<'a>, layout_ids: LayoutIds<'a>, proc_symbols: Vec<'a, Symbol>, mut linker_symbols: Vec<'a, SymInfo>, mut exports: Vec<'a, Export>, ) -> Self { const MEMORY_INIT_SIZE: u32 = 1024 * 1024; let arena = env.arena; let num_procs = proc_symbols.len(); exports.push(Export { name: MEMORY_NAME.to_string(), ty: ExportType::Mem, index: 0, }); let stack_pointer = Global { ty: GlobalType { value_type: ValueType::I32, is_mutable: true, }, init: ConstExpr::I32(MEMORY_INIT_SIZE as i32), }; exports.push(Export { name: STACK_POINTER_NAME.to_string(), ty: ExportType::Global, index: 0, }); linker_symbols.push(SymInfo::Global(WasmObjectSymbol::Defined { flags: WASM_SYM_BINDING_WEAK, index: 0, name: STACK_POINTER_NAME.to_string(), })); let const_segment = DataSegment { mode: DataMode::Active { offset: ConstExpr::I32(CONST_SEGMENT_BASE_ADDR as i32), }, init: Vec::with_capacity_in(64, arena), }; let module = WasmModule { types: TypeSection::new(arena, num_procs), import: ImportSection::new(arena), function: FunctionSection::new(arena, num_procs), table: (), memory: MemorySection::new(MEMORY_INIT_SIZE), global: GlobalSection { entries: bumpalo::vec![in arena; stack_pointer], }, export: ExportSection { entries: exports }, start: (), element: (), code: CodeSection { code_builders: Vec::with_capacity_in(num_procs, arena), }, data: DataSection { segments: bumpalo::vec![in arena; const_segment], }, linking: LinkingSection::new(arena), relocations: RelocationSection::new(arena, "reloc.CODE"), }; WasmBackend { env, // Module-level data module, layout_ids, constant_sym_index_map: MutMap::default(), builtin_sym_index_map: MutMap::default(), proc_symbols, linker_symbols, // Function-level data block_depth: 0, joinpoint_label_map: MutMap::default(), code_builder: CodeBuilder::new(arena), storage: Storage::new(arena), } } /// Reset function-level data fn reset(&mut self) { // Push the completed CodeBuilder into the module and swap it for a new empty one let mut swap_code_builder = CodeBuilder::new(self.env.arena); std::mem::swap(&mut swap_code_builder, &mut self.code_builder); self.module.code.code_builders.push(swap_code_builder); self.storage.clear(); self.joinpoint_label_map.clear(); assert_eq!(self.block_depth, 0); } /********************************************************** PROCEDURE ***********************************************************/ pub fn build_proc(&mut self, proc: Proc<'a>, _sym: Symbol) -> Result<(), String> { // println!("\ngenerating procedure {:?}\n", _sym); self.start_proc(&proc); self.build_stmt(&proc.body, &proc.ret_layout)?; self.finalize_proc()?; self.reset(); // println!("\nfinished generating {:?}\n", _sym); Ok(()) } fn start_proc(&mut self, proc: &Proc<'a>) { let ret_layout = WasmLayout::new(&proc.ret_layout); let ret_type = if ret_layout.is_stack_memory() { self.storage.arg_types.push(PTR_TYPE); self.start_block(BlockType::NoResult); // block to ensure all paths pop stack memory (if any) None } else { let ty = ret_layout.value_type(); self.start_block(BlockType::Value(ty)); // block to ensure all paths pop stack memory (if any) Some(ty) }; for (layout, symbol) in proc.args { let arg_layout = WasmLayout::new(layout); self.storage .allocate(&arg_layout, *symbol, StoredValueKind::Parameter); } self.module.add_function_signature(Signature { param_types: self.storage.arg_types.clone(), ret_type, }); } fn finalize_proc(&mut self) -> Result<(), String> { // end the block from start_proc, to ensure all paths pop stack memory (if any) self.end_block(); // Write local declarations and stack frame push/pop code self.code_builder.build_fn_header( &self.storage.local_types, self.storage.stack_frame_size, self.storage.stack_frame_pointer, ); Ok(()) } /********************************************************** STATEMENTS ***********************************************************/ /// start a loop that leaves a value on the stack fn start_loop_with_return(&mut self, value_type: ValueType) { self.block_depth += 1; self.code_builder.loop_(BlockType::Value(value_type)); } fn start_block(&mut self, block_type: BlockType) { self.block_depth += 1; self.code_builder.block(block_type); } fn end_block(&mut self) { self.block_depth -= 1; self.code_builder.end(); } fn build_stmt(&mut self, stmt: &Stmt<'a>, ret_layout: &Layout<'a>) -> Result<(), String> { match stmt { Stmt::Let(sym, expr, layout, following) => { let wasm_layout = WasmLayout::new(layout); let kind = match following { Stmt::Ret(ret_sym) if *sym == *ret_sym => StoredValueKind::ReturnValue, _ => StoredValueKind::Variable, }; let sym_storage = self.storage.allocate(&wasm_layout, *sym, kind); self.build_expr(sym, expr, layout, &sym_storage)?; // For primitives, we record that this symbol is at the top of the VM stack // (For other values, we wrote to memory and there's nothing on the VM stack) if let WasmLayout::Primitive(value_type, size) = wasm_layout { let vm_state = self.code_builder.set_top_symbol(*sym); self.storage.symbol_storage_map.insert( *sym, StoredValue::VirtualMachineStack { vm_state, value_type, size, }, ); } self.build_stmt(following, ret_layout)?; Ok(()) } Stmt::Ret(sym) => { use crate::storage::StoredValue::*; let storage = self.storage.symbol_storage_map.get(sym).unwrap(); match storage { StackMemory { location, size, alignment_bytes, } => { let (from_ptr, from_offset) = location.local_and_offset(self.storage.stack_frame_pointer); copy_memory( &mut self.code_builder, CopyMemoryConfig { from_ptr, from_offset, to_ptr: LocalId(0), to_offset: 0, size: *size, alignment_bytes: *alignment_bytes, }, ); } _ => { self.storage.load_symbols(&mut self.code_builder, &[*sym]); self.code_builder.br(self.block_depth); // jump to end of function (for stack frame pop) } } Ok(()) } Stmt::Switch { cond_symbol, cond_layout: _, branches, default_branch, ret_layout: _, } => { // NOTE currently implemented as a series of conditional jumps // We may be able to improve this in the future with `Select` // or `BrTable` // Ensure the condition value is not stored only in the VM stack // Otherwise we can't reach it from inside the block let cond_storage = self.storage.get(cond_symbol).to_owned(); self.storage.ensure_value_has_local( &mut self.code_builder, *cond_symbol, cond_storage, ); // create (number_of_branches - 1) new blocks. for _ in 0..branches.len() { self.start_block(BlockType::NoResult) } // then, we jump whenever the value under scrutiny is equal to the value of a branch for (i, (value, _, _)) in branches.iter().enumerate() { // put the cond_symbol on the top of the stack self.storage .load_symbols(&mut self.code_builder, &[*cond_symbol]); self.code_builder.i32_const(*value as i32); // compare the 2 topmost values self.code_builder.i32_eq(); // "break" out of `i` surrounding blocks self.code_builder.br_if(i as u32); } // if we never jumped because a value matched, we're in the default case self.build_stmt(default_branch.1, ret_layout)?; // now put in the actual body of each branch in order // (the first branch would have broken out of 1 block, // hence we must generate its code first) for (_, _, branch) in branches.iter() { self.end_block(); self.build_stmt(branch, ret_layout)?; } Ok(()) } Stmt::Join { id, parameters, body, remainder, } => { // make locals for join pointer parameters let mut jp_param_storages = Vec::with_capacity_in(parameters.len(), self.env.arena); for parameter in parameters.iter() { let wasm_layout = WasmLayout::new(¶meter.layout); let mut param_storage = self.storage.allocate( &wasm_layout, parameter.symbol, StoredValueKind::Variable, ); param_storage = self.storage.ensure_value_has_local( &mut self.code_builder, parameter.symbol, param_storage, ); jp_param_storages.push(param_storage); } self.start_block(BlockType::NoResult); self.joinpoint_label_map .insert(*id, (self.block_depth, jp_param_storages)); self.build_stmt(remainder, ret_layout)?; self.end_block(); // A `return` inside of a `loop` seems to make it so that the `loop` itself // also "returns" (so, leaves on the stack) a value of the return type. let return_wasm_layout = WasmLayout::new(ret_layout); self.start_loop_with_return(return_wasm_layout.value_type()); self.build_stmt(body, ret_layout)?; // ends the loop self.end_block(); Ok(()) } Stmt::Jump(id, arguments) => { let (target, param_storages) = self.joinpoint_label_map[id].clone(); for (arg_symbol, param_storage) in arguments.iter().zip(param_storages.iter()) { let arg_storage = self.storage.get(arg_symbol).clone(); self.storage.clone_value( &mut self.code_builder, param_storage, &arg_storage, *arg_symbol, ); } // jump let levels = self.block_depth - target; self.code_builder.br(levels); Ok(()) } x => Err(format!("statement not yet implemented: {:?}", x)), } } /********************************************************** EXPRESSIONS ***********************************************************/ fn build_expr( &mut self, sym: &Symbol, expr: &Expr<'a>, layout: &Layout<'a>, storage: &StoredValue, ) -> Result<(), String> { let wasm_layout = WasmLayout::new(layout); match expr { Expr::Literal(lit) => self.load_literal(lit, storage, *sym, layout), Expr::Call(roc_mono::ir::Call { call_type, arguments, }) => match call_type { CallType::ByName { name: func_sym, .. } => { let mut wasm_args_tmp: Vec; let (wasm_args, has_return_val) = match wasm_layout { WasmLayout::StackMemory { .. } => { wasm_args_tmp = Vec::with_capacity_in(arguments.len() + 1, self.env.arena); wasm_args_tmp.push(*sym); wasm_args_tmp.extend_from_slice(*arguments); (wasm_args_tmp.as_slice(), false) } _ => (*arguments, true), }; self.storage.load_symbols(&mut self.code_builder, wasm_args); // Index of the called function in the code section. Assumes all functions end up in the binary. // (We may decide to keep all procs even if calls are inlined, in case platform calls them) let func_index = match self.proc_symbols.iter().position(|s| s == func_sym) { Some(i) => i as u32, None => { // TODO: actually useful linking! Push a relocation for it. return Err(format!( "Not yet supported: calling foreign function {:?}", func_sym )); } }; // Index of the function's name in the symbol table // Same as the function index since those are the first symbols we add let symbol_index = func_index; self.code_builder.call( func_index, symbol_index, wasm_args.len(), has_return_val, ); Ok(()) } CallType::LowLevel { op: lowlevel, .. } => { self.build_call_low_level(lowlevel, arguments, layout) } x => Err(format!("the call type, {:?}, is not yet implemented", x)), }, Expr::Struct(fields) => self.create_struct(sym, layout, fields), x => Err(format!("Expression is not yet implemented {:?}", x)), } } fn load_literal( &mut self, lit: &Literal<'a>, storage: &StoredValue, sym: Symbol, layout: &Layout<'a>, ) -> Result<(), String> { let not_supported_error = || Err(format!("Literal value {:?} is not yet implemented", lit)); match storage { StoredValue::VirtualMachineStack { value_type, .. } => { match (lit, value_type) { (Literal::Float(x), ValueType::F64) => self.code_builder.f64_const(*x as f64), (Literal::Float(x), ValueType::F32) => self.code_builder.f32_const(*x as f32), (Literal::Int(x), ValueType::I64) => self.code_builder.i64_const(*x as i64), (Literal::Int(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), (Literal::Bool(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), (Literal::Byte(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), _ => { return not_supported_error(); } }; } StoredValue::StackMemory { location, .. } => match lit { Literal::Str(string) => { let (local_id, offset) = location.local_and_offset(self.storage.stack_frame_pointer); let len = string.len(); if len < 8 { let mut stack_mem_bytes = [0; 8]; stack_mem_bytes[0..len].clone_from_slice(string.as_bytes()); stack_mem_bytes[7] = 0x80 | (len as u8); let str_as_int = i64::from_le_bytes(stack_mem_bytes); self.code_builder.get_local(local_id); self.code_builder.i64_const(str_as_int); self.code_builder.i64_store(Align::Bytes4, offset); } else { let (linker_sym_index, elements_addr) = self.lookup_string_constant(string, sym, layout); self.code_builder.get_local(local_id); self.code_builder.insert_memory_relocation(linker_sym_index); self.code_builder.i32_const(elements_addr as i32); self.code_builder.i32_store(Align::Bytes4, offset); self.code_builder.get_local(local_id); self.code_builder.i32_const(string.len() as i32); self.code_builder.i32_store(Align::Bytes4, offset + 4); }; } _ => { return not_supported_error(); } }, _ => { return not_supported_error(); } }; Ok(()) } /// Look up a string constant in our internal data structures /// Return the data we need for code gen: linker symbol index and memory address fn lookup_string_constant( &mut self, string: &'a str, sym: Symbol, layout: &Layout<'a>, ) -> (u32, u32) { match self.constant_sym_index_map.get(string) { Some(linker_sym_index) => { // We've seen this string before. The linker metadata has a reference // to its offset in the constants data segment. let syminfo = &self.linker_symbols[*linker_sym_index]; match syminfo { SymInfo::Data(DataSymbol::Defined { segment_offset, .. }) => { let elements_addr = *segment_offset + CONST_SEGMENT_BASE_ADDR; (*linker_sym_index as u32, elements_addr) } _ => unreachable!( "Compiler bug: Invalid linker symbol info for string {:?}:\n{:?}", string, syminfo ), } } None => { let const_segment_bytes = &mut self.module.data.segments[CONST_SEGMENT_INDEX].init; // Store the string in the data section, to be loaded on module instantiation // RocStr `elements` field will point to that constant data, not the heap let segment_offset = const_segment_bytes.len() as u32; let elements_addr = segment_offset + CONST_SEGMENT_BASE_ADDR; const_segment_bytes.extend_from_slice(string.as_bytes()); // Generate linker info // Just pick the symbol name from the first usage let name = self .layout_ids .get(sym, layout) .to_symbol_string(sym, &self.env.interns); let linker_symbol = SymInfo::Data(DataSymbol::Defined { flags: 0, name, segment_index: CONST_SEGMENT_INDEX as u32, segment_offset, size: string.len() as u32, }); let linker_sym_index = self.linker_symbols.len(); self.constant_sym_index_map.insert(string, linker_sym_index); self.linker_symbols.push(linker_symbol); (linker_sym_index as u32, elements_addr) } } } fn create_struct( &mut self, sym: &Symbol, layout: &Layout<'a>, fields: &'a [Symbol], ) -> Result<(), String> { // TODO: we just calculated storage and now we're getting it out of a map // Not passing it as an argument because I'm trying to match Backend method signatures let storage = self.storage.get(sym).to_owned(); if let Layout::Struct(field_layouts) = layout { match storage { StoredValue::StackMemory { location, size, .. } => { if size > 0 { let (local_id, struct_offset) = location.local_and_offset(self.storage.stack_frame_pointer); let mut field_offset = struct_offset; for (field, _) in fields.iter().zip(field_layouts.iter()) { field_offset += self.storage.copy_value_to_memory( &mut self.code_builder, local_id, field_offset, *field, ); } } else { return Err(format!("Not supported yet: zero-size struct at {:?}", sym)); } } _ => { return Err(format!( "Cannot create struct {:?} with storage {:?}", sym, storage )); } }; } else { // Struct expression but not Struct layout => single element. Copy it. let field_storage = self.storage.get(&fields[0]).to_owned(); self.storage .clone_value(&mut self.code_builder, &storage, &field_storage, fields[0]); } Ok(()) } fn build_call_low_level( &mut self, lowlevel: &LowLevel, args: &'a [Symbol], return_layout: &Layout<'a>, ) -> Result<(), String> { self.storage.load_symbols(&mut self.code_builder, args); let wasm_layout = WasmLayout::new(return_layout); let ret_type = wasm_layout.value_type(); let panic_ret_type = || panic!("Invalid return type for {:?}: {:?}", lowlevel, ret_type); match lowlevel { LowLevel::NumAdd => match ret_type { ValueType::I32 => self.code_builder.i32_add(), ValueType::I64 => self.code_builder.i64_add(), ValueType::F32 => self.code_builder.f32_add(), ValueType::F64 => self.code_builder.f64_add(), }, LowLevel::NumSub => match ret_type { ValueType::I32 => self.code_builder.i32_sub(), ValueType::I64 => self.code_builder.i64_sub(), ValueType::F32 => self.code_builder.f32_sub(), ValueType::F64 => self.code_builder.f64_sub(), }, LowLevel::NumMul => match ret_type { ValueType::I32 => self.code_builder.i32_mul(), ValueType::I64 => self.code_builder.i64_mul(), ValueType::F32 => self.code_builder.f32_mul(), ValueType::F64 => self.code_builder.f64_mul(), }, LowLevel::NumGt => match self.get_uniform_arg_type(args) { ValueType::I32 => self.code_builder.i32_gt_s(), ValueType::I64 => self.code_builder.i64_gt_s(), ValueType::F32 => self.code_builder.f32_gt(), ValueType::F64 => self.code_builder.f64_gt(), }, LowLevel::Eq => match self.get_uniform_arg_type(args) { ValueType::I32 => self.code_builder.i32_eq(), ValueType::I64 => self.code_builder.i64_eq(), ValueType::F32 => self.code_builder.f32_eq(), ValueType::F64 => self.code_builder.f64_eq(), }, LowLevel::NumNeg => match ret_type { ValueType::I32 => { self.code_builder.i32_const(-1); self.code_builder.i32_mul(); } ValueType::I64 => { self.code_builder.i64_const(-1); self.code_builder.i64_mul(); } ValueType::F32 => self.code_builder.f32_neg(), ValueType::F64 => self.code_builder.f64_neg(), }, LowLevel::NumAtan => { let name = match ret_type { ValueType::F32 => &bitcode::NUM_ATAN[FloatWidth::F32], ValueType::F64 => &bitcode::NUM_ATAN[FloatWidth::F64], _ => panic_ret_type(), }; self.call_imported_builtin(name, &[ret_type], Some(ret_type)); } _ => { return Err(format!("unsupported low-level op {:?}", lowlevel)); } }; Ok(()) } /// Get the ValueType for a set of arguments that are required to have the same type fn get_uniform_arg_type(&self, args: &'a [Symbol]) -> ValueType { let value_type = self.storage.get(&args[0]).value_type(); for arg in args.iter().skip(1) { debug_assert!(self.storage.get(arg).value_type() == value_type); } value_type } fn call_imported_builtin( &mut self, name: &'a str, arg_types: &[ValueType], ret_type: Option, ) { let (fn_index, linker_symbol_index) = match self.builtin_sym_index_map.get(name) { Some(sym_idx) => match &self.linker_symbols[*sym_idx] { SymInfo::Function(WasmObjectSymbol::Imported { index, .. }) => { (*index, *sym_idx as u32) } x => unreachable!("Invalid linker symbol for builtin {}: {:?}", name, x), }, None => { let mut param_types = Vec::with_capacity_in(arg_types.len(), self.env.arena); param_types.extend_from_slice(arg_types); let signature_index = self.module.types.insert(Signature { param_types, ret_type, }); let import_index = self.module.import.entries.len() as u32; let import = Import { module: BUILTINS_IMPORT_MODULE_NAME, name: name.to_string(), description: ImportDesc::Func { signature_index }, }; self.module.import.entries.push(import); let sym_idx = self.linker_symbols.len() as u32; let sym_info = SymInfo::Function(WasmObjectSymbol::Imported { flags: WASM_SYM_UNDEFINED, index: import_index, }); self.linker_symbols.push(sym_info); (import_index, sym_idx) } }; self.code_builder.call( fn_index, linker_symbol_index, arg_types.len(), ret_type.is_some(), ); } }