diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index fcd5f42cf8..95137d25bb 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -114,6 +114,7 @@ pub const RocList = extern struct { const Caller1 = fn (?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void; const Caller2 = fn (?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void; +const Caller3 = fn (?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void; pub fn listMap(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, old_element_width: usize, new_element_width: usize) callconv(.C) RocList { if (list.bytes) |source_ptr| { @@ -213,6 +214,126 @@ pub fn listMap2(list1: RocList, list2: RocList, transform: Opaque, caller: Calle } } +pub fn listMap3(list1: RocList, list2: RocList, list3: RocList, transform: Opaque, caller: Caller3, alignment: usize, a_width: usize, b_width: usize, c_width: usize, d_width: usize, dec_a: Dec, dec_b: Dec, dec_c: Dec) callconv(.C) RocList { + const smaller_length = std.math.min(list1.len(), list2.len()); + const output_length = std.math.min(smaller_length, list3.len()); + + if (list1.bytes) |source_a| { + if (list2.bytes) |source_b| { + if (list3.bytes) |source_c| { + const output = RocList.allocate(std.heap.c_allocator, alignment, output_length, d_width); + const target_ptr = output.bytes orelse unreachable; + + var i: usize = 0; + while (i < output_length) : (i += 1) { + const element_a = source_a + i * a_width; + const element_b = source_b + i * b_width; + const element_c = source_c + i * c_width; + const target = target_ptr + i * d_width; + + caller(transform, element_a, element_b, element_c, target); + } + + // if the lists don't have equal length, we must consume the remaining elements + // In this case we consume by (recursively) decrementing the elements + if (list1.len() > output_length) { + while (i < list1.len()) : (i += 1) { + const element_a = source_a + i * a_width; + dec_a(element_a); + } + } + + if (list2.len() > output_length) { + while (i < list2.len()) : (i += 1) { + const element_b = source_b + i * b_width; + dec_b(element_b); + } + } + + if (list3.len() > output_length) { + while (i < list3.len()) : (i += 1) { + const element_c = source_c + i * c_width; + dec_b(element_c); + } + } + + utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width); + utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width); + utils.decref(std.heap.c_allocator, alignment, list3.bytes, list3.len() * c_width); + + return output; + } else { + // consume list1 elements (we know there is at least one because the list1.bytes pointer is non-null + var i: usize = 0; + while (i < list1.len()) : (i += 1) { + const element_a = source_a + i * a_width; + dec_a(element_a); + } + utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width); + + // consume list2 elements (we know there is at least one because the list1.bytes pointer is non-null + i = 0; + while (i < list2.len()) : (i += 1) { + const element_b = source_b + i * b_width; + dec_b(element_b); + } + utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width); + + return RocList.empty(); + } + } else { + // consume list1 elements (we know there is at least one because the list1.bytes pointer is non-null + var i: usize = 0; + while (i < list1.len()) : (i += 1) { + const element_a = source_a + i * a_width; + dec_a(element_a); + } + + utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width); + + // consume list3 elements (if any) + if (list3.bytes) |source_c| { + i = 0; + + while (i < list2.len()) : (i += 1) { + const element_c = source_c + i * c_width; + dec_c(element_c); + } + + utils.decref(std.heap.c_allocator, alignment, list3.bytes, list3.len() * c_width); + } + + return RocList.empty(); + } + } else { + // consume list2 elements (if any) + if (list2.bytes) |source_b| { + var i: usize = 0; + + while (i < list2.len()) : (i += 1) { + const element_b = source_b + i * b_width; + dec_b(element_b); + } + + utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width); + } + + // consume list3 elements (if any) + if (list3.bytes) |source_c| { + var i: usize = 0; + + while (i < list2.len()) : (i += 1) { + const element_c = source_c + i * c_width; + dec_c(element_c); + } + + utils.decref(std.heap.c_allocator, alignment, list3.bytes, list3.len() * c_width); + } + + return RocList.empty(); + } +} + pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, element_width: usize, inc: Inc, dec: Dec) callconv(.C) RocList { if (list.bytes) |source_ptr| { const size = list.len(); diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index 2905ec52df..0fd6cb1f2f 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -64,6 +64,7 @@ pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list"; pub const LIST_MAP: &str = "roc_builtins.list.map"; pub const LIST_MAP2: &str = "roc_builtins.list.map2"; +pub const LIST_MAP3: &str = "roc_builtins.list.map3"; pub const LIST_MAP_WITH_INDEX: &str = "roc_builtins.list.map_with_index"; pub const LIST_KEEP_IF: &str = "roc_builtins.list.keep_if"; pub const LIST_KEEP_OKS: &str = "roc_builtins.list.keep_oks"; diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 094fbac48f..a743afde77 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -817,6 +817,21 @@ pub fn types() -> MutMap { ) }); + // map3 : List a, List b, List c, (a, b, c -> d) -> List d + add_type(Symbol::LIST_MAP3, { + let_tvars! {a, b, c, d, cvar}; + + top_level_function( + vec![ + list_type(flex(a)), + list_type(flex(b)), + list_type(flex(c)), + closure(vec![flex(a), flex(b), flex(c)], cvar, Box::new(flex(d))), + ], + Box::new(list_type(flex(d))), + ) + }); + // append : List elem, elem -> List elem add_type( Symbol::LIST_APPEND, diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index 6977eaa607..e31d1597f0 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -81,6 +81,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option LIST_JOIN => list_join, LIST_MAP => list_map, LIST_MAP2 => list_map2, + LIST_MAP3 => list_map3, LIST_MAP_WITH_INDEX => list_map_with_index, LIST_KEEP_IF => list_keep_if, LIST_KEEP_OKS => list_keep_oks, @@ -218,6 +219,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::LIST_JOIN => list_join, Symbol::LIST_MAP => list_map, Symbol::LIST_MAP2 => list_map2, + Symbol::LIST_MAP3 => list_map3, Symbol::LIST_MAP_WITH_INDEX => list_map_with_index, Symbol::LIST_KEEP_IF => list_keep_if, Symbol::LIST_KEEP_OKS => list_keep_oks, @@ -370,6 +372,38 @@ fn lowlevel_3(symbol: Symbol, op: LowLevel, var_store: &mut VarStore) -> Def { ) } +fn lowlevel_4(symbol: Symbol, op: LowLevel, var_store: &mut VarStore) -> Def { + let arg1_var = var_store.fresh(); + let arg2_var = var_store.fresh(); + let arg3_var = var_store.fresh(); + let arg4_var = var_store.fresh(); + let ret_var = var_store.fresh(); + + let body = RunLowLevel { + op, + args: vec![ + (arg1_var, Var(Symbol::ARG_1)), + (arg2_var, Var(Symbol::ARG_2)), + (arg3_var, Var(Symbol::ARG_3)), + (arg4_var, Var(Symbol::ARG_4)), + ], + ret_var, + }; + + defn( + symbol, + vec![ + (arg1_var, Symbol::ARG_1), + (arg2_var, Symbol::ARG_2), + (arg3_var, Symbol::ARG_3), + (arg4_var, Symbol::ARG_4), + ], + var_store, + body, + ret_var, + ) +} + /// Num.maxInt : Int fn num_max_int(symbol: Symbol, var_store: &mut VarStore) -> Def { let int_var = var_store.fresh(); @@ -2122,6 +2156,11 @@ fn list_map2(symbol: Symbol, var_store: &mut VarStore) -> Def { lowlevel_3(symbol, LowLevel::ListMap2, var_store) } +/// List.map3 : List a, List b, (a, b -> c) -> List c +fn list_map3(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_4(symbol, LowLevel::ListMap3, var_store) +} + /// Dict.hashTestOnly : k, v -> Nat pub fn dict_hash_test_only(symbol: Symbol, var_store: &mut VarStore) -> Def { lowlevel_2(symbol, LowLevel::Hash, var_store) diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 98386c2270..903ec6acef 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -7,8 +7,8 @@ use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, - list_map2, list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, list_single, - list_sum, list_walk, list_walk_backwards, + list_map2, list_map3, list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, + list_single, list_sum, list_walk, list_walk_backwards, }; use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8, @@ -3746,6 +3746,38 @@ fn run_low_level<'a, 'ctx, 'env>( _ => unreachable!("invalid list layout"), } } + ListMap3 => { + debug_assert_eq!(args.len(), 4); + + let (list1, list1_layout) = load_symbol_and_layout(scope, &args[0]); + let (list2, list2_layout) = load_symbol_and_layout(scope, &args[1]); + let (list3, list3_layout) = load_symbol_and_layout(scope, &args[2]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[3]); + + match (list1_layout, list2_layout, list3_layout) { + ( + Layout::Builtin(Builtin::List(_, element1_layout)), + Layout::Builtin(Builtin::List(_, element2_layout)), + Layout::Builtin(Builtin::List(_, element3_layout)), + ) => list_map3( + env, + layout_ids, + func, + func_layout, + list1, + list2, + list3, + element1_layout, + element2_layout, + element3_layout, + ), + (Layout::Builtin(Builtin::EmptyList), _, _) + | (_, Layout::Builtin(Builtin::EmptyList), _) + | (_, _, Layout::Builtin(Builtin::EmptyList)) => empty_list(env), + _ => unreachable!("invalid list layout"), + } + } ListMapWithIndex => { // List.map : List before, (before -> after) -> List after debug_assert_eq!(args.len(), 2); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 4124891cf0..fb3232e0b8 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1305,6 +1305,114 @@ pub fn list_map2<'a, 'ctx, 'env>( ) } +pub fn list_map3<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list1: BasicValueEnum<'ctx>, + list2: BasicValueEnum<'ctx>, + list3: BasicValueEnum<'ctx>, + element1_layout: &Layout<'a>, + element2_layout: &Layout<'a>, + element3_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let return_layout = match transform_layout { + Layout::FunctionPointer(_, ret) => ret, + Layout::Closure(_, _, ret) => ret, + _ => unreachable!("not a callable layout"), + }; + + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let list1_i128 = complex_bitcast( + env.builder, + list1, + env.context.i128_type().into(), + "to_i128", + ); + + let list2_i128 = complex_bitcast( + env.builder, + list2, + env.context.i128_type().into(), + "to_i128", + ); + + let list3_i128 = complex_bitcast( + env.builder, + list3, + env.context.i128_type().into(), + "to_i128", + ); + + let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); + env.builder.build_store(transform_ptr, transform); + + let argument_layouts = [ + element1_layout.clone(), + element2_layout.clone(), + element3_layout.clone(), + ]; + let stepper_caller = + build_transform_caller(env, layout_ids, transform_layout, &argument_layouts) + .as_global_value() + .as_pointer_value(); + + let a_width = env + .ptr_int() + .const_int(element1_layout.stack_size(env.ptr_bytes) as u64, false); + + let b_width = env + .ptr_int() + .const_int(element2_layout.stack_size(env.ptr_bytes) as u64, false); + + let c_width = env + .ptr_int() + .const_int(element3_layout.stack_size(env.ptr_bytes) as u64, false); + + let d_width = env + .ptr_int() + .const_int(return_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = return_layout.alignment_bytes(env.ptr_bytes); + let alignment_iv = env.ptr_int().const_int(alignment as u64, false); + + let dec_a = build_dec_wrapper(env, layout_ids, element1_layout); + let dec_b = build_dec_wrapper(env, layout_ids, element2_layout); + let dec_c = build_dec_wrapper(env, layout_ids, element3_layout); + + let output = call_bitcode_fn( + env, + &[ + list1_i128, + list2_i128, + list3_i128, + env.builder + .build_bitcast(transform_ptr, u8_ptr, "to_opaque"), + stepper_caller.into(), + alignment_iv.into(), + a_width.into(), + b_width.into(), + c_width.into(), + d_width.into(), + dec_a.as_global_value().as_pointer_value().into(), + dec_b.as_global_value().as_pointer_value().into(), + dec_c.as_global_value().as_pointer_value().into(), + ], + bitcode::LIST_MAP3, + ); + + complex_bitcast( + env.builder, + output, + collection(env.context, env.ptr_bytes).into(), + "from_i128", + ) +} + /// List.concat : List elem, List elem -> List elem pub fn list_concat<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index c597c42076..b8a2ed31c6 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -28,6 +28,7 @@ pub enum LowLevel { ListJoin, ListMap, ListMap2, + ListMap3, ListMapWithIndex, ListKeepIf, ListWalk, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 9ecb241be0..652c8493e2 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -910,6 +910,7 @@ define_builtins! { 22 LIST_KEEP_ERRS: "keepErrs" 23 LIST_MAP_WITH_INDEX: "mapWithIndex" 24 LIST_MAP2: "map2" + 25 LIST_MAP3: "map3" } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 486205fcb3..b3785ddaa9 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -652,6 +652,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ListJoin => arena.alloc_slice_copy(&[irrelevant]), ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, irrelevant]), ListMap2 => arena.alloc_slice_copy(&[owned, owned, irrelevant]), + ListMap3 => arena.alloc_slice_copy(&[owned, owned, owned, irrelevant]), ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, borrowed]), ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]), diff --git a/compiler/test_gen/src/gen_list.rs b/compiler/test_gen/src/gen_list.rs index dbf974988b..28986f682c 100644 --- a/compiler/test_gen/src/gen_list.rs +++ b/compiler/test_gen/src/gen_list.rs @@ -568,6 +568,36 @@ fn list_map_closure() { ); } +#[test] +fn list_map3_group() { + assert_evals_to!( + indoc!( + r#" + List.map3 [1,2,3] [3,2,1] [2,1,3] (\a, b, c -> Group a b c) + "# + ), + RocList::from_slice(&[(1, 3, 2), (2, 2, 1), (3, 1, 3)]), + RocList<(i64, i64, i64)> + ); +} + +#[test] +fn list_map3_different_length() { + assert_evals_to!( + indoc!( + r#" + List.map3 + ["a", "b", "d" ] + ["b"], + ["c"], + Str.concat + "# + ), + RocList::from_slice(&[RocStr::from_slice("abc".as_bytes()),]), + RocList + ); +} + #[test] fn list_map2_pair() { assert_evals_to!(