mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 06:44:46 +00:00
wasm: code gen for higher order wrapper function
This commit is contained in:
parent
51789f38c2
commit
5db3ae0227
5 changed files with 124 additions and 174 deletions
|
@ -2,7 +2,7 @@ use bumpalo::{self, collections::Vec};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
||||||
use code_builder::Align;
|
use code_builder::Align;
|
||||||
use roc_builtins::bitcode::IntWidth;
|
use roc_builtins::bitcode::{FloatWidth, IntWidth};
|
||||||
use roc_collections::all::MutMap;
|
use roc_collections::all::MutMap;
|
||||||
use roc_module::ident::Ident;
|
use roc_module::ident::Ident;
|
||||||
use roc_module::low_level::{LowLevel, LowLevelWrapperType};
|
use roc_module::low_level::{LowLevel, LowLevelWrapperType};
|
||||||
|
@ -223,7 +223,7 @@ impl<'a> WasmBackend<'a> {
|
||||||
let ret_layout = WasmLayout::new(&proc.ret_layout);
|
let ret_layout = WasmLayout::new(&proc.ret_layout);
|
||||||
|
|
||||||
let ret_type = match ret_layout.return_method() {
|
let ret_type = match ret_layout.return_method() {
|
||||||
ReturnMethod::Primitive(ty) => Some(ty),
|
ReturnMethod::Primitive(ty, _) => Some(ty),
|
||||||
ReturnMethod::NoReturnValue => None,
|
ReturnMethod::NoReturnValue => None,
|
||||||
ReturnMethod::WriteToPointerArg => {
|
ReturnMethod::WriteToPointerArg => {
|
||||||
self.storage.arg_types.push(PTR_TYPE);
|
self.storage.arg_types.push(PTR_TYPE);
|
||||||
|
@ -281,95 +281,133 @@ impl<'a> WasmBackend<'a> {
|
||||||
self.module.names.append_function(wasm_fn_index, name_bytes);
|
self.module.names.append_function(wasm_fn_index, name_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Procs that are called from higher-order Zig builtins currently need a wrapper,
|
/// Build a wrapper around a Roc procedure so that it can be called from our higher-order Zig builtins.
|
||||||
/// because the Zig compiler has bugs in its C calling convention for Wasm.
|
///
|
||||||
/// Whenever Zig fixes this, we should be able to remove the wrapper entirely.
|
/// The generic Zig code passes *pointers* to all of the argument values (e.g. on the heap in a List).
|
||||||
pub fn build_zigcc_wrapper(&mut self, wrapper_idx: usize, inner_idx: usize) {
|
/// Numbers up to 64 bits are passed by value, so we need to load them from the provided pointer.
|
||||||
let ProcLookupData {
|
/// Everything else is passed by reference, so we can just pass the pointer through.
|
||||||
layout: proc_layout,
|
///
|
||||||
..
|
/// NOTE: If the builtins expected the return pointer first and closure data last, we could eliminate the wrapper
|
||||||
} = self.proc_lookup[wrapper_idx];
|
/// when all args are pass-by-reference and non-zero size. But currently we need it to swap those around.
|
||||||
let ProcLayout {
|
pub fn build_higher_order_wrapper(&mut self, wrapper_lookup_idx: usize, inner_lookup_idx: usize) {
|
||||||
arguments: arg_layouts,
|
use Align::*;
|
||||||
result,
|
use ValueType::*;
|
||||||
} = proc_layout;
|
|
||||||
|
|
||||||
let ret_sym = self.create_symbol("##ret");
|
let ProcLookupData {
|
||||||
let ret_layout = WasmLayout::new(&result);
|
name: wrapper_name,
|
||||||
let is_stack_return = match ret_layout.return_method() {
|
layout: wrapper_proc_layout,
|
||||||
ReturnMethod::Primitive(_) => false,
|
..
|
||||||
ReturnMethod::NoReturnValue => false,
|
} = self.proc_lookup[wrapper_lookup_idx];
|
||||||
|
let wrapper_arg_layouts = wrapper_proc_layout.arguments;
|
||||||
|
|
||||||
|
// Our convention is that the last arg of the wrapper is the heap return pointer
|
||||||
|
let heap_return_ptr_id = LocalId(wrapper_arg_layouts.len() as u32 - 1);
|
||||||
|
let inner_ret_layout = match wrapper_arg_layouts.last() {
|
||||||
|
Some(Layout::Boxed(inner)) => WasmLayout::new(inner),
|
||||||
|
x => internal_error!("Higher-order wrapper: invalid return layout {:?}", x),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut n_inner_wasm_args = 0;
|
||||||
|
let ret_type_and_size = match inner_ret_layout.return_method() {
|
||||||
|
ReturnMethod::Primitive(ty, size) => Some((ty, size)),
|
||||||
|
ReturnMethod::NoReturnValue => None,
|
||||||
ReturnMethod::WriteToPointerArg => {
|
ReturnMethod::WriteToPointerArg => {
|
||||||
// Return variable must be at index 0
|
n_inner_wasm_args += 1;
|
||||||
self.storage
|
self.code_builder.get_local(heap_return_ptr_id);
|
||||||
.allocate(result, ret_sym, StoredValueKind::Parameter);
|
None
|
||||||
true
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut arg_symbols = Vec::with_capacity_in(arg_layouts.len(), self.env.arena);
|
// Load all the arguments for the inner function
|
||||||
let mut frame_writes = Vec::with_capacity_in(arg_layouts.len() * 2, self.env.arena);
|
for (i, wrapper_arg) in wrapper_arg_layouts.iter().enumerate() {
|
||||||
|
let is_closure_data = i == 0; // Skip closure data (first for wrapper, last for inner)
|
||||||
for (i, arg) in arg_layouts.iter().enumerate() {
|
let is_return_pointer = i == wrapper_arg_layouts.len() - 1; // Skip return pointer (may not be an arg for inner. And if it is, swaps from end to start)
|
||||||
let arg_name = format!("arg{}", i);
|
if is_closure_data || is_return_pointer || wrapper_arg.stack_size(TARGET_INFO) == 0 {
|
||||||
let symbol = self.create_symbol(&arg_name);
|
continue;
|
||||||
arg_symbols.push(symbol);
|
|
||||||
self.storage
|
|
||||||
.allocate_zigcc_arg(arg, symbol, &mut frame_writes);
|
|
||||||
}
|
}
|
||||||
|
n_inner_wasm_args += 1;
|
||||||
|
|
||||||
if !is_stack_return {
|
// Load wrapper argument. They're all pointers.
|
||||||
// Local variables must come *after* the arguments
|
self.code_builder.get_local(LocalId(i as u32));
|
||||||
self.storage
|
|
||||||
.allocate(result, ret_sym, StoredValueKind::Variable);
|
// Dereference any primitive-valued arguments
|
||||||
|
match wrapper_arg {
|
||||||
|
Layout::Boxed(inner_arg) => match inner_arg {
|
||||||
|
Layout::Builtin(Builtin::Int(IntWidth::U8 | IntWidth::I8)) => {
|
||||||
|
self.code_builder.i32_load8_u(Bytes1, 0);
|
||||||
}
|
}
|
||||||
|
Layout::Builtin(Builtin::Int(IntWidth::U16 | IntWidth::I16)) => {
|
||||||
// Write structs to the stack frame
|
self.code_builder.i32_load16_u(Bytes2, 0);
|
||||||
if !frame_writes.is_empty() {
|
}
|
||||||
let frame_ptr = self.storage.create_anonymous_local(PTR_TYPE);
|
Layout::Builtin(Builtin::Int(IntWidth::U32 | IntWidth::I32)) => {
|
||||||
self.storage.stack_frame_pointer = Some(frame_ptr);
|
self.code_builder.i32_load(Bytes4, 0);
|
||||||
|
}
|
||||||
for (zig_arg, value_type, offset) in frame_writes.into_iter() {
|
Layout::Builtin(Builtin::Int(IntWidth::U64 | IntWidth::I64)) => {
|
||||||
self.code_builder.get_local(frame_ptr);
|
self.code_builder.i64_load(Bytes8, 0);
|
||||||
self.code_builder.get_local(zig_arg);
|
}
|
||||||
if value_type == ValueType::I32 {
|
Layout::Builtin(Builtin::Float(FloatWidth::F32)) => {
|
||||||
let align = Align::from_stack_offset(Align::Bytes4, offset);
|
self.code_builder.f32_load(Bytes4, 0);
|
||||||
self.code_builder.i32_store(align, offset);
|
}
|
||||||
} else {
|
Layout::Builtin(Builtin::Float(FloatWidth::F64)) => {
|
||||||
let align = Align::from_stack_offset(Align::Bytes8, offset);
|
self.code_builder.f64_load(Bytes8, 0);
|
||||||
self.code_builder.i64_store(align, offset);
|
}
|
||||||
|
Layout::Builtin(Builtin::Bool) => {
|
||||||
|
self.code_builder.i32_load8_u(Bytes1, 0);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Any other layout is a pointer, which we've already loaded. Nothing to do!
|
||||||
|
}
|
||||||
|
},
|
||||||
|
x => internal_error!("Higher-order wrapper: expected a Box layout, got {:?}", x),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the inner function has closure data, it's the last arg of the inner fn
|
||||||
|
let closure_data_layout = wrapper_arg_layouts[0];
|
||||||
|
if closure_data_layout.stack_size(TARGET_INFO) > 0 {
|
||||||
|
self.code_builder.get_local(LocalId(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the inner function returns a primitive, load the address to store it at
|
||||||
|
if ret_type_and_size.is_some() {
|
||||||
|
self.code_builder.get_local(heap_return_ptr_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the wrapped inner function
|
// Call the wrapped inner function
|
||||||
let ProcLookupData { linker_index, .. } = self.proc_lookup[inner_idx];
|
let lookup = &self.proc_lookup[inner_lookup_idx];
|
||||||
let (param_types, ret_type) = self.storage.load_symbols_for_call(
|
let inner_wasm_fn_index = self.fn_index_offset + inner_lookup_idx as u32;
|
||||||
self.env.arena,
|
let has_return_val = ret_type_and_size.is_some();
|
||||||
&mut self.code_builder,
|
self.code_builder.call(
|
||||||
arg_symbols.into_bump_slice(),
|
inner_wasm_fn_index,
|
||||||
ret_sym,
|
lookup.linker_index,
|
||||||
&ret_layout,
|
n_inner_wasm_args,
|
||||||
CallConv::C,
|
has_return_val,
|
||||||
);
|
);
|
||||||
let wasm_fn_index = self.fn_index_offset + inner_idx as u32;
|
|
||||||
let num_wasm_args = param_types.len();
|
|
||||||
let has_return_val = ret_type.is_some();
|
|
||||||
self.code_builder
|
|
||||||
.call(wasm_fn_index, linker_index, num_wasm_args, has_return_val);
|
|
||||||
|
|
||||||
// Setup & teardown the stack frame
|
// If the inner function returns a primitive, store it to the address we loaded earlier
|
||||||
self.code_builder.build_fn_header_and_footer(
|
if let Some((ty, size)) = ret_type_and_size {
|
||||||
&self.storage.local_types,
|
match (ty, size) {
|
||||||
self.storage.stack_frame_size,
|
(I64, 8) => self.code_builder.i64_store(Bytes8, 0),
|
||||||
self.storage.stack_frame_pointer,
|
(I32, 4) => self.code_builder.i32_store(Bytes4, 0),
|
||||||
);
|
(I32, 2) => self.code_builder.i32_store16(Bytes2, 0),
|
||||||
|
(I32, 1) => self.code_builder.i32_store8(Bytes1, 0),
|
||||||
|
(F32, 4) => self.code_builder.f32_store(Bytes4, 0),
|
||||||
|
(F64, 8) => self.code_builder.f64_store(Bytes8, 0),
|
||||||
|
_ => {
|
||||||
|
internal_error!("Cannot store {:?} with alignment of {:?}", ty, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write empty function header (local variables array with zero length)
|
||||||
|
self.code_builder.build_fn_header_and_footer(&[], 0, None);
|
||||||
|
|
||||||
self.module.add_function_signature(Signature {
|
self.module.add_function_signature(Signature {
|
||||||
param_types: self.storage.arg_types.clone(),
|
param_types: bumpalo::vec![in self.env.arena; I32; wrapper_arg_layouts.len()],
|
||||||
ret_type,
|
ret_type: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
self.append_proc_debug_name(wrapper_name);
|
||||||
self.reset();
|
self.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1568,9 +1606,7 @@ impl<'a> WasmBackend<'a> {
|
||||||
let proc_index = self
|
let proc_index = self
|
||||||
.proc_lookup
|
.proc_lookup
|
||||||
.iter()
|
.iter()
|
||||||
.position(|lookup| {
|
.position(|lookup| lookup.name == proc_symbol && lookup.layout.arguments[0] == layout)
|
||||||
lookup.name == proc_symbol && lookup.layout.arguments[0] == layout
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let wasm_fn_index = self.fn_index_offset + proc_index as u32;
|
let wasm_fn_index = self.fn_index_offset + proc_index as u32;
|
||||||
|
|
|
@ -10,7 +10,7 @@ pub const BUILTINS_ZIG_VERSION: ZigVersion = ZigVersion::Zig8;
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub enum ReturnMethod {
|
pub enum ReturnMethod {
|
||||||
/// This layout is returned from a Wasm function "normally" as a Primitive
|
/// This layout is returned from a Wasm function "normally" as a Primitive
|
||||||
Primitive(ValueType),
|
Primitive(ValueType, u32),
|
||||||
/// This layout is returned by writing to a pointer passed as the first argument
|
/// This layout is returned by writing to a pointer passed as the first argument
|
||||||
WriteToPointerArg,
|
WriteToPointerArg,
|
||||||
/// This layout is empty and requires no return value or argument (e.g. refcount helpers)
|
/// This layout is empty and requires no return value or argument (e.g. refcount helpers)
|
||||||
|
@ -125,7 +125,7 @@ impl WasmLayout {
|
||||||
|
|
||||||
pub fn return_method(&self) -> ReturnMethod {
|
pub fn return_method(&self) -> ReturnMethod {
|
||||||
match self {
|
match self {
|
||||||
Self::Primitive(ty, _) => ReturnMethod::Primitive(*ty),
|
Self::Primitive(ty, size) => ReturnMethod::Primitive(*ty, *size),
|
||||||
Self::StackMemory { size, .. } => {
|
Self::StackMemory { size, .. } => {
|
||||||
if *size == 0 {
|
if *size == 0 {
|
||||||
ReturnMethod::NoReturnValue
|
ReturnMethod::NoReturnValue
|
||||||
|
|
|
@ -174,12 +174,8 @@ pub fn build_module_without_wrapper<'a>(
|
||||||
use ProcSource::*;
|
use ProcSource::*;
|
||||||
match source {
|
match source {
|
||||||
Roc => { /* already generated */ }
|
Roc => { /* already generated */ }
|
||||||
Helper => {
|
Helper => backend.build_proc(helper_iter.next().unwrap()),
|
||||||
if let Some(proc) = helper_iter.next() {
|
HigherOrderWrapper(inner_idx) => backend.build_higher_order_wrapper(idx, *inner_idx),
|
||||||
backend.build_proc(proc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
HigherOrderWrapper(inner_idx) => backend.build_zigcc_wrapper(idx, *inner_idx),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1042,15 +1042,13 @@ pub fn call_higher_order_lowlevel<'a>(
|
||||||
cb.i32_const(elem_old_size as i32);
|
cb.i32_const(elem_old_size as i32);
|
||||||
cb.i32_const(elem_new_size as i32);
|
cb.i32_const(elem_new_size as i32);
|
||||||
|
|
||||||
let (proc_index, lookup) = backend
|
let num_wasm_args = 9;
|
||||||
.proc_lookup
|
let has_return_val = false;
|
||||||
.iter()
|
backend.call_zig_builtin_after_loading_args(
|
||||||
.enumerate()
|
bitcode::LIST_MAP,
|
||||||
.find(|(_, lookup)| lookup.name == Symbol::LIST_MAP)
|
num_wasm_args,
|
||||||
.unwrap_or_else(|| panic!("Can't find {:?}", op));
|
has_return_val,
|
||||||
let wasm_fn_index = backend.fn_index_offset + proc_index as u32;
|
);
|
||||||
|
|
||||||
cb.call(wasm_fn_index, lookup.linker_index, 9, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ListMap2 { .. }
|
ListMap2 { .. }
|
||||||
|
|
|
@ -211,86 +211,6 @@ impl<'a> Storage<'a> {
|
||||||
storage
|
storage
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allocate storage for an argument that will be passed from Zig code
|
|
||||||
/// (Zig *should* implement the C calling convention here, but it has bugs we need to work around)
|
|
||||||
pub fn allocate_zigcc_arg(
|
|
||||||
&mut self,
|
|
||||||
layout: &Layout<'a>,
|
|
||||||
symbol: Symbol,
|
|
||||||
frame_writes: &mut Vec<(LocalId, ValueType, u32)>,
|
|
||||||
) {
|
|
||||||
let wasm_layout = WasmLayout::new(layout);
|
|
||||||
self.symbol_layouts.insert(symbol, *layout);
|
|
||||||
|
|
||||||
match wasm_layout {
|
|
||||||
WasmLayout::Primitive(value_type, size) => {
|
|
||||||
self.arg_types.push(value_type);
|
|
||||||
let storage = StoredValue::Local {
|
|
||||||
local_id: self.get_next_local_id(),
|
|
||||||
value_type,
|
|
||||||
size,
|
|
||||||
};
|
|
||||||
self.symbol_storage_map.insert(symbol, storage);
|
|
||||||
}
|
|
||||||
|
|
||||||
WasmLayout::StackMemory {
|
|
||||||
size,
|
|
||||||
alignment_bytes,
|
|
||||||
format,
|
|
||||||
} => {
|
|
||||||
// Stack frame offset where we'll write the Zig argument value
|
|
||||||
let location = if size == 0 {
|
|
||||||
// An argument with zero size is purely conceptual, and will not exist in Wasm.
|
|
||||||
// However we need to track the symbol, so we treat it like a local variable rather than an argument.
|
|
||||||
StackMemoryLocation::FrameOffset(0)
|
|
||||||
} else if size > 16 {
|
|
||||||
// For larger structs, Zig passes a pointer to stack memory in the Zig caller. That suits us. Just pass it through.
|
|
||||||
self.arg_types.push(PTR_TYPE);
|
|
||||||
StackMemoryLocation::PointerArg(self.get_next_local_id())
|
|
||||||
} else {
|
|
||||||
// Zig passes small structs as primitive values, but Roc expects a pointer to stack memory
|
|
||||||
|
|
||||||
// Generate the Zig-compatible argument(s)
|
|
||||||
let types: &[ValueType] = if size <= 4 {
|
|
||||||
&[ValueType::I32]
|
|
||||||
} else if size <= 8 {
|
|
||||||
&[ValueType::I64]
|
|
||||||
} else if size <= 12 {
|
|
||||||
&[ValueType::I64, ValueType::I32]
|
|
||||||
} else {
|
|
||||||
&[ValueType::I64, ValueType::I64]
|
|
||||||
};
|
|
||||||
|
|
||||||
// Allocate space in the stack frame, so we can pass a pointer to Roc
|
|
||||||
let mut offset: u32 =
|
|
||||||
round_up_to_alignment!(self.stack_frame_size as u32, alignment_bytes);
|
|
||||||
let loc = StackMemoryLocation::FrameOffset(offset);
|
|
||||||
self.stack_frame_size = (offset + size) as i32;
|
|
||||||
|
|
||||||
// Make a note of which writes we need to do
|
|
||||||
// Note: We can't do the writes until after we know how many Wasm args we have.
|
|
||||||
// We need a LocalId for the stack frame pointer and it must come after the args.
|
|
||||||
for ty in types.iter() {
|
|
||||||
let local_id = LocalId(self.arg_types.len() as u32);
|
|
||||||
frame_writes.push((local_id, *ty, offset));
|
|
||||||
offset += if *ty == ValueType::I32 { 4 } else { 8 };
|
|
||||||
}
|
|
||||||
|
|
||||||
loc
|
|
||||||
};
|
|
||||||
|
|
||||||
let storage = StoredValue::StackMemory {
|
|
||||||
location,
|
|
||||||
size,
|
|
||||||
alignment_bytes,
|
|
||||||
format,
|
|
||||||
};
|
|
||||||
|
|
||||||
self.symbol_storage_map.insert(symbol, storage);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get storage info for a given symbol
|
/// Get storage info for a given symbol
|
||||||
pub fn get(&self, sym: &Symbol) -> &StoredValue {
|
pub fn get(&self, sym: &Symbol) -> &StoredValue {
|
||||||
self.symbol_storage_map.get(sym).unwrap_or_else(|| {
|
self.symbol_storage_map.get(sym).unwrap_or_else(|| {
|
||||||
|
@ -483,7 +403,7 @@ impl<'a> Storage<'a> {
|
||||||
|
|
||||||
let return_method = return_layout.return_method();
|
let return_method = return_layout.return_method();
|
||||||
let return_type = match return_method {
|
let return_type = match return_method {
|
||||||
ReturnMethod::Primitive(ty) => Some(ty),
|
ReturnMethod::Primitive(ty, _) => Some(ty),
|
||||||
ReturnMethod::NoReturnValue => None,
|
ReturnMethod::NoReturnValue => None,
|
||||||
ReturnMethod::WriteToPointerArg => {
|
ReturnMethod::WriteToPointerArg => {
|
||||||
wasm_arg_types.push(PTR_TYPE);
|
wasm_arg_types.push(PTR_TYPE);
|
||||||
|
@ -577,7 +497,7 @@ impl<'a> Storage<'a> {
|
||||||
size
|
size
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
size
|
size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue