diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 5006814a59..529715115d 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -2,7 +2,7 @@ use bumpalo::{self, collections::Vec}; use std::fmt::Write; use code_builder::Align; -use roc_builtins::bitcode::IntWidth; +use roc_builtins::bitcode::{FloatWidth, IntWidth}; use roc_collections::all::MutMap; use roc_module::ident::Ident; use roc_module::low_level::{LowLevel, LowLevelWrapperType}; @@ -223,7 +223,7 @@ impl<'a> WasmBackend<'a> { let ret_layout = WasmLayout::new(&proc.ret_layout); let ret_type = match ret_layout.return_method() { - ReturnMethod::Primitive(ty) => Some(ty), + ReturnMethod::Primitive(ty, _) => Some(ty), ReturnMethod::NoReturnValue => None, ReturnMethod::WriteToPointerArg => { self.storage.arg_types.push(PTR_TYPE); @@ -281,95 +281,133 @@ impl<'a> WasmBackend<'a> { self.module.names.append_function(wasm_fn_index, name_bytes); } - /// Procs that are called from higher-order Zig builtins currently need a wrapper, - /// because the Zig compiler has bugs in its C calling convention for Wasm. - /// Whenever Zig fixes this, we should be able to remove the wrapper entirely. - pub fn build_zigcc_wrapper(&mut self, wrapper_idx: usize, inner_idx: usize) { - let ProcLookupData { - layout: proc_layout, - .. - } = self.proc_lookup[wrapper_idx]; - let ProcLayout { - arguments: arg_layouts, - result, - } = proc_layout; + /// Build a wrapper around a Roc procedure so that it can be called from our higher-order Zig builtins. + /// + /// The generic Zig code passes *pointers* to all of the argument values (e.g. on the heap in a List). + /// Numbers up to 64 bits are passed by value, so we need to load them from the provided pointer. + /// Everything else is passed by reference, so we can just pass the pointer through. + /// + /// NOTE: If the builtins expected the return pointer first and closure data last, we could eliminate the wrapper + /// when all args are pass-by-reference and non-zero size. But currently we need it to swap those around. + pub fn build_higher_order_wrapper(&mut self, wrapper_lookup_idx: usize, inner_lookup_idx: usize) { + use Align::*; + use ValueType::*; - let ret_sym = self.create_symbol("##ret"); - let ret_layout = WasmLayout::new(&result); - let is_stack_return = match ret_layout.return_method() { - ReturnMethod::Primitive(_) => false, - ReturnMethod::NoReturnValue => false, + let ProcLookupData { + name: wrapper_name, + layout: wrapper_proc_layout, + .. + } = self.proc_lookup[wrapper_lookup_idx]; + let wrapper_arg_layouts = wrapper_proc_layout.arguments; + + // Our convention is that the last arg of the wrapper is the heap return pointer + let heap_return_ptr_id = LocalId(wrapper_arg_layouts.len() as u32 - 1); + let inner_ret_layout = match wrapper_arg_layouts.last() { + Some(Layout::Boxed(inner)) => WasmLayout::new(inner), + x => internal_error!("Higher-order wrapper: invalid return layout {:?}", x), + }; + + let mut n_inner_wasm_args = 0; + let ret_type_and_size = match inner_ret_layout.return_method() { + ReturnMethod::Primitive(ty, size) => Some((ty, size)), + ReturnMethod::NoReturnValue => None, ReturnMethod::WriteToPointerArg => { - // Return variable must be at index 0 - self.storage - .allocate(result, ret_sym, StoredValueKind::Parameter); - true + n_inner_wasm_args += 1; + self.code_builder.get_local(heap_return_ptr_id); + None } }; - let mut arg_symbols = Vec::with_capacity_in(arg_layouts.len(), self.env.arena); - let mut frame_writes = Vec::with_capacity_in(arg_layouts.len() * 2, self.env.arena); + // Load all the arguments for the inner function + for (i, wrapper_arg) in wrapper_arg_layouts.iter().enumerate() { + let is_closure_data = i == 0; // Skip closure data (first for wrapper, last for inner) + let is_return_pointer = i == wrapper_arg_layouts.len() - 1; // Skip return pointer (may not be an arg for inner. And if it is, swaps from end to start) + if is_closure_data || is_return_pointer || wrapper_arg.stack_size(TARGET_INFO) == 0 { + continue; + } + n_inner_wasm_args += 1; - for (i, arg) in arg_layouts.iter().enumerate() { - let arg_name = format!("arg{}", i); - let symbol = self.create_symbol(&arg_name); - arg_symbols.push(symbol); - self.storage - .allocate_zigcc_arg(arg, symbol, &mut frame_writes); + // Load wrapper argument. They're all pointers. + self.code_builder.get_local(LocalId(i as u32)); + + // Dereference any primitive-valued arguments + match wrapper_arg { + Layout::Boxed(inner_arg) => match inner_arg { + Layout::Builtin(Builtin::Int(IntWidth::U8 | IntWidth::I8)) => { + self.code_builder.i32_load8_u(Bytes1, 0); + } + Layout::Builtin(Builtin::Int(IntWidth::U16 | IntWidth::I16)) => { + self.code_builder.i32_load16_u(Bytes2, 0); + } + Layout::Builtin(Builtin::Int(IntWidth::U32 | IntWidth::I32)) => { + self.code_builder.i32_load(Bytes4, 0); + } + Layout::Builtin(Builtin::Int(IntWidth::U64 | IntWidth::I64)) => { + self.code_builder.i64_load(Bytes8, 0); + } + Layout::Builtin(Builtin::Float(FloatWidth::F32)) => { + self.code_builder.f32_load(Bytes4, 0); + } + Layout::Builtin(Builtin::Float(FloatWidth::F64)) => { + self.code_builder.f64_load(Bytes8, 0); + } + Layout::Builtin(Builtin::Bool) => { + self.code_builder.i32_load8_u(Bytes1, 0); + } + _ => { + // Any other layout is a pointer, which we've already loaded. Nothing to do! + } + }, + x => internal_error!("Higher-order wrapper: expected a Box layout, got {:?}", x), + } } - if !is_stack_return { - // Local variables must come *after* the arguments - self.storage - .allocate(result, ret_sym, StoredValueKind::Variable); + // If the inner function has closure data, it's the last arg of the inner fn + let closure_data_layout = wrapper_arg_layouts[0]; + if closure_data_layout.stack_size(TARGET_INFO) > 0 { + self.code_builder.get_local(LocalId(0)); } - // Write structs to the stack frame - if !frame_writes.is_empty() { - let frame_ptr = self.storage.create_anonymous_local(PTR_TYPE); - self.storage.stack_frame_pointer = Some(frame_ptr); + // If the inner function returns a primitive, load the address to store it at + if ret_type_and_size.is_some() { + self.code_builder.get_local(heap_return_ptr_id); + } - for (zig_arg, value_type, offset) in frame_writes.into_iter() { - self.code_builder.get_local(frame_ptr); - self.code_builder.get_local(zig_arg); - if value_type == ValueType::I32 { - let align = Align::from_stack_offset(Align::Bytes4, offset); - self.code_builder.i32_store(align, offset); - } else { - let align = Align::from_stack_offset(Align::Bytes8, offset); - self.code_builder.i64_store(align, offset); + // Call the wrapped inner function + let lookup = &self.proc_lookup[inner_lookup_idx]; + let inner_wasm_fn_index = self.fn_index_offset + inner_lookup_idx as u32; + let has_return_val = ret_type_and_size.is_some(); + self.code_builder.call( + inner_wasm_fn_index, + lookup.linker_index, + n_inner_wasm_args, + has_return_val, + ); + + // If the inner function returns a primitive, store it to the address we loaded earlier + if let Some((ty, size)) = ret_type_and_size { + match (ty, size) { + (I64, 8) => self.code_builder.i64_store(Bytes8, 0), + (I32, 4) => self.code_builder.i32_store(Bytes4, 0), + (I32, 2) => self.code_builder.i32_store16(Bytes2, 0), + (I32, 1) => self.code_builder.i32_store8(Bytes1, 0), + (F32, 4) => self.code_builder.f32_store(Bytes4, 0), + (F64, 8) => self.code_builder.f64_store(Bytes8, 0), + _ => { + internal_error!("Cannot store {:?} with alignment of {:?}", ty, size); } } } - // Call the wrapped inner function - let ProcLookupData { linker_index, .. } = self.proc_lookup[inner_idx]; - let (param_types, ret_type) = self.storage.load_symbols_for_call( - self.env.arena, - &mut self.code_builder, - arg_symbols.into_bump_slice(), - ret_sym, - &ret_layout, - CallConv::C, - ); - let wasm_fn_index = self.fn_index_offset + inner_idx as u32; - let num_wasm_args = param_types.len(); - let has_return_val = ret_type.is_some(); - self.code_builder - .call(wasm_fn_index, linker_index, num_wasm_args, has_return_val); - - // Setup & teardown the stack frame - self.code_builder.build_fn_header_and_footer( - &self.storage.local_types, - self.storage.stack_frame_size, - self.storage.stack_frame_pointer, - ); + // Write empty function header (local variables array with zero length) + self.code_builder.build_fn_header_and_footer(&[], 0, None); self.module.add_function_signature(Signature { - param_types: self.storage.arg_types.clone(), - ret_type, + param_types: bumpalo::vec![in self.env.arena; I32; wrapper_arg_layouts.len()], + ret_type: None, }); + self.append_proc_debug_name(wrapper_name); self.reset(); } @@ -1568,9 +1606,7 @@ impl<'a> WasmBackend<'a> { let proc_index = self .proc_lookup .iter() - .position(|lookup| { - lookup.name == proc_symbol && lookup.layout.arguments[0] == layout - }) + .position(|lookup| lookup.name == proc_symbol && lookup.layout.arguments[0] == layout) .unwrap(); let wasm_fn_index = self.fn_index_offset + proc_index as u32; diff --git a/compiler/gen_wasm/src/layout.rs b/compiler/gen_wasm/src/layout.rs index 7cd8dffdaa..8476b45fbc 100644 --- a/compiler/gen_wasm/src/layout.rs +++ b/compiler/gen_wasm/src/layout.rs @@ -10,7 +10,7 @@ pub const BUILTINS_ZIG_VERSION: ZigVersion = ZigVersion::Zig8; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ReturnMethod { /// This layout is returned from a Wasm function "normally" as a Primitive - Primitive(ValueType), + Primitive(ValueType, u32), /// This layout is returned by writing to a pointer passed as the first argument WriteToPointerArg, /// This layout is empty and requires no return value or argument (e.g. refcount helpers) @@ -125,7 +125,7 @@ impl WasmLayout { pub fn return_method(&self) -> ReturnMethod { match self { - Self::Primitive(ty, _) => ReturnMethod::Primitive(*ty), + Self::Primitive(ty, size) => ReturnMethod::Primitive(*ty, *size), Self::StackMemory { size, .. } => { if *size == 0 { ReturnMethod::NoReturnValue diff --git a/compiler/gen_wasm/src/lib.rs b/compiler/gen_wasm/src/lib.rs index 762f3ef1dc..6689680ac4 100644 --- a/compiler/gen_wasm/src/lib.rs +++ b/compiler/gen_wasm/src/lib.rs @@ -174,12 +174,8 @@ pub fn build_module_without_wrapper<'a>( use ProcSource::*; match source { Roc => { /* already generated */ } - Helper => { - if let Some(proc) = helper_iter.next() { - backend.build_proc(proc); - } - } - HigherOrderWrapper(inner_idx) => backend.build_zigcc_wrapper(idx, *inner_idx), + Helper => backend.build_proc(helper_iter.next().unwrap()), + HigherOrderWrapper(inner_idx) => backend.build_higher_order_wrapper(idx, *inner_idx), } } diff --git a/compiler/gen_wasm/src/low_level.rs b/compiler/gen_wasm/src/low_level.rs index f7b47f089b..2fc2fcefc3 100644 --- a/compiler/gen_wasm/src/low_level.rs +++ b/compiler/gen_wasm/src/low_level.rs @@ -1042,15 +1042,13 @@ pub fn call_higher_order_lowlevel<'a>( cb.i32_const(elem_old_size as i32); cb.i32_const(elem_new_size as i32); - let (proc_index, lookup) = backend - .proc_lookup - .iter() - .enumerate() - .find(|(_, lookup)| lookup.name == Symbol::LIST_MAP) - .unwrap_or_else(|| panic!("Can't find {:?}", op)); - let wasm_fn_index = backend.fn_index_offset + proc_index as u32; - - cb.call(wasm_fn_index, lookup.linker_index, 9, false); + let num_wasm_args = 9; + let has_return_val = false; + backend.call_zig_builtin_after_loading_args( + bitcode::LIST_MAP, + num_wasm_args, + has_return_val, + ); } ListMap2 { .. } diff --git a/compiler/gen_wasm/src/storage.rs b/compiler/gen_wasm/src/storage.rs index acc9ffb298..e2d2429f00 100644 --- a/compiler/gen_wasm/src/storage.rs +++ b/compiler/gen_wasm/src/storage.rs @@ -211,86 +211,6 @@ impl<'a> Storage<'a> { storage } - /// Allocate storage for an argument that will be passed from Zig code - /// (Zig *should* implement the C calling convention here, but it has bugs we need to work around) - pub fn allocate_zigcc_arg( - &mut self, - layout: &Layout<'a>, - symbol: Symbol, - frame_writes: &mut Vec<(LocalId, ValueType, u32)>, - ) { - let wasm_layout = WasmLayout::new(layout); - self.symbol_layouts.insert(symbol, *layout); - - match wasm_layout { - WasmLayout::Primitive(value_type, size) => { - self.arg_types.push(value_type); - let storage = StoredValue::Local { - local_id: self.get_next_local_id(), - value_type, - size, - }; - self.symbol_storage_map.insert(symbol, storage); - } - - WasmLayout::StackMemory { - size, - alignment_bytes, - format, - } => { - // Stack frame offset where we'll write the Zig argument value - let location = if size == 0 { - // An argument with zero size is purely conceptual, and will not exist in Wasm. - // However we need to track the symbol, so we treat it like a local variable rather than an argument. - StackMemoryLocation::FrameOffset(0) - } else if size > 16 { - // For larger structs, Zig passes a pointer to stack memory in the Zig caller. That suits us. Just pass it through. - self.arg_types.push(PTR_TYPE); - StackMemoryLocation::PointerArg(self.get_next_local_id()) - } else { - // Zig passes small structs as primitive values, but Roc expects a pointer to stack memory - - // Generate the Zig-compatible argument(s) - let types: &[ValueType] = if size <= 4 { - &[ValueType::I32] - } else if size <= 8 { - &[ValueType::I64] - } else if size <= 12 { - &[ValueType::I64, ValueType::I32] - } else { - &[ValueType::I64, ValueType::I64] - }; - - // Allocate space in the stack frame, so we can pass a pointer to Roc - let mut offset: u32 = - round_up_to_alignment!(self.stack_frame_size as u32, alignment_bytes); - let loc = StackMemoryLocation::FrameOffset(offset); - self.stack_frame_size = (offset + size) as i32; - - // Make a note of which writes we need to do - // Note: We can't do the writes until after we know how many Wasm args we have. - // We need a LocalId for the stack frame pointer and it must come after the args. - for ty in types.iter() { - let local_id = LocalId(self.arg_types.len() as u32); - frame_writes.push((local_id, *ty, offset)); - offset += if *ty == ValueType::I32 { 4 } else { 8 }; - } - - loc - }; - - let storage = StoredValue::StackMemory { - location, - size, - alignment_bytes, - format, - }; - - self.symbol_storage_map.insert(symbol, storage); - } - } - } - /// Get storage info for a given symbol pub fn get(&self, sym: &Symbol) -> &StoredValue { self.symbol_storage_map.get(sym).unwrap_or_else(|| { @@ -483,7 +403,7 @@ impl<'a> Storage<'a> { let return_method = return_layout.return_method(); let return_type = match return_method { - ReturnMethod::Primitive(ty) => Some(ty), + ReturnMethod::Primitive(ty, _) => Some(ty), ReturnMethod::NoReturnValue => None, ReturnMethod::WriteToPointerArg => { wasm_arg_types.push(PTR_TYPE); @@ -577,7 +497,7 @@ impl<'a> Storage<'a> { size ); } - }; + } size } }