diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index 0d850ea3b3..e6d16b9785 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -130,6 +130,25 @@ pub fn listMap(list: RocList, transform: Opaque, caller: Caller1, alignment: usi } } +pub fn listMapWithIndex(list: RocList, transform: Opaque, caller: Caller2, alignment: usize, old_element_width: usize, new_element_width: usize) callconv(.C) RocList { + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = 0; + const output = RocList.allocate(std.heap.c_allocator, alignment, size, new_element_width); + const target_ptr = output.bytes orelse unreachable; + + while (i < size) : (i += 1) { + caller(transform, @ptrCast(?[*]u8, &i), source_ptr + (i * old_element_width), target_ptr + (i * new_element_width)); + } + + utils.decref(std.heap.c_allocator, alignment, list.bytes, size * old_element_width); + + return output; + } else { + return RocList.empty(); + } +} + pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, element_width: usize) callconv(.C) RocList { if (list.bytes) |source_ptr| { const size = list.len(); diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index 2cce26a092..1ecb482406 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -7,6 +7,7 @@ const list = @import("list.zig"); comptime { exportListFn(list.listMap, "map"); + exportListFn(list.listMapWithIndex, "map_with_index"); exportListFn(list.listKeepIf, "keep_if"); exportListFn(list.listWalk, "walk"); exportListFn(list.listWalkBackwards, "walk_backwards"); diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index 528467ce55..a220151bc9 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -60,6 +60,7 @@ pub const DICT_WALK: &str = "roc_builtins.dict.walk"; pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list"; pub const LIST_MAP: &str = "roc_builtins.list.map"; +pub const LIST_MAP_WITH_INDEX: &str = "roc_builtins.list.map_with_index"; pub const LIST_KEEP_IF: &str = "roc_builtins.list.keep_if"; pub const LIST_KEEP_OKS: &str = "roc_builtins.list.keep_oks"; pub const LIST_KEEP_ERRS: &str = "roc_builtins.list.keep_errs"; diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 3dcf93f221..42eb034581 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -726,6 +726,18 @@ pub fn types() -> MutMap { ), ); + // mapWithIndex : List before, (Nat, before -> after) -> List after + add_type(Symbol::LIST_MAP_WITH_INDEX, { + let_tvars! { cvar, before, after}; + top_level_function( + vec![ + list_type(flex(before)), + closure(vec![nat_type(), flex(before)], cvar, Box::new(flex(after))), + ], + Box::new(list_type(flex(after))), + ) + }); + // append : List elem, elem -> List elem add_type( Symbol::LIST_APPEND, diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index e2854cd708..2865f31dff 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -78,6 +78,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option LIST_PREPEND => list_prepend, LIST_JOIN => list_join, LIST_MAP => list_map, + LIST_MAP_WITH_INDEX => list_map_with_index, LIST_KEEP_IF => list_keep_if, LIST_KEEP_OKS => list_keep_oks, LIST_KEEP_ERRS=> list_keep_errs, @@ -205,6 +206,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::LIST_PREPEND => list_prepend, Symbol::LIST_JOIN => list_join, Symbol::LIST_MAP => list_map, + Symbol::LIST_MAP_WITH_INDEX => list_map_with_index, Symbol::LIST_KEEP_IF => list_keep_if, Symbol::LIST_KEEP_OKS => list_keep_oks, Symbol::LIST_KEEP_ERRS=> list_keep_errs, @@ -1956,6 +1958,11 @@ fn list_map(symbol: Symbol, var_store: &mut VarStore) -> Def { lowlevel_2(symbol, LowLevel::ListMap, var_store) } +/// List.mapWithIndex : List before, (Nat, before -> after) -> List after +fn list_map_with_index(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::ListMapWithIndex, var_store) +} + /// Dict.hashTestOnly : k, v -> Nat pub fn dict_hash_test_only(symbol: Symbol, var_store: &mut VarStore) -> Def { lowlevel_2(symbol, LowLevel::Hash, var_store) diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 6f0bc4d3e9..ba1bdda4e7 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -7,8 +7,8 @@ use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, - list_prepend, list_repeat, list_reverse, list_set, list_single, list_sum, list_walk, - list_walk_backwards, + list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, list_single, list_sum, + list_walk, list_walk_backwards, }; use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_join_with, @@ -3642,6 +3642,24 @@ fn run_low_level<'a, 'ctx, 'env>( _ => unreachable!("invalid list layout"), } } + ListMapWithIndex => { + // List.map : List before, (before -> after) -> List after + debug_assert_eq!(args.len(), 2); + + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + + match list_layout { + Layout::Builtin(Builtin::EmptyList) => { + return empty_list(env); + } + Layout::Builtin(Builtin::List(_, element_layout)) => { + list_map_with_index(env, layout_ids, func, func_layout, list, element_layout) + } + _ => unreachable!("invalid list layout"), + } + } ListKeepIf => { // List.keepIf : List elem, (elem -> Bool) -> List elem debug_assert_eq!(args.len(), 2); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 2df8115e71..27de085980 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1169,6 +1169,49 @@ pub fn list_map<'a, 'ctx, 'env>( transform_layout: &Layout<'a>, list: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_map_generic( + env, + layout_ids, + transform, + transform_layout, + list, + element_layout, + bitcode::LIST_MAP, + &[element_layout.clone()], + ) +} + +/// List.mapWithIndex : List before, (Nat, before -> after) -> List after +pub fn list_map_with_index<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_map_generic( + env, + layout_ids, + transform, + transform_layout, + list, + element_layout, + bitcode::LIST_MAP_WITH_INDEX, + &[Layout::Builtin(Builtin::Usize), element_layout.clone()], + ) +} + +fn list_map_generic<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, + op: &str, + argument_layouts: &[Layout<'a>], ) -> BasicValueEnum<'ctx> { let builder = env.builder; @@ -1186,7 +1229,7 @@ pub fn list_map<'a, 'ctx, 'env>( env.builder.build_store(transform_ptr, transform); let stepper_caller = - build_transform_caller(env, layout_ids, transform_layout, &[element_layout.clone()]) + build_transform_caller(env, layout_ids, transform_layout, argument_layouts) .as_global_value() .as_pointer_value(); @@ -1212,7 +1255,7 @@ pub fn list_map<'a, 'ctx, 'env>( old_element_width.into(), new_element_width.into(), ], - &bitcode::LIST_MAP, + op, ); complex_bitcast( diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index efbe56bc42..60adc5cf02 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -1729,6 +1729,15 @@ mod gen_list { assert_evals_to!("List.keepErrs [Ok 1, Err 2] (\\x -> x)", &[2], &[i64]); } + #[test] + fn list_map_with_index() { + assert_evals_to!( + "List.mapWithIndex [0,0,0] (\\index, x -> index + x)", + &[0, 1, 2], + &[i64] + ); + } + #[test] #[should_panic(expected = r#"Roc failed with message: "integer addition overflowed!"#)] fn cleanup_because_exception() { diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 0acb97d0d8..e69fa0dd02 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -25,6 +25,7 @@ pub enum LowLevel { ListPrepend, ListJoin, ListMap, + ListMapWithIndex, ListKeepIf, ListWalk, ListWalkBackwards, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 9380c1564b..25fb2cb0fb 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -898,6 +898,7 @@ define_builtins! { 20 LIST_LAST: "last" 21 LIST_KEEP_OKS: "keepOks" 22 LIST_KEEP_ERRS: "keepErrs" + 23 LIST_MAP_WITH_INDEX: "mapWithIndex" } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 3783ab9a5b..47fd1c8c95 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -581,7 +581,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ListPrepend => arena.alloc_slice_copy(&[owned, owned]), StrJoinWith => arena.alloc_slice_copy(&[irrelevant, irrelevant]), ListJoin => arena.alloc_slice_copy(&[irrelevant]), - ListMap => arena.alloc_slice_copy(&[owned, irrelevant]), + ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, irrelevant]), ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, irrelevant]), ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]),