From 7b96e953ba3d7f381bc874e2cca2532400abff13 Mon Sep 17 00:00:00 2001 From: Brian Carroll Date: Thu, 7 Apr 2022 12:05:53 +0100 Subject: [PATCH] wasm: Get List.map2 working --- compiler/gen_wasm/src/backend.rs | 10 +-- compiler/gen_wasm/src/low_level.rs | 79 ++++++++++++++++++--- compiler/mono/src/code_gen_help/mod.rs | 8 +-- compiler/mono/src/code_gen_help/refcount.rs | 4 +- compiler/test_gen/src/gen_list.rs | 6 +- 5 files changed, 86 insertions(+), 21 deletions(-) diff --git a/compiler/gen_wasm/src/backend.rs b/compiler/gen_wasm/src/backend.rs index 19b52cef54..0ad99742cb 100644 --- a/compiler/gen_wasm/src/backend.rs +++ b/compiler/gen_wasm/src/backend.rs @@ -1616,8 +1616,9 @@ impl<'a> WasmBackend<'a> { ); } - /// Generate a refcount increment procedure and return its Wasm function index - pub fn gen_refcount_inc_for_zig(&mut self, layout: Layout<'a>) -> u32 { + /// Generate a refcount helper procedure and return a pointer (table index) to it + /// This allows it to be indirectly called from Zig code + pub fn get_refcount_fn_ptr(&mut self, layout: Layout<'a>, op: HelperOp) -> i32 { let ident_ids = self .interns .all_ident_ids @@ -1626,7 +1627,7 @@ impl<'a> WasmBackend<'a> { let (proc_symbol, new_specializations) = self .helper_proc_gen - .gen_refcount_inc_proc(ident_ids, layout); + .gen_refcount_proc(ident_ids, layout, op); // If any new specializations were created, register their symbol data for (spec_sym, spec_layout) in new_specializations.into_iter() { @@ -1639,6 +1640,7 @@ impl<'a> WasmBackend<'a> { .position(|lookup| lookup.name == proc_symbol && lookup.layout.arguments[0] == layout) .unwrap(); - self.fn_index_offset + proc_index as u32 + let wasm_fn_index = self.fn_index_offset + proc_index as u32; + self.get_fn_table_index(wasm_fn_index) } } diff --git a/compiler/gen_wasm/src/low_level.rs b/compiler/gen_wasm/src/low_level.rs index b376f613b5..b396a83adc 100644 --- a/compiler/gen_wasm/src/low_level.rs +++ b/compiler/gen_wasm/src/low_level.rs @@ -3,6 +3,7 @@ use roc_builtins::bitcode::{self, FloatWidth, IntWidth}; use roc_error_macros::internal_error; use roc_module::low_level::LowLevel; use roc_module::symbol::Symbol; +use roc_mono::code_gen_help::HelperOp; use roc_mono::ir::{HigherOrderLowLevel, PassedFunction, ProcLayout}; use roc_mono::layout::{Builtin, Layout, UnionLayout}; use roc_mono::low_level::HigherOrder; @@ -1014,22 +1015,20 @@ pub fn call_higher_order_lowlevel<'a>( }; let wrapper_fn_idx = backend.register_helper_proc(wrapper_sym, wrapper_layout, source); - let inc_fn_idx = backend.gen_refcount_inc_for_zig(closure_data_layout); - let wrapper_fn_ptr = backend.get_fn_table_index(wrapper_fn_idx); - let inc_fn_ptr = backend.get_fn_table_index(inc_fn_idx); + let inc_fn_ptr = backend.get_refcount_fn_ptr(closure_data_layout, HelperOp::Inc); match op { // List.map : List elem_x, (elem_x -> elem_ret) -> List elem_ret ListMap { xs } => { - let list_layout_in = backend.storage.symbol_layouts[xs]; + let list_x = backend.storage.symbol_layouts[xs]; - let (elem_x, elem_ret) = match (list_layout_in, return_layout) { + let (elem_x, elem_ret) = match (list_x, return_layout) { ( Layout::Builtin(Builtin::List(elem_x)), Layout::Builtin(Builtin::List(elem_ret)), ) => (elem_x, elem_ret), - _ => unreachable!("invalid layout for List.map arguments"), + _ => unreachable!("invalid arguments layout for {:?}", op), }; let elem_x_size = elem_x.stack_size(TARGET_INFO); let (elem_ret_size, elem_ret_align) = elem_ret.stack_size_and_alignment(TARGET_INFO); @@ -1039,7 +1038,7 @@ pub fn call_higher_order_lowlevel<'a>( // Load return pointer & argument values // Wasm signature: (i32, i64, i64, i32, i32, i32, i32, i32, i32, i32) -> nil backend.storage.load_symbols(cb, &[return_sym]); - backend.storage.load_symbol_zig(cb, *xs); // list with capacity = 2 x i64 args + backend.storage.load_symbol_zig(cb, *xs); // 2 x i64 cb.i32_const(wrapper_fn_ptr); if closure_data_exists { backend.storage.load_symbols(cb, &[*captured_environment]); @@ -1062,8 +1061,70 @@ pub fn call_higher_order_lowlevel<'a>( ); } - ListMap2 { .. } - | ListMap3 { .. } + ListMap2 { xs, ys } => { + let list_x = backend.storage.symbol_layouts[xs]; + let list_y = backend.storage.symbol_layouts[ys]; + + let (elem_x, elem_y, elem_ret) = match (list_x, list_y, return_layout) { + ( + Layout::Builtin(Builtin::List(x)), + Layout::Builtin(Builtin::List(y)), + Layout::Builtin(Builtin::List(ret)), + ) => (x, y, ret), + _ => unreachable!("invalid arguments layout for {:?}", op), + }; + let elem_x_size = elem_x.stack_size(TARGET_INFO); + let elem_y_size = elem_y.stack_size(TARGET_INFO); + let (elem_ret_size, elem_ret_align) = elem_ret.stack_size_and_alignment(TARGET_INFO); + + let dec_x_fn_ptr = backend.get_refcount_fn_ptr(*elem_x, HelperOp::Dec); + let dec_y_fn_ptr = backend.get_refcount_fn_ptr(*elem_y, HelperOp::Dec); + + let cb = &mut backend.code_builder; + + /* Load Wasm arguments + return ptr: RocList, // i32 + list1: RocList, // i64, i64 + list2: RocList, // i64, i64 + caller: Caller2, // i32 + data: Opaque, // i32 + inc_n_data: IncN, // i32 + data_is_owned: bool, // i32 + alignment: u32, // i32 + a_width: usize, // i32 + b_width: usize, // i32 + c_width: usize, // i32 + dec_a: Dec, // i32 + dec_b: Dec, // i32 + */ + backend.storage.load_symbols(cb, &[return_sym]); + backend.storage.load_symbol_zig(cb, *xs); + backend.storage.load_symbol_zig(cb, *ys); + cb.i32_const(wrapper_fn_ptr); + if closure_data_exists { + backend.storage.load_symbols(cb, &[*captured_environment]); + } else { + cb.i32_const(0); // null pointer + } + cb.i32_const(inc_fn_ptr); + cb.i32_const(*owns_captured_environment as i32); + cb.i32_const(elem_ret_align as i32); + cb.i32_const(elem_x_size as i32); + cb.i32_const(elem_y_size as i32); + cb.i32_const(elem_ret_size as i32); + cb.i32_const(dec_x_fn_ptr); + cb.i32_const(dec_y_fn_ptr); + + let num_wasm_args = 15; + let has_return_val = false; + backend.call_zig_builtin_after_loading_args( + bitcode::LIST_MAP2, + num_wasm_args, + has_return_val, + ); + } + + ListMap3 { .. } | ListMap4 { .. } | ListMapWithIndex { .. } | ListKeepIf { .. } diff --git a/compiler/mono/src/code_gen_help/mod.rs b/compiler/mono/src/code_gen_help/mod.rs index 2ffdd2494d..3244e04fbb 100644 --- a/compiler/mono/src/code_gen_help/mod.rs +++ b/compiler/mono/src/code_gen_help/mod.rs @@ -25,7 +25,7 @@ const ARG_2: Symbol = Symbol::ARG_2; pub const REFCOUNT_MAX: usize = 0; #[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum HelperOp { +pub enum HelperOp { Inc, Dec, DecRef(JoinPointId), @@ -185,16 +185,16 @@ impl<'a> CodeGenHelp<'a> { /// Generate a refcount increment procedure, *without* a Call expression. /// *This method should be rarely used* - only when the proc is to be called from Zig. /// Otherwise you want to generate the Proc and the Call together, using another method. - /// This only supports the 'inc' operation, as it's the only real use case we have. - pub fn gen_refcount_inc_proc( + pub fn gen_refcount_proc( &mut self, ident_ids: &mut IdentIds, layout: Layout<'a>, + op: HelperOp, ) -> (Symbol, Vec<'a, (Symbol, ProcLayout<'a>)>) { let mut ctx = Context { new_linker_data: Vec::new_in(self.arena), recursive_union: None, - op: HelperOp::Inc, + op, }; let proc_name = self.find_or_create_proc(ident_ids, &mut ctx, layout); diff --git a/compiler/mono/src/code_gen_help/refcount.rs b/compiler/mono/src/code_gen_help/refcount.rs index ef46ad8564..176f875d41 100644 --- a/compiler/mono/src/code_gen_help/refcount.rs +++ b/compiler/mono/src/code_gen_help/refcount.rs @@ -107,7 +107,9 @@ pub fn refcount_generic<'a>( match layout { Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => { - unreachable!("Not refcounted: {:?}", layout) + // Generate a dummy function that immediately returns Unit + // Some higher-order Zig builtins *always* call an RC function on List elements. + rc_return_stmt(root, ident_ids, ctx) } Layout::Builtin(Builtin::Str) => refcount_str(root, ident_ids, ctx), Layout::Builtin(Builtin::List(elem_layout)) => { diff --git a/compiler/test_gen/src/gen_list.rs b/compiler/test_gen/src/gen_list.rs index cb46f2de49..92ab071fa2 100644 --- a/compiler/test_gen/src/gen_list.rs +++ b/compiler/test_gen/src/gen_list.rs @@ -1169,7 +1169,7 @@ fn list_map3_different_length() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn list_map2_pair() { assert_evals_to!( indoc!( @@ -1184,13 +1184,13 @@ fn list_map2_pair() { } #[test] -#[cfg(any(feature = "gen-llvm"))] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn list_map2_different_lengths() { assert_evals_to!( indoc!( r#" List.map2 - ["a", "b", "lllllllllllllongnggg" ] + ["a", "b", "lllllllllllllooooooooongnggg" ] ["b"] (\a, b -> Str.concat a b) "#