diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index 139163af1d..285482e293 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -502,6 +502,42 @@ pub fn listWalkBackwards(list: RocList, stepper: Opaque, stepper_caller: Caller2 utils.decref(std.heap.c_allocator, alignment, list.bytes, data_bytes); } +pub fn listWalkUntil(list: RocList, stepper: Opaque, stepper_caller: Caller2, accum: Opaque, alignment: usize, element_width: usize, accum_width: usize, output: Opaque) callconv(.C) void { + if (accum_width == 0) { + return; + } + + if (list.isEmpty()) { + @memcpy(output orelse unreachable, accum orelse unreachable, accum_width); + return; + } + + const alloc: [*]u8 = @ptrCast([*]u8, std.heap.c_allocator.alloc(u8, accum_width) catch unreachable); + var b1 = output orelse unreachable; + var b2 = alloc; + + @memcpy(b2, accum orelse unreachable, accum_width); + + if (list.bytes) |source_ptr| { + var i: usize = 0; + const size = list.len(); + while (i < size) : (i += 1) { + const element = source_ptr + i * element_width; + stepper_caller(stepper, element, b2, b1); + + const temp = b1; + b2 = b1; + b1 = temp; + } + } + + @memcpy(output orelse unreachable, b2, accum_width); + std.heap.c_allocator.free(alloc[0..accum_width]); + + const data_bytes = list.len() * element_width; + utils.decref(std.heap.c_allocator, alignment, list.bytes, data_bytes); +} + // List.contains : List k, k -> Bool pub fn listContains(list: RocList, key: Opaque, key_width: usize, is_eq: EqFn) callconv(.C) bool { if (list.bytes) |source_ptr| { diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 4c4b13d99a..a98fcd4dc9 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -771,6 +771,34 @@ pub fn types() -> MutMap { ), ); + fn until_type(content: SolvedType) -> SolvedType { + // [ LT, EQ, GT ] + SolvedType::TagUnion( + vec![ + (TagName::Global("Continue".into()), vec![content.clone()]), + (TagName::Global("Stop".into()), vec![content]), + ], + Box::new(SolvedType::EmptyTagUnion), + ) + } + + // walkUntil : List elem, (elem -> accum -> [ Continue accum, Stop accum ]), accum -> accum + add_type( + Symbol::LIST_WALK_UNTIL, + top_level_function( + vec![ + list_type(flex(TVAR1)), + closure( + vec![flex(TVAR1), flex(TVAR2)], + TVAR3, + Box::new(until_type(flex(TVAR2))), + ), + flex(TVAR2), + ], + Box::new(flex(TVAR2)), + ), + ); + // keepIf : List elem, (elem -> Bool) -> List elem add_type( Symbol::LIST_KEEP_IF, diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index 334fee9cda..50ac3a1100 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -89,6 +89,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option LIST_KEEP_ERRS=> list_keep_errs, LIST_WALK => list_walk, LIST_WALK_BACKWARDS => list_walk_backwards, + LIST_WALK_UNTIL => list_walk_until, DICT_TEST_HASH => dict_hash_test_only, DICT_LEN => dict_len, DICT_EMPTY => dict_empty, @@ -231,6 +232,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::LIST_KEEP_ERRS=> list_keep_errs, Symbol::LIST_WALK => list_walk, Symbol::LIST_WALK_BACKWARDS => list_walk_backwards, + Symbol::LIST_WALK_UNTIL => list_walk_until, Symbol::DICT_TEST_HASH => dict_hash_test_only, Symbol::DICT_LEN => dict_len, Symbol::DICT_EMPTY => dict_empty, @@ -2094,60 +2096,17 @@ fn list_join(symbol: Symbol, var_store: &mut VarStore) -> Def { /// List.walk : List elem, (elem -> accum -> accum), accum -> accum fn list_walk(symbol: Symbol, var_store: &mut VarStore) -> Def { - let list_var = var_store.fresh(); - let func_var = var_store.fresh(); - let accum_var = var_store.fresh(); - - let body = RunLowLevel { - op: LowLevel::ListWalk, - args: vec![ - (list_var, Var(Symbol::ARG_1)), - (func_var, Var(Symbol::ARG_2)), - (accum_var, Var(Symbol::ARG_3)), - ], - ret_var: accum_var, - }; - - defn( - symbol, - vec![ - (list_var, Symbol::ARG_1), - (func_var, Symbol::ARG_2), - (accum_var, Symbol::ARG_3), - ], - var_store, - body, - accum_var, - ) + lowlevel_3(symbol, LowLevel::ListWalk, var_store) } /// List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum fn list_walk_backwards(symbol: Symbol, var_store: &mut VarStore) -> Def { - let list_var = var_store.fresh(); - let func_var = var_store.fresh(); - let accum_var = var_store.fresh(); + lowlevel_3(symbol, LowLevel::ListWalkBackwards, var_store) +} - let body = RunLowLevel { - op: LowLevel::ListWalkBackwards, - args: vec![ - (list_var, Var(Symbol::ARG_1)), - (func_var, Var(Symbol::ARG_2)), - (accum_var, Var(Symbol::ARG_3)), - ], - ret_var: accum_var, - }; - - defn( - symbol, - vec![ - (list_var, Symbol::ARG_1), - (func_var, Symbol::ARG_2), - (accum_var, Symbol::ARG_3), - ], - var_store, - body, - accum_var, - ) +/// List.walkUntil : List elem, (elem, accum -> [ Continue accum, Stop accum ]), accum -> accum +fn list_walk_until(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_3(symbol, LowLevel::ListWalkUntil, var_store) } /// List.sum : List (Num a) -> Num a diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 4dfda5e9fa..65e7ddb604 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -8,7 +8,7 @@ use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, list_map2, list_map3, list_map_with_index, list_prepend, list_product, list_repeat, - list_reverse, list_set, list_single, list_sum, list_walk, list_walk_backwards, + list_reverse, list_set, list_single, list_sum, list_walk_help, }; use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8, @@ -3879,57 +3879,30 @@ fn run_low_level<'a, 'ctx, 'env>( list_contains(env, layout_ids, elem, elem_layout, list) } - ListWalk => { - debug_assert_eq!(args.len(), 3); - - let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - - let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); - - let (default, default_layout) = load_symbol_and_layout(scope, &args[2]); - - match list_layout { - Layout::Builtin(Builtin::EmptyList) => default, - Layout::Builtin(Builtin::List(_, element_layout)) => list_walk( - env, - layout_ids, - parent, - list, - element_layout, - func, - func_layout, - default, - default_layout, - ), - _ => unreachable!("invalid list layout"), - } - } - ListWalkBackwards => { - // List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum - debug_assert_eq!(args.len(), 3); - - let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - - let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); - - let (default, default_layout) = load_symbol_and_layout(scope, &args[2]); - - match list_layout { - Layout::Builtin(Builtin::EmptyList) => default, - Layout::Builtin(Builtin::List(_, element_layout)) => list_walk_backwards( - env, - layout_ids, - parent, - list, - element_layout, - func, - func_layout, - default, - default_layout, - ), - _ => unreachable!("invalid list layout"), - } - } + ListWalk => list_walk_help( + env, + layout_ids, + scope, + parent, + args, + crate::llvm::build_list::ListWalk::Walk, + ), + ListWalkUntil => list_walk_help( + env, + layout_ids, + scope, + parent, + args, + crate::llvm::build_list::ListWalk::WalkUntil, + ), + ListWalkBackwards => list_walk_help( + env, + layout_ids, + scope, + parent, + args, + crate::llvm::build_list::ListWalk::WalkBackwards, + ), ListSum => { debug_assert_eq!(args.len(), 1); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 40321129db..7ef207e816 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -863,56 +863,54 @@ pub fn list_product<'a, 'ctx, 'env>( builder.build_load(accum_alloca, "load_final_acum") } -/// List.walk : List elem, (elem -> accum -> accum), accum -> accum -pub fn list_walk<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - parent: FunctionValue<'ctx>, - list: BasicValueEnum<'ctx>, - element_layout: &Layout<'a>, - func: BasicValueEnum<'ctx>, - func_layout: &Layout<'a>, - default: BasicValueEnum<'ctx>, - default_layout: &Layout<'a>, -) -> BasicValueEnum<'ctx> { - list_walk_generic( - env, - layout_ids, - parent, - list, - element_layout, - func, - func_layout, - default, - default_layout, - &bitcode::LIST_WALK, - ) +pub enum ListWalk { + Walk, + WalkBackwards, + WalkUntil, + WalkBackwardsUntil, } -/// List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum -pub fn list_walk_backwards<'a, 'ctx, 'env>( +pub fn list_walk_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + scope: &crate::llvm::build::Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, - list: BasicValueEnum<'ctx>, - element_layout: &Layout<'a>, - func: BasicValueEnum<'ctx>, - func_layout: &Layout<'a>, - default: BasicValueEnum<'ctx>, - default_layout: &Layout<'a>, + args: &[roc_module::symbol::Symbol], + variant: ListWalk, ) -> BasicValueEnum<'ctx> { - list_walk_generic( - env, - layout_ids, - parent, - list, - element_layout, - func, - func_layout, - default, - default_layout, - &bitcode::LIST_WALK_BACKWARDS, - ) + use crate::llvm::build::load_symbol_and_layout; + + debug_assert_eq!(args.len(), 3); + + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + + let (default, default_layout) = load_symbol_and_layout(scope, &args[2]); + + let bitcode_fn = match variant { + ListWalk::Walk => bitcode::LIST_WALK, + ListWalk::WalkBackwards => bitcode::LIST_WALK_BACKWARDS, + ListWalk::WalkUntil => todo!(), + ListWalk::WalkBackwardsUntil => todo!(), + }; + + match list_layout { + Layout::Builtin(Builtin::EmptyList) => default, + Layout::Builtin(Builtin::List(_, element_layout)) => list_walk_generic( + env, + layout_ids, + parent, + list, + element_layout, + func, + func_layout, + default, + default_layout, + &bitcode_fn, + ), + _ => unreachable!("invalid list layout"), + } } fn list_walk_generic<'a, 'ctx, 'env>( diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 3922bb8e6d..24f9f815e2 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -32,6 +32,7 @@ pub enum LowLevel { ListMapWithIndex, ListKeepIf, ListWalk, + ListWalkUntil, ListWalkBackwards, ListSum, ListProduct, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 85806a99c5..1332e0487f 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -915,6 +915,7 @@ define_builtins! { 24 LIST_MAP2: "map2" 25 LIST_MAP3: "map3" 26 LIST_PRODUCT: "product" + 27 LIST_WALK_UNTIL: "walkUntil" } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index f48c787194..6545158512 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -655,8 +655,9 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ListMap3 => arena.alloc_slice_copy(&[owned, owned, owned, irrelevant]), ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, borrowed]), ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), - ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]), - ListWalkBackwards => arena.alloc_slice_copy(&[owned, irrelevant, owned]), + ListWalk | ListWalkUntil | ListWalkBackwards => { + arena.alloc_slice_copy(&[owned, irrelevant, owned]) + } ListSum | ListProduct => arena.alloc_slice_copy(&[borrowed]), // TODO when we have lists with capacity (if ever)