diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 4d00a21d79..c68d14529c 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -8,7 +8,8 @@ use roc_module::low_level::{LowLevel, LowLevelWrapperType}; use roc_module::symbol::{Interns, Symbol}; use roc_mono::code_gen_help::{CodeGenHelp, REFCOUNT_MAX}; use roc_mono::ir::{ - CallType, Expr, JoinPointId, ListLiteralElement, Literal, Proc, ProcLayout, Stmt, + BranchInfo, CallType, Expr, JoinPointId, ListLiteralElement, Literal, ModifyRc, Param, Proc, + ProcLayout, Stmt, }; use roc_mono::layout::{Builtin, Layout, LayoutIds, TagIdIntType, UnionLayout}; @@ -176,7 +177,7 @@ impl<'a> WasmBackend<'a> { self.start_proc(proc); - self.build_stmt(&proc.body); + self.stmt(&proc.body); self.finalize_proc(); self.reset(); @@ -240,6 +241,35 @@ impl<'a> WasmBackend<'a> { ***********************************************************/ + fn stmt(&mut self, stmt: &Stmt<'a>) { + match stmt { + Stmt::Let(_, _, _, _) => self.stmt_let(stmt), + + Stmt::Ret(sym) => self.stmt_ret(*sym), + + Stmt::Switch { + cond_symbol, + cond_layout, + branches, + default_branch, + ret_layout: _, + } => self.stmt_switch(*cond_symbol, cond_layout, branches, default_branch), + + Stmt::Join { + id, + parameters, + body, + remainder, + } => self.stmt_join(*id, parameters, body, remainder), + + Stmt::Jump(id, arguments) => self.stmt_jump(*id, arguments), + + Stmt::Refcounting(modify, following) => self.stmt_refcounting(modify, following), + + Stmt::RuntimeError(msg) => self.stmt_runtime_error(msg), + } + } + fn start_block(&mut self) { // Wasm blocks can have result types, but we don't use them. // You need the right type on the stack when you jump from an inner block to an outer one. @@ -259,7 +289,27 @@ impl<'a> WasmBackend<'a> { self.code_builder.end(); } - fn store_expr_value( + fn stmt_let(&mut self, stmt: &Stmt<'a>) { + let mut current_stmt = stmt; + while let Stmt::Let(sym, expr, layout, following) = current_stmt { + if DEBUG_LOG_SETTINGS.let_stmt_ir { + println!("let {:?} = {}", sym, expr.to_pretty(200)); // ignore `following`! Too confusing otherwise. + } + + let kind = match following { + Stmt::Ret(ret_sym) if *sym == *ret_sym => StoredValueKind::ReturnValue, + _ => StoredValueKind::Variable, + }; + + self.stmt_let_store_expr(*sym, layout, expr, kind); + + current_stmt = *following; + } + + self.stmt(current_stmt); + } + + fn stmt_let_store_expr( &mut self, sym: Symbol, layout: &Layout<'a>, @@ -268,7 +318,7 @@ impl<'a> WasmBackend<'a> { ) { let sym_storage = self.storage.allocate(*layout, sym, kind); - self.build_expr(&sym, expr, layout, &sym_storage); + self.expr(sym, expr, layout, &sym_storage); // If this value is stored in the VM stack, we need code_builder to track it // (since every instruction can change the VM stack) @@ -279,229 +329,207 @@ impl<'a> WasmBackend<'a> { } } - fn build_stmt(&mut self, stmt: &Stmt<'a>) { - match stmt { - Stmt::Let(_, _, _, _) => { - let mut current_stmt = stmt; - while let Stmt::Let(sym, expr, layout, following) = current_stmt { - if DEBUG_LOG_SETTINGS.let_stmt_ir { - println!("let {:?} = {}", sym, expr.to_pretty(200)); // ignore `following`! Too confusing otherwise. - } + fn stmt_ret(&mut self, sym: Symbol) { + use crate::storage::StoredValue::*; - let kind = match following { - Stmt::Ret(ret_sym) if *sym == *ret_sym => StoredValueKind::ReturnValue, - _ => StoredValueKind::Variable, - }; + let storage = self.storage.symbol_storage_map.get(&sym).unwrap(); - self.store_expr_value(*sym, layout, expr, kind); - - current_stmt = *following; - } - - self.build_stmt(current_stmt); - } - - Stmt::Ret(sym) => { - use crate::storage::StoredValue::*; - - let storage = self.storage.symbol_storage_map.get(sym).unwrap(); - - match storage { - StackMemory { - location, - size, - alignment_bytes, - .. - } => { - let (from_ptr, from_offset) = - location.local_and_offset(self.storage.stack_frame_pointer); - copy_memory( - &mut self.code_builder, - CopyMemoryConfig { - from_ptr, - from_offset, - to_ptr: LocalId(0), - to_offset: 0, - size: *size, - alignment_bytes: *alignment_bytes, - }, - ); - } - - _ => { - self.storage.load_symbols(&mut self.code_builder, &[*sym]); - - // If we have a return value, store it to the return variable - // This avoids complications with block result types when returning from nested blocks - if let Some(ret_var) = self.storage.return_var { - self.code_builder.set_local(ret_var); - } - } - } - // jump to the "stack frame pop" code at the end of the function - self.code_builder.br(self.block_depth - 1); - } - - Stmt::Switch { - cond_symbol, - cond_layout, - branches, - default_branch, - ret_layout: _, + match storage { + StackMemory { + location, + size, + alignment_bytes, + .. } => { - // NOTE currently implemented as a series of conditional jumps - // We may be able to improve this in the future with `Select` - // or `BrTable` - - // Ensure the condition value is not stored only in the VM stack - // Otherwise we can't reach it from inside the block - let cond_storage = self.storage.get(cond_symbol).to_owned(); - self.storage.ensure_value_has_local( + let (from_ptr, from_offset) = + location.local_and_offset(self.storage.stack_frame_pointer); + copy_memory( &mut self.code_builder, - *cond_symbol, - cond_storage, + CopyMemoryConfig { + from_ptr, + from_offset, + to_ptr: LocalId(0), + to_offset: 0, + size: *size, + alignment_bytes: *alignment_bytes, + }, ); - - // create a block for each branch except the default - for _ in 0..branches.len() { - self.start_block() - } - - let is_bool = matches!(cond_layout, Layout::Builtin(Builtin::Bool)); - let cond_type = WasmLayout::new(cond_layout).arg_types(CallConv::C)[0]; - - // then, we jump whenever the value under scrutiny is equal to the value of a branch - for (i, (value, _, _)) in branches.iter().enumerate() { - // put the cond_symbol on the top of the stack - self.storage - .load_symbols(&mut self.code_builder, &[*cond_symbol]); - - if is_bool { - // We already have a bool, don't need to compare against a const to get one - if *value == 0 { - self.code_builder.i32_eqz(); - } - } else { - match cond_type { - ValueType::I32 => { - self.code_builder.i32_const(*value as i32); - self.code_builder.i32_eq(); - } - ValueType::I64 => { - self.code_builder.i64_const(*value as i64); - self.code_builder.i64_eq(); - } - ValueType::F32 => { - self.code_builder.f32_const(f32::from_bits(*value as u32)); - self.code_builder.f32_eq(); - } - ValueType::F64 => { - self.code_builder.f64_const(f64::from_bits(*value as u64)); - self.code_builder.f64_eq(); - } - } - } - - // "break" out of `i` surrounding blocks - self.code_builder.br_if(i as u32); - } - - // if we never jumped because a value matched, we're in the default case - self.build_stmt(default_branch.1); - - // now put in the actual body of each branch in order - // (the first branch would have broken out of 1 block, - // hence we must generate its code first) - for (_, _, branch) in branches.iter() { - self.end_block(); - - self.build_stmt(branch); - } - } - Stmt::Join { - id, - parameters, - body, - remainder, - } => { - // make locals for join pointer parameters - let mut jp_param_storages = Vec::with_capacity_in(parameters.len(), self.env.arena); - for parameter in parameters.iter() { - let mut param_storage = self.storage.allocate( - parameter.layout, - parameter.symbol, - StoredValueKind::Variable, - ); - param_storage = self.storage.ensure_value_has_local( - &mut self.code_builder, - parameter.symbol, - param_storage, - ); - jp_param_storages.push(param_storage); - } - - self.start_block(); - - self.joinpoint_label_map - .insert(*id, (self.block_depth, jp_param_storages)); - - self.build_stmt(remainder); - - self.end_block(); - self.start_loop(); - - self.build_stmt(body); - - // ends the loop - self.end_block(); - } - Stmt::Jump(id, arguments) => { - let (target, param_storages) = self.joinpoint_label_map[id].clone(); - - for (arg_symbol, param_storage) in arguments.iter().zip(param_storages.iter()) { - let arg_storage = self.storage.get(arg_symbol).clone(); - self.storage.clone_value( - &mut self.code_builder, - param_storage, - &arg_storage, - *arg_symbol, - ); - } - - // jump - let levels = self.block_depth - target; - self.code_builder.br(levels); } - Stmt::Refcounting(modify, following) => { - let value = modify.get_symbol(); - let layout = self.storage.symbol_layouts[&value]; + _ => { + self.storage.load_symbols(&mut self.code_builder, &[sym]); - let ident_ids = self - .interns - .all_ident_ids - .get_mut(&self.env.module_id) - .unwrap(); - - let (rc_stmt, new_specializations) = self - .helper_proc_gen - .expand_refcount_stmt(ident_ids, layout, modify, *following); - - if false { - self.register_symbol_debug_names(); - println!("## rc_stmt:\n{}\n{:?}", rc_stmt.to_pretty(200), rc_stmt); + // If we have a return value, store it to the return variable + // This avoids complications with block result types when returning from nested blocks + if let Some(ret_var) = self.storage.return_var { + self.code_builder.set_local(ret_var); } - - // If any new specializations were created, register their symbol data - for spec in new_specializations.into_iter() { - self.register_helper_proc(spec); - } - - self.build_stmt(rc_stmt); } - - Stmt::RuntimeError(msg) => todo!("RuntimeError {:?}", msg), } + // jump to the "stack frame pop" code at the end of the function + self.code_builder.br(self.block_depth - 1); + } + + fn stmt_switch( + &mut self, + cond_symbol: Symbol, + cond_layout: &Layout<'a>, + branches: &'a [(u64, BranchInfo<'a>, Stmt<'a>)], + default_branch: &(BranchInfo<'a>, &'a Stmt<'a>), + ) { + // NOTE currently implemented as a series of conditional jumps + // We may be able to improve this in the future with `Select` + // or `BrTable` + + // Ensure the condition value is not stored only in the VM stack + // Otherwise we can't reach it from inside the block + let cond_storage = self.storage.get(&cond_symbol).to_owned(); + self.storage + .ensure_value_has_local(&mut self.code_builder, cond_symbol, cond_storage); + + // create a block for each branch except the default + for _ in 0..branches.len() { + self.start_block() + } + + let is_bool = matches!(cond_layout, Layout::Builtin(Builtin::Bool)); + let cond_type = WasmLayout::new(cond_layout).arg_types(CallConv::C)[0]; + + // then, we jump whenever the value under scrutiny is equal to the value of a branch + for (i, (value, _, _)) in branches.iter().enumerate() { + // put the cond_symbol on the top of the stack + self.storage + .load_symbols(&mut self.code_builder, &[cond_symbol]); + + if is_bool { + // We already have a bool, don't need to compare against a const to get one + if *value == 0 { + self.code_builder.i32_eqz(); + } + } else { + match cond_type { + ValueType::I32 => { + self.code_builder.i32_const(*value as i32); + self.code_builder.i32_eq(); + } + ValueType::I64 => { + self.code_builder.i64_const(*value as i64); + self.code_builder.i64_eq(); + } + ValueType::F32 => { + self.code_builder.f32_const(f32::from_bits(*value as u32)); + self.code_builder.f32_eq(); + } + ValueType::F64 => { + self.code_builder.f64_const(f64::from_bits(*value as u64)); + self.code_builder.f64_eq(); + } + } + } + + // "break" out of `i` surrounding blocks + self.code_builder.br_if(i as u32); + } + + // if we never jumped because a value matched, we're in the default case + self.stmt(default_branch.1); + + // now put in the actual body of each branch in order + // (the first branch would have broken out of 1 block, + // hence we must generate its code first) + for (_, _, branch) in branches.iter() { + self.end_block(); + + self.stmt(branch); + } + } + + fn stmt_join( + &mut self, + id: JoinPointId, + parameters: &'a [Param<'a>], + body: &'a Stmt<'a>, + remainder: &'a Stmt<'a>, + ) { + // make locals for join pointer parameters + let mut jp_param_storages = Vec::with_capacity_in(parameters.len(), self.env.arena); + for parameter in parameters.iter() { + let mut param_storage = self.storage.allocate( + parameter.layout, + parameter.symbol, + StoredValueKind::Variable, + ); + param_storage = self.storage.ensure_value_has_local( + &mut self.code_builder, + parameter.symbol, + param_storage, + ); + jp_param_storages.push(param_storage); + } + + self.start_block(); + + self.joinpoint_label_map + .insert(id, (self.block_depth, jp_param_storages)); + + self.stmt(remainder); + + self.end_block(); + self.start_loop(); + + self.stmt(body); + + // ends the loop + self.end_block(); + } + + fn stmt_jump(&mut self, id: JoinPointId, arguments: &'a [Symbol]) { + let (target, param_storages) = self.joinpoint_label_map[&id].clone(); + + for (arg_symbol, param_storage) in arguments.iter().zip(param_storages.iter()) { + let arg_storage = self.storage.get(arg_symbol).clone(); + self.storage.clone_value( + &mut self.code_builder, + param_storage, + &arg_storage, + *arg_symbol, + ); + } + + // jump + let levels = self.block_depth - target; + self.code_builder.br(levels); + } + + fn stmt_refcounting(&mut self, modify: &ModifyRc, following: &'a Stmt<'a>) { + let value = modify.get_symbol(); + let layout = self.storage.symbol_layouts[&value]; + + let ident_ids = self + .interns + .all_ident_ids + .get_mut(&self.env.module_id) + .unwrap(); + + let (rc_stmt, new_specializations) = self + .helper_proc_gen + .expand_refcount_stmt(ident_ids, layout, modify, following); + + if false { + self.register_symbol_debug_names(); + println!("## rc_stmt:\n{}\n{:?}", rc_stmt.to_pretty(200), rc_stmt); + } + + // If any new specializations were created, register their symbol data + for spec in new_specializations.into_iter() { + self.register_helper_proc(spec); + } + + self.stmt(rc_stmt); + } + + fn stmt_runtime_error(&mut self, msg: &'a str) { + todo!("RuntimeError {:?}", msg) } /********************************************************** @@ -510,212 +538,567 @@ impl<'a> WasmBackend<'a> { ***********************************************************/ - fn build_expr( - &mut self, - sym: &Symbol, - expr: &Expr<'a>, - layout: &Layout<'a>, - storage: &StoredValue, - ) { - let wasm_layout = WasmLayout::new(layout); + fn expr(&mut self, sym: Symbol, expr: &Expr<'a>, layout: &Layout<'a>, storage: &StoredValue) { match expr { - Expr::Literal(lit) => self.load_literal(lit, storage, *sym, layout), + Expr::Literal(lit) => self.expr_literal(lit, storage, sym, layout), Expr::Call(roc_mono::ir::Call { call_type, arguments, - }) => match call_type { - CallType::ByName { name: func_sym, .. } => { - // If this function is just a lowlevel wrapper, then inline it - if let LowLevelWrapperType::CanBeReplacedBy(lowlevel) = - LowLevelWrapperType::from_symbol(*func_sym) - { - return self.build_low_level( - lowlevel, - arguments, - *sym, - wasm_layout, - layout, - storage, - ); - } + }) => self.expr_call(call_type, arguments, sym, layout, storage), - let (param_types, ret_type) = self.storage.load_symbols_for_call( - self.env.arena, - &mut self.code_builder, - arguments, - *sym, - &wasm_layout, - CallConv::C, - ); - - for (roc_proc_index, (ir_sym, linker_sym_index)) in - self.proc_symbols.iter().enumerate() - { - let wasm_fn_index = self.fn_index_offset + roc_proc_index as u32; - if ir_sym == func_sym { - let num_wasm_args = param_types.len(); - let has_return_val = ret_type.is_some(); - self.code_builder.call( - wasm_fn_index, - *linker_sym_index, - num_wasm_args, - has_return_val, - ); - return; - } - } - - internal_error!( - "Could not find procedure {:?}\nKnown procedures: {:?}", - func_sym, - self.proc_symbols - ); - } - - CallType::LowLevel { op: lowlevel, .. } => { - self.build_low_level(*lowlevel, arguments, *sym, wasm_layout, layout, storage) - } - - x => todo!("call type {:?}", x), - }, - - Expr::Struct(fields) => self.create_struct(sym, layout, storage, fields), + Expr::Struct(fields) => self.expr_struct(sym, layout, storage, fields), Expr::StructAtIndex { index, field_layouts, structure, - } => { - self.storage.ensure_value_has_local( - &mut self.code_builder, - *sym, - storage.to_owned(), - ); - let (local_id, mut offset) = match self.storage.get(structure) { - StoredValue::StackMemory { location, .. } => { - location.local_and_offset(self.storage.stack_frame_pointer) - } + } => self.expr_struct_at_index(sym, storage, *index, field_layouts, *structure), - StoredValue::Local { - value_type, - local_id, - .. - } => { - debug_assert!(matches!(value_type, ValueType::I32)); - (*local_id, 0) - } + Expr::Array { elems, elem_layout } => self.expr_array(sym, storage, elem_layout, elems), - StoredValue::VirtualMachineStack { .. } => { - internal_error!("ensure_value_has_local didn't work") - } - }; - for field in field_layouts.iter().take(*index as usize) { - offset += field.stack_size(PTR_SIZE); - } - self.storage - .copy_value_from_memory(&mut self.code_builder, *sym, local_id, offset); - } - - Expr::Array { elems, elem_layout } => { - if let StoredValue::StackMemory { location, .. } = storage { - let size = elem_layout.stack_size(PTR_SIZE) * (elems.len() as u32); - - // Allocate heap space and store its address in a local variable - let heap_local_id = self.storage.create_anonymous_local(PTR_TYPE); - let heap_alignment = elem_layout.alignment_bytes(PTR_SIZE); - self.allocate_with_refcount(Some(size), heap_alignment, 1); - self.code_builder.set_local(heap_local_id); - - let (stack_local_id, stack_offset) = - location.local_and_offset(self.storage.stack_frame_pointer); - - // elements pointer - self.code_builder.get_local(stack_local_id); - self.code_builder.get_local(heap_local_id); - self.code_builder.i32_store(Align::Bytes4, stack_offset); - - // length of the list - self.code_builder.get_local(stack_local_id); - self.code_builder.i32_const(elems.len() as i32); - self.code_builder.i32_store(Align::Bytes4, stack_offset + 4); - - let mut elem_offset = 0; - - for (i, elem) in elems.iter().enumerate() { - let elem_sym = match elem { - ListLiteralElement::Literal(lit) => { - // This has no Symbol but our storage methods expect one. - // Let's just pretend it was defined in a `Let`. - let debug_name = format!("{:?}_{}", sym, i); - let elem_sym = self.create_symbol(&debug_name); - let expr = Expr::Literal(*lit); - - self.store_expr_value( - elem_sym, - elem_layout, - &expr, - StoredValueKind::Variable, - ); - - elem_sym - } - - ListLiteralElement::Symbol(elem_sym) => *elem_sym, - }; - - elem_offset += self.storage.copy_value_to_memory( - &mut self.code_builder, - heap_local_id, - elem_offset, - elem_sym, - ); - } - } else { - internal_error!("Unexpected storage for Array {:?}: {:?}", sym, storage) - } - } - - Expr::EmptyArray => { - if let StoredValue::StackMemory { location, .. } = storage { - let (local_id, offset) = - location.local_and_offset(self.storage.stack_frame_pointer); - - // This is a minor cheat. - // What we want to write to stack memory is { elements: null, length: 0 } - // But instead of two 32-bit stores, we can do a single 64-bit store. - self.code_builder.get_local(local_id); - self.code_builder.i64_const(0); - self.code_builder.i64_store(Align::Bytes4, offset); - } else { - internal_error!("Unexpected storage for {:?}", sym) - } - } + Expr::EmptyArray => self.expr_empty_array(sym, storage), Expr::Tag { tag_layout: union_layout, tag_id, arguments, .. - } => self.build_tag(union_layout, *tag_id, arguments, *sym, storage), + } => self.expr_tag(union_layout, *tag_id, arguments, sym, storage), Expr::GetTagId { structure, union_layout, - } => self.build_get_tag_id(*structure, union_layout, *sym, storage), + } => self.expr_get_tag_id(*structure, union_layout, sym, storage), Expr::UnionAtIndex { structure, tag_id, union_layout, index, - } => self.build_union_at_index(*structure, *tag_id, union_layout, *index, *sym), + } => self.expr_union_at_index(*structure, *tag_id, union_layout, *index, sym), _ => todo!("Expression `{}`", expr.to_pretty(100)), } } - fn build_tag( + /******************************************************************* + * Literals + *******************************************************************/ + + fn expr_literal( + &mut self, + lit: &Literal<'a>, + storage: &StoredValue, + sym: Symbol, + layout: &Layout<'a>, + ) { + let invalid_error = + || internal_error!("Literal value {:?} has invalid storage {:?}", lit, storage); + + match storage { + StoredValue::VirtualMachineStack { value_type, .. } => { + match (lit, value_type) { + (Literal::Float(x), ValueType::F64) => self.code_builder.f64_const(*x as f64), + (Literal::Float(x), ValueType::F32) => self.code_builder.f32_const(*x as f32), + (Literal::Int(x), ValueType::I64) => self.code_builder.i64_const(*x as i64), + (Literal::Int(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), + (Literal::Bool(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), + (Literal::Byte(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), + _ => invalid_error(), + }; + } + + StoredValue::StackMemory { location, .. } => { + let mut write128 = |lower_bits, upper_bits| { + let (local_id, offset) = + location.local_and_offset(self.storage.stack_frame_pointer); + + self.code_builder.get_local(local_id); + self.code_builder.i64_const(lower_bits); + self.code_builder.i64_store(Align::Bytes8, offset); + + self.code_builder.get_local(local_id); + self.code_builder.i64_const(upper_bits); + self.code_builder.i64_store(Align::Bytes8, offset + 8); + }; + + match lit { + Literal::Decimal(decimal) => { + let lower_bits = (decimal.0 & 0xffff_ffff_ffff_ffff) as i64; + let upper_bits = (decimal.0 >> 64) as i64; + write128(lower_bits, upper_bits); + } + Literal::Int(x) => { + let lower_bits = (*x & 0xffff_ffff_ffff_ffff) as i64; + let upper_bits = (*x >> 64) as i64; + write128(lower_bits, upper_bits); + } + Literal::Float(_) => { + // Also not implemented in LLVM backend (nor in Rust!) + todo!("f128 type"); + } + Literal::Str(string) => { + let (local_id, offset) = + location.local_and_offset(self.storage.stack_frame_pointer); + + let len = string.len(); + if len < 8 { + let mut stack_mem_bytes = [0; 8]; + stack_mem_bytes[0..len].clone_from_slice(string.as_bytes()); + stack_mem_bytes[7] = 0x80 | (len as u8); + let str_as_int = i64::from_le_bytes(stack_mem_bytes); + + // Write all 8 bytes at once using an i64 + // Str is normally two i32's, but in this special case, we can get away with fewer instructions + self.code_builder.get_local(local_id); + self.code_builder.i64_const(str_as_int); + self.code_builder.i64_store(Align::Bytes4, offset); + } else { + let (linker_sym_index, elements_addr) = + self.expr_literal_big_str(string, sym, layout); + + self.code_builder.get_local(local_id); + self.code_builder + .i32_const_mem_addr(elements_addr, linker_sym_index); + self.code_builder.i32_store(Align::Bytes4, offset); + + self.code_builder.get_local(local_id); + self.code_builder.i32_const(string.len() as i32); + self.code_builder.i32_store(Align::Bytes4, offset + 4); + }; + } + _ => invalid_error(), + } + } + + _ => invalid_error(), + }; + } + + /// Create a string constant in the module data section + /// Return the data we need for code gen: linker symbol index and memory address + fn expr_literal_big_str( + &mut self, + string: &'a str, + sym: Symbol, + layout: &Layout<'a>, + ) -> (u32, u32) { + // Place the segment at a 4-byte aligned offset + let segment_addr = round_up_to_alignment!(self.next_constant_addr, PTR_SIZE); + let elements_addr = segment_addr + PTR_SIZE; + let length_with_refcount = 4 + string.len(); + self.next_constant_addr = segment_addr + length_with_refcount as u32; + + let mut segment = DataSegment { + mode: DataMode::active_at(segment_addr), + init: Vec::with_capacity_in(length_with_refcount, self.env.arena), + }; + + // Prefix the string bytes with "infinite" refcount + let refcount_max_bytes: [u8; 4] = (REFCOUNT_MAX as i32).to_le_bytes(); + segment.init.extend_from_slice(&refcount_max_bytes); + segment.init.extend_from_slice(string.as_bytes()); + + let segment_index = self.module.data.append_segment(segment); + + // Generate linker symbol + let name = self + .layout_ids + .get(sym, layout) + .to_symbol_string(sym, self.interns); + + let linker_symbol = SymInfo::Data(DataSymbol::Defined { + flags: 0, + name: name.clone(), + segment_index, + segment_offset: 4, + size: string.len() as u32, + }); + + // Ensure the linker keeps the segment aligned when relocating it + self.module.linking.segment_info.push(LinkingSegment { + name, + alignment: Align::Bytes4, + flags: 0, + }); + + let linker_sym_index = self.module.linking.symbol_table.len(); + self.module.linking.symbol_table.push(linker_symbol); + + (linker_sym_index as u32, elements_addr) + } + + /******************************************************************* + * Call expressions + *******************************************************************/ + + fn expr_call( + &mut self, + call_type: &CallType<'a>, + arguments: &'a [Symbol], + ret_sym: Symbol, + ret_layout: &Layout<'a>, + ret_storage: &StoredValue, + ) { + match call_type { + CallType::ByName { name: func_sym, .. } => { + self.expr_call_by_name(*func_sym, arguments, ret_sym, ret_layout, ret_storage) + } + CallType::LowLevel { op: lowlevel, .. } => { + self.expr_call_low_level(*lowlevel, arguments, ret_sym, ret_layout, ret_storage) + } + + x => todo!("call type {:?}", x), + } + } + + fn expr_call_by_name( + &mut self, + func_sym: Symbol, + arguments: &'a [Symbol], + ret_sym: Symbol, + ret_layout: &Layout<'a>, + ret_storage: &StoredValue, + ) { + let wasm_layout = WasmLayout::new(ret_layout); + + // If this function is just a lowlevel wrapper, then inline it + if let LowLevelWrapperType::CanBeReplacedBy(lowlevel) = + LowLevelWrapperType::from_symbol(func_sym) + { + return self.expr_call_low_level(lowlevel, arguments, ret_sym, ret_layout, ret_storage); + } + + let (param_types, ret_type) = self.storage.load_symbols_for_call( + self.env.arena, + &mut self.code_builder, + arguments, + ret_sym, + &wasm_layout, + CallConv::C, + ); + + for (roc_proc_index, (ir_sym, linker_sym_index)) in self.proc_symbols.iter().enumerate() { + let wasm_fn_index = self.fn_index_offset + roc_proc_index as u32; + if *ir_sym == func_sym { + let num_wasm_args = param_types.len(); + let has_return_val = ret_type.is_some(); + self.code_builder.call( + wasm_fn_index, + *linker_sym_index, + num_wasm_args, + has_return_val, + ); + return; + } + } + + internal_error!( + "Could not find procedure {:?}\nKnown procedures: {:?}", + func_sym, + self.proc_symbols + ); + } + + fn expr_call_low_level( + &mut self, + lowlevel: LowLevel, + arguments: &'a [Symbol], + return_sym: Symbol, + mono_layout: &Layout<'a>, + storage: &StoredValue, + ) { + use LowLevel::*; + let return_layout = WasmLayout::new(mono_layout); + + match lowlevel { + Eq | NotEq => self.build_eq_or_neq( + lowlevel, + arguments, + return_sym, + return_layout, + mono_layout, + storage, + ), + PtrCast => { + // Don't want Zig calling convention when casting pointers. + self.storage.load_symbols(&mut self.code_builder, arguments); + } + Hash => todo!("Generic hash function generation"), + + // Almost all lowlevels take this branch, except for the special cases above + _ => { + // Load the arguments using Zig calling convention + let (param_types, ret_type) = self.storage.load_symbols_for_call( + self.env.arena, + &mut self.code_builder, + arguments, + return_sym, + &return_layout, + CallConv::Zig, + ); + + // Generate instructions OR decide which Zig function to call + let build_result = dispatch_low_level( + &mut self.code_builder, + &mut self.storage, + lowlevel, + arguments, + &return_layout, + mono_layout, + ); + + // Handle the result + use LowlevelBuildResult::*; + match build_result { + Done => {} + BuiltinCall(name) => { + self.expr_call_zig_builtin(name, param_types, ret_type); + } + NotImplemented => { + todo!("Low level operation {:?}", lowlevel) + } + } + } + } + } + + /// Generate a call instruction to a Zig builtin function. + /// And if we haven't seen it before, add an Import and linker data for it. + /// Zig calls use LLVM's "fast" calling convention rather than our usual C ABI. + fn expr_call_zig_builtin( + &mut self, + name: &'a str, + param_types: Vec<'a, ValueType>, + ret_type: Option, + ) { + let num_wasm_args = param_types.len(); + let has_return_val = ret_type.is_some(); + let fn_index = self.module.names.functions[name.as_bytes()]; + self.called_preload_fns.push(fn_index); + let linker_symbol_index = u32::MAX; + + self.code_builder + .call(fn_index, linker_symbol_index, num_wasm_args, has_return_val); + } + + /******************************************************************* + * Structs + *******************************************************************/ + + fn expr_struct( + &mut self, + sym: Symbol, + layout: &Layout<'a>, + storage: &StoredValue, + fields: &'a [Symbol], + ) { + if matches!(layout, Layout::Struct(_)) { + match storage { + StoredValue::StackMemory { location, size, .. } => { + if *size > 0 { + let (local_id, struct_offset) = + location.local_and_offset(self.storage.stack_frame_pointer); + let mut field_offset = struct_offset; + for field in fields.iter() { + field_offset += self.storage.copy_value_to_memory( + &mut self.code_builder, + local_id, + field_offset, + *field, + ); + } + } else { + // Zero-size struct. No code to emit. + // These values are purely conceptual, they only exist internally in the compiler + } + } + _ => internal_error!("Cannot create struct {:?} with storage {:?}", sym, storage), + }; + } else { + // Struct expression but not Struct layout => single element. Copy it. + let field_storage = self.storage.get(&fields[0]).to_owned(); + self.storage + .clone_value(&mut self.code_builder, storage, &field_storage, fields[0]); + } + } + + fn expr_struct_at_index( + &mut self, + sym: Symbol, + storage: &StoredValue, + index: u64, + field_layouts: &'a [Layout<'a>], + structure: Symbol, + ) { + self.storage + .ensure_value_has_local(&mut self.code_builder, sym, storage.to_owned()); + let (local_id, mut offset) = match self.storage.get(&structure) { + StoredValue::StackMemory { location, .. } => { + location.local_and_offset(self.storage.stack_frame_pointer) + } + + StoredValue::Local { + value_type, + local_id, + .. + } => { + debug_assert!(matches!(value_type, ValueType::I32)); + (*local_id, 0) + } + + StoredValue::VirtualMachineStack { .. } => { + internal_error!("ensure_value_has_local didn't work") + } + }; + for field in field_layouts.iter().take(index as usize) { + offset += field.stack_size(PTR_SIZE); + } + self.storage + .copy_value_from_memory(&mut self.code_builder, sym, local_id, offset); + } + + /******************************************************************* + * Heap allocation + *******************************************************************/ + + /// Allocate heap space and write an initial refcount + /// If the data size is known at compile time, pass it in comptime_data_size. + /// If size is only known at runtime, push *data* size to the VM stack first. + /// Leaves the *data* address on the VM stack + fn allocate_with_refcount( + &mut self, + comptime_data_size: Option, + alignment_bytes: u32, + initial_refcount: u32, + ) { + // Add extra bytes for the refcount + let extra_bytes = alignment_bytes.max(PTR_SIZE); + + if let Some(data_size) = comptime_data_size { + // Data size known at compile time and passed as an argument + self.code_builder + .i32_const((data_size + extra_bytes) as i32); + } else { + // Data size known only at runtime and is on top of VM stack + self.code_builder.i32_const(extra_bytes as i32); + self.code_builder.i32_add(); + } + + // Provide a constant for the alignment argument + self.code_builder.i32_const(alignment_bytes as i32); + + // Call the foreign function. (Zig and C calling conventions are the same for this signature) + let param_types = bumpalo::vec![in self.env.arena; ValueType::I32, ValueType::I32]; + let ret_type = Some(ValueType::I32); + self.expr_call_zig_builtin("roc_alloc", param_types, ret_type); + + // Save the allocation address to a temporary local variable + let local_id = self.storage.create_anonymous_local(ValueType::I32); + self.code_builder.tee_local(local_id); + + // Write the initial refcount + let refcount_offset = extra_bytes - PTR_SIZE; + let encoded_refcount = (initial_refcount as i32) - 1 + i32::MIN; + self.code_builder.i32_const(encoded_refcount); + self.code_builder.i32_store(Align::Bytes4, refcount_offset); + + // Put the data address on the VM stack + self.code_builder.get_local(local_id); + self.code_builder.i32_const(extra_bytes as i32); + self.code_builder.i32_add(); + } + + /******************************************************************* + * Arrays + *******************************************************************/ + + fn expr_array( + &mut self, + sym: Symbol, + storage: &StoredValue, + elem_layout: &Layout<'a>, + elems: &'a [ListLiteralElement<'a>], + ) { + if let StoredValue::StackMemory { location, .. } = storage { + let size = elem_layout.stack_size(PTR_SIZE) * (elems.len() as u32); + + // Allocate heap space and store its address in a local variable + let heap_local_id = self.storage.create_anonymous_local(PTR_TYPE); + let heap_alignment = elem_layout.alignment_bytes(PTR_SIZE); + self.allocate_with_refcount(Some(size), heap_alignment, 1); + self.code_builder.set_local(heap_local_id); + + let (stack_local_id, stack_offset) = + location.local_and_offset(self.storage.stack_frame_pointer); + + // elements pointer + self.code_builder.get_local(stack_local_id); + self.code_builder.get_local(heap_local_id); + self.code_builder.i32_store(Align::Bytes4, stack_offset); + + // length of the list + self.code_builder.get_local(stack_local_id); + self.code_builder.i32_const(elems.len() as i32); + self.code_builder.i32_store(Align::Bytes4, stack_offset + 4); + + let mut elem_offset = 0; + + for (i, elem) in elems.iter().enumerate() { + let elem_sym = match elem { + ListLiteralElement::Literal(lit) => { + // This has no Symbol but our storage methods expect one. + // Let's just pretend it was defined in a `Let`. + let debug_name = format!("{:?}_{}", sym, i); + let elem_sym = self.create_symbol(&debug_name); + let expr = Expr::Literal(*lit); + + self.stmt_let_store_expr( + elem_sym, + elem_layout, + &expr, + StoredValueKind::Variable, + ); + + elem_sym + } + + ListLiteralElement::Symbol(elem_sym) => *elem_sym, + }; + + elem_offset += self.storage.copy_value_to_memory( + &mut self.code_builder, + heap_local_id, + elem_offset, + elem_sym, + ); + } + } else { + internal_error!("Unexpected storage for Array {:?}: {:?}", sym, storage) + } + } + + fn expr_empty_array(&mut self, sym: Symbol, storage: &StoredValue) { + if let StoredValue::StackMemory { location, .. } = storage { + let (local_id, offset) = location.local_and_offset(self.storage.stack_frame_pointer); + + // This is a minor cheat. + // What we want to write to stack memory is { elements: null, length: 0 } + // But instead of two 32-bit stores, we can do a single 64-bit store. + self.code_builder.get_local(local_id); + self.code_builder.i64_const(0); + self.code_builder.i64_store(Align::Bytes4, offset); + } else { + internal_error!("Unexpected storage for {:?}", sym) + } + } + + /******************************************************************* + * Tag Unions + *******************************************************************/ + + fn expr_tag( &mut self, union_layout: &UnionLayout<'a>, tag_id: TagIdIntType, @@ -798,7 +1181,7 @@ impl<'a> WasmBackend<'a> { } } - fn build_get_tag_id( + fn expr_get_tag_id( &mut self, structure: Symbol, union_layout: &UnionLayout<'a>, @@ -879,7 +1262,7 @@ impl<'a> WasmBackend<'a> { } } - fn build_union_at_index( + fn expr_union_at_index( &mut self, structure: Symbol, tag_id: TagIdIntType, @@ -951,115 +1334,9 @@ impl<'a> WasmBackend<'a> { .copy_value_from_memory(&mut self.code_builder, symbol, from_ptr, from_offset); } - /// Allocate heap space and write an initial refcount - /// If the data size is known at compile time, pass it in comptime_data_size. - /// If size is only known at runtime, push *data* size to the VM stack first. - /// Leaves the *data* address on the VM stack - fn allocate_with_refcount( - &mut self, - comptime_data_size: Option, - alignment_bytes: u32, - initial_refcount: u32, - ) { - // Add extra bytes for the refcount - let extra_bytes = alignment_bytes.max(PTR_SIZE); - - if let Some(data_size) = comptime_data_size { - // Data size known at compile time and passed as an argument - self.code_builder - .i32_const((data_size + extra_bytes) as i32); - } else { - // Data size known only at runtime and is on top of VM stack - self.code_builder.i32_const(extra_bytes as i32); - self.code_builder.i32_add(); - } - - // Provide a constant for the alignment argument - self.code_builder.i32_const(alignment_bytes as i32); - - // Call the foreign function. (Zig and C calling conventions are the same for this signature) - let param_types = bumpalo::vec![in self.env.arena; ValueType::I32, ValueType::I32]; - let ret_type = Some(ValueType::I32); - self.call_zig_builtin("roc_alloc", param_types, ret_type); - - // Save the allocation address to a temporary local variable - let local_id = self.storage.create_anonymous_local(ValueType::I32); - self.code_builder.tee_local(local_id); - - // Write the initial refcount - let refcount_offset = extra_bytes - PTR_SIZE; - let encoded_refcount = (initial_refcount as i32) - 1 + i32::MIN; - self.code_builder.i32_const(encoded_refcount); - self.code_builder.i32_store(Align::Bytes4, refcount_offset); - - // Put the data address on the VM stack - self.code_builder.get_local(local_id); - self.code_builder.i32_const(extra_bytes as i32); - self.code_builder.i32_add(); - } - - fn build_low_level( - &mut self, - lowlevel: LowLevel, - arguments: &'a [Symbol], - return_sym: Symbol, - return_layout: WasmLayout, - mono_layout: &Layout<'a>, - storage: &StoredValue, - ) { - use LowLevel::*; - - match lowlevel { - Eq | NotEq => self.build_eq_or_neq( - lowlevel, - arguments, - return_sym, - return_layout, - mono_layout, - storage, - ), - PtrCast => { - // Don't want Zig calling convention when casting pointers. - self.storage.load_symbols(&mut self.code_builder, arguments); - } - Hash => todo!("Generic hash function generation"), - - // Almost all lowlevels take this branch, except for the special cases above - _ => { - // Load the arguments using Zig calling convention - let (param_types, ret_type) = self.storage.load_symbols_for_call( - self.env.arena, - &mut self.code_builder, - arguments, - return_sym, - &return_layout, - CallConv::Zig, - ); - - // Generate instructions OR decide which Zig function to call - let build_result = dispatch_low_level( - &mut self.code_builder, - &mut self.storage, - lowlevel, - arguments, - &return_layout, - mono_layout, - ); - - // Handle the result - use LowlevelBuildResult::*; - match build_result { - Done => {} - BuiltinCall(name) => { - self.call_zig_builtin(name, param_types, ret_type); - } - NotImplemented => { - todo!("Low level operation {:?}", lowlevel) - } - } - } - } - } + /******************************************************************* + * Equality + *******************************************************************/ fn build_eq_or_neq( &mut self, @@ -1091,7 +1368,7 @@ impl<'a> WasmBackend<'a> { &return_layout, CallConv::Zig, ); - self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type); + self.expr_call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type); if matches!(lowlevel, LowLevel::NotEq) { self.code_builder.i32_eqz(); } @@ -1274,210 +1551,11 @@ impl<'a> WasmBackend<'a> { // Generate Wasm code for the IR call expression let bool_layout = Layout::Builtin(Builtin::Bool); - self.build_expr( - &return_sym, + self.expr( + return_sym, self.env.arena.alloc(specialized_call_expr), &bool_layout, storage, ); } - - fn load_literal( - &mut self, - lit: &Literal<'a>, - storage: &StoredValue, - sym: Symbol, - layout: &Layout<'a>, - ) { - let invalid_error = - || internal_error!("Literal value {:?} has invalid storage {:?}", lit, storage); - - match storage { - StoredValue::VirtualMachineStack { value_type, .. } => { - match (lit, value_type) { - (Literal::Float(x), ValueType::F64) => self.code_builder.f64_const(*x as f64), - (Literal::Float(x), ValueType::F32) => self.code_builder.f32_const(*x as f32), - (Literal::Int(x), ValueType::I64) => self.code_builder.i64_const(*x as i64), - (Literal::Int(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), - (Literal::Bool(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), - (Literal::Byte(x), ValueType::I32) => self.code_builder.i32_const(*x as i32), - _ => invalid_error(), - }; - } - - StoredValue::StackMemory { location, .. } => { - let mut write128 = |lower_bits, upper_bits| { - let (local_id, offset) = - location.local_and_offset(self.storage.stack_frame_pointer); - - self.code_builder.get_local(local_id); - self.code_builder.i64_const(lower_bits); - self.code_builder.i64_store(Align::Bytes8, offset); - - self.code_builder.get_local(local_id); - self.code_builder.i64_const(upper_bits); - self.code_builder.i64_store(Align::Bytes8, offset + 8); - }; - - match lit { - Literal::Decimal(decimal) => { - let lower_bits = (decimal.0 & 0xffff_ffff_ffff_ffff) as i64; - let upper_bits = (decimal.0 >> 64) as i64; - write128(lower_bits, upper_bits); - } - Literal::Int(x) => { - let lower_bits = (*x & 0xffff_ffff_ffff_ffff) as i64; - let upper_bits = (*x >> 64) as i64; - write128(lower_bits, upper_bits); - } - Literal::Float(_) => { - // Also not implemented in LLVM backend (nor in Rust!) - todo!("f128 type"); - } - Literal::Str(string) => { - let (local_id, offset) = - location.local_and_offset(self.storage.stack_frame_pointer); - - let len = string.len(); - if len < 8 { - let mut stack_mem_bytes = [0; 8]; - stack_mem_bytes[0..len].clone_from_slice(string.as_bytes()); - stack_mem_bytes[7] = 0x80 | (len as u8); - let str_as_int = i64::from_le_bytes(stack_mem_bytes); - - // Write all 8 bytes at once using an i64 - // Str is normally two i32's, but in this special case, we can get away with fewer instructions - self.code_builder.get_local(local_id); - self.code_builder.i64_const(str_as_int); - self.code_builder.i64_store(Align::Bytes4, offset); - } else { - let (linker_sym_index, elements_addr) = - self.create_string_constant(string, sym, layout); - - self.code_builder.get_local(local_id); - self.code_builder - .i32_const_mem_addr(elements_addr, linker_sym_index); - self.code_builder.i32_store(Align::Bytes4, offset); - - self.code_builder.get_local(local_id); - self.code_builder.i32_const(string.len() as i32); - self.code_builder.i32_store(Align::Bytes4, offset + 4); - }; - } - _ => invalid_error(), - } - } - - _ => invalid_error(), - }; - } - - /// Create a string constant in the module data section - /// Return the data we need for code gen: linker symbol index and memory address - fn create_string_constant( - &mut self, - string: &'a str, - sym: Symbol, - layout: &Layout<'a>, - ) -> (u32, u32) { - // Place the segment at a 4-byte aligned offset - let segment_addr = round_up_to_alignment!(self.next_constant_addr, PTR_SIZE); - let elements_addr = segment_addr + PTR_SIZE; - let length_with_refcount = 4 + string.len(); - self.next_constant_addr = segment_addr + length_with_refcount as u32; - - let mut segment = DataSegment { - mode: DataMode::active_at(segment_addr), - init: Vec::with_capacity_in(length_with_refcount, self.env.arena), - }; - - // Prefix the string bytes with "infinite" refcount - let refcount_max_bytes: [u8; 4] = (REFCOUNT_MAX as i32).to_le_bytes(); - segment.init.extend_from_slice(&refcount_max_bytes); - segment.init.extend_from_slice(string.as_bytes()); - - let segment_index = self.module.data.append_segment(segment); - - // Generate linker symbol - let name = self - .layout_ids - .get(sym, layout) - .to_symbol_string(sym, self.interns); - - let linker_symbol = SymInfo::Data(DataSymbol::Defined { - flags: 0, - name: name.clone(), - segment_index, - segment_offset: 4, - size: string.len() as u32, - }); - - // Ensure the linker keeps the segment aligned when relocating it - self.module.linking.segment_info.push(LinkingSegment { - name, - alignment: Align::Bytes4, - flags: 0, - }); - - let linker_sym_index = self.module.linking.symbol_table.len(); - self.module.linking.symbol_table.push(linker_symbol); - - (linker_sym_index as u32, elements_addr) - } - - fn create_struct( - &mut self, - sym: &Symbol, - layout: &Layout<'a>, - storage: &StoredValue, - fields: &'a [Symbol], - ) { - if matches!(layout, Layout::Struct(_)) { - match storage { - StoredValue::StackMemory { location, size, .. } => { - if *size > 0 { - let (local_id, struct_offset) = - location.local_and_offset(self.storage.stack_frame_pointer); - let mut field_offset = struct_offset; - for field in fields.iter() { - field_offset += self.storage.copy_value_to_memory( - &mut self.code_builder, - local_id, - field_offset, - *field, - ); - } - } else { - // Zero-size struct. No code to emit. - // These values are purely conceptual, they only exist internally in the compiler - } - } - _ => internal_error!("Cannot create struct {:?} with storage {:?}", sym, storage), - }; - } else { - // Struct expression but not Struct layout => single element. Copy it. - let field_storage = self.storage.get(&fields[0]).to_owned(); - self.storage - .clone_value(&mut self.code_builder, storage, &field_storage, fields[0]); - } - } - - /// Generate a call instruction to a Zig builtin function. - /// And if we haven't seen it before, add an Import and linker data for it. - /// Zig calls use LLVM's "fast" calling convention rather than our usual C ABI. - fn call_zig_builtin( - &mut self, - name: &'a str, - param_types: Vec<'a, ValueType>, - ret_type: Option, - ) { - let num_wasm_args = param_types.len(); - let has_return_val = ret_type.is_some(); - let fn_index = self.module.names.functions[name.as_bytes()]; - self.called_preload_fns.push(fn_index); - let linker_symbol_index = u32::MAX; - - self.code_builder - .call(fn_index, linker_symbol_index, num_wasm_args, has_return_val); - } }