diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index 7c399f1b55..eaa80ed7f8 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -1,5 +1,6 @@ const std = @import("std"); const utils = @import("utils.zig"); +const RocResult = utils.RocResult; const mem = std.mem; const Allocator = mem.Allocator; @@ -152,7 +153,48 @@ pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment: output.length = kept; - // utils.decref(std.heap.c_allocator, alignment, list.bytes, size * old_element_width); + utils.decref(std.heap.c_allocator, alignment, list.bytes, size * element_width); + + return output; + } else { + return RocList.empty(); + } +} + +pub fn listKeepOks(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) callconv(.C) RocList { + return listKeepResult(list, RocResult.isOk, transform, caller, alignment, before_width, result_width, after_width); +} + +pub fn listKeepErrs(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) callconv(.C) RocList { + return listKeepResult(list, RocResult.isErr, transform, caller, alignment, before_width, result_width, after_width); +} + +pub fn listKeepResult(list: RocList, is_good_constructor: fn (RocResult) bool, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) RocList { + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = 0; + var output = RocList.allocate(std.heap.c_allocator, alignment, list.len(), list.len() * after_width); + const target_ptr = output.bytes orelse unreachable; + + var temporary = @ptrCast([*]u8, std.heap.c_allocator.alloc(u8, result_width) catch unreachable); + + var kept: usize = 0; + while (i < size) : (i += 1) { + const element = source_ptr + (i * before_width); + caller(transform, element, temporary); + + const result = utils.RocResult{ .bytes = temporary }; + + if (is_good_constructor(result)) { + @memcpy(target_ptr + (kept * after_width), temporary + @sizeOf(i64), after_width); + + kept += 1; + } + } + + output.length = kept; + + utils.decref(std.heap.c_allocator, alignment, list.bytes, size * before_width); return output; } else { diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index 000b1a4f00..43eda13192 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -10,6 +10,8 @@ comptime { exportListFn(list.listKeepIf, "keep_if"); exportListFn(list.listWalk, "walk"); exportListFn(list.listWalkBackwards, "walk_backwards"); + exportListFn(list.listKeepOks, "keep_oks"); + exportListFn(list.listKeepErrs, "keep_errs"); } // Dict Module diff --git a/compiler/builtins/bitcode/src/utils.zig b/compiler/builtins/bitcode/src/utils.zig index 9479fab7ce..e95f3ab711 100644 --- a/compiler/builtins/bitcode/src/utils.zig +++ b/compiler/builtins/bitcode/src/utils.zig @@ -86,3 +86,22 @@ pub fn allocateWithRefcount( }, } } + +pub const RocResult = extern struct { + bytes: ?[*]u8, + + pub fn isOk(self: RocResult) bool { + // assumptions + // + // - the tag is the first field + // - the tag is usize bytes wide + // - Ok has tag_id 1, because Err < Ok + const usizes: [*]usize = @ptrCast([*]usize, @alignCast(8, self.bytes)); + + return usizes[0] == 1; + } + + pub fn isErr(self: RocResult) bool { + return !self.isOk(); + } +}; diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index 08b33a8d48..a0d3cc6f7f 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -61,5 +61,7 @@ pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list"; pub const LIST_MAP: &str = "roc_builtins.list.map"; pub const LIST_KEEP_IF: &str = "roc_builtins.list.keep_if"; +pub const LIST_KEEP_OKS: &str = "roc_builtins.list.keep_oks"; +pub const LIST_KEEP_ERRS: &str = "roc_builtins.list.keep_errs"; pub const LIST_WALK: &str = "roc_builtins.list.walk"; pub const LIST_WALK_BACKWARDS: &str = "roc_builtins.list.walk_backwards"; diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 12dac61b6f..3dcf93f221 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -10,6 +10,30 @@ use roc_types::solved_types::SolvedType; use roc_types::subs::VarId; use std::collections::HashMap; +/// Example: +/// +/// let_tvars! { a, b, c } +/// +/// This is equivalent to: +/// +/// let a = VarId::from_u32(1); +/// let b = VarId::from_u32(2); +/// let c = VarId::from_u32(3); +/// +/// The idea is that this is less error-prone than assigning hardcoded IDs by hand. +macro_rules! let_tvars { + ($($name:ident,)+) => { let_tvars!($($name),+) }; + ($($name:ident),*) => { + let mut _current_tvar = 0; + + $( + _current_tvar += 1; + + let $name = VarId::from_u32(_current_tvar); + )* + }; +} + #[derive(Clone, Copy, Debug)] pub enum Mode { Standard, @@ -658,6 +682,38 @@ pub fn types() -> MutMap { ), ); + // keepOks : List before, (before -> Result after *) -> List after + add_type(Symbol::LIST_KEEP_OKS, { + let_tvars! { star, cvar, before, after}; + top_level_function( + vec![ + list_type(flex(before)), + closure( + vec![flex(before)], + cvar, + Box::new(result_type(flex(after), flex(star))), + ), + ], + Box::new(list_type(flex(after))), + ) + }); + + // keepOks : List before, (before -> Result * after) -> List after + add_type(Symbol::LIST_KEEP_ERRS, { + let_tvars! { star, cvar, before, after}; + top_level_function( + vec![ + list_type(flex(before)), + closure( + vec![flex(before)], + cvar, + Box::new(result_type(flex(star), flex(after))), + ), + ], + Box::new(list_type(flex(after))), + ) + }); + // map : List before, (before -> after) -> List after add_type( Symbol::LIST_MAP, diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index 111a495663..e2854cd708 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -79,6 +79,8 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option LIST_JOIN => list_join, LIST_MAP => list_map, LIST_KEEP_IF => list_keep_if, + LIST_KEEP_OKS => list_keep_oks, + LIST_KEEP_ERRS=> list_keep_errs, LIST_WALK => list_walk, LIST_WALK_BACKWARDS => list_walk_backwards, DICT_TEST_HASH => dict_hash_test_only, @@ -204,6 +206,8 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::LIST_JOIN => list_join, Symbol::LIST_MAP => list_map, Symbol::LIST_KEEP_IF => list_keep_if, + Symbol::LIST_KEEP_OKS => list_keep_oks, + Symbol::LIST_KEEP_ERRS=> list_keep_errs, Symbol::LIST_WALK => list_walk, Symbol::LIST_WALK_BACKWARDS => list_walk_backwards, Symbol::DICT_TEST_HASH => dict_hash_test_only, @@ -1934,50 +1938,22 @@ fn list_keep_if(symbol: Symbol, var_store: &mut VarStore) -> Def { /// List.contains : List elem, elem -> Bool fn list_contains(symbol: Symbol, var_store: &mut VarStore) -> Def { - let list_var = var_store.fresh(); - let elem_var = var_store.fresh(); - let bool_var = var_store.fresh(); + lowlevel_2(symbol, LowLevel::ListContains, var_store) +} - let body = RunLowLevel { - op: LowLevel::ListContains, - args: vec![ - (list_var, Var(Symbol::ARG_1)), - (elem_var, Var(Symbol::ARG_2)), - ], - ret_var: bool_var, - }; +/// List.keepOks : List before, (before -> Result after *) -> List after +fn list_keep_oks(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::ListKeepOks, var_store) +} - defn( - symbol, - vec![(list_var, Symbol::ARG_1), (elem_var, Symbol::ARG_2)], - var_store, - body, - bool_var, - ) +/// List.keepErrs: List before, (before -> Result * after) -> List after +fn list_keep_errs(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::ListKeepErrs, var_store) } /// List.map : List before, (before -> after) -> List after fn list_map(symbol: Symbol, var_store: &mut VarStore) -> Def { - let list_var = var_store.fresh(); - let func_var = var_store.fresh(); - let ret_list_var = var_store.fresh(); - - let body = RunLowLevel { - op: LowLevel::ListMap, - args: vec![ - (list_var, Var(Symbol::ARG_1)), - (func_var, Var(Symbol::ARG_2)), - ], - ret_var: ret_list_var, - }; - - defn( - symbol, - vec![(list_var, Symbol::ARG_1), (func_var, Symbol::ARG_2)], - var_store, - body, - ret_list_var, - ) + lowlevel_2(symbol, LowLevel::ListMap, var_store) } /// Dict.hashTestOnly : k, v -> Nat diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index bc1fa2c2ae..4f91a33cc7 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -5,8 +5,9 @@ use crate::llvm::build_dict::{ use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, - list_get_unsafe, list_join, list_keep_if, list_len, list_map, list_prepend, list_repeat, - list_reverse, list_set, list_single, list_sum, list_walk, list_walk_backwards, + list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, + list_prepend, list_repeat, list_reverse, list_set, list_single, list_sum, list_walk, + list_walk_backwards, }; use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_join_with, @@ -3648,8 +3649,6 @@ fn run_low_level<'a, 'ctx, 'env>( let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); - let inplace = get_inplace_from_layout(layout); - match list_layout { Layout::Builtin(Builtin::EmptyList) => { return empty_list(env); @@ -3660,6 +3659,66 @@ fn run_low_level<'a, 'ctx, 'env>( _ => unreachable!("invalid list layout"), } } + ListKeepOks => { + // List.keepOks : List before, (before -> Result after *) -> List after + debug_assert_eq!(args.len(), 2); + + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + + match (list_layout, layout) { + (_, Layout::Builtin(Builtin::EmptyList)) + | (Layout::Builtin(Builtin::EmptyList), _) => { + return empty_list(env); + } + ( + Layout::Builtin(Builtin::List(_, before_layout)), + Layout::Builtin(Builtin::List(_, after_layout)), + ) => list_keep_oks( + env, + layout_ids, + func, + func_layout, + list, + before_layout, + after_layout, + ), + (other1, other2) => { + unreachable!("invalid list layouts:\n{:?}\n{:?}", other1, other2) + } + } + } + ListKeepErrs => { + // List.keepErrs : List before, (before -> Result * after) -> List after + debug_assert_eq!(args.len(), 2); + + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + + match (list_layout, layout) { + (_, Layout::Builtin(Builtin::EmptyList)) + | (Layout::Builtin(Builtin::EmptyList), _) => { + return empty_list(env); + } + ( + Layout::Builtin(Builtin::List(_, before_layout)), + Layout::Builtin(Builtin::List(_, after_layout)), + ) => list_keep_errs( + env, + layout_ids, + func, + func_layout, + list, + before_layout, + after_layout, + ), + (other1, other2) => { + unreachable!("invalid list layouts:\n{:?}\n{:?}", other1, other2) + } + } + } ListContains => { // List.contains : List elem, elem -> Bool debug_assert_eq!(args.len(), 2); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 05ef4237e0..0d78065f56 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1139,6 +1139,121 @@ pub fn list_keep_if<'a, 'ctx, 'env>( ) } +/// List.keepOks : List before, (before -> Result after *) -> List after +#[allow(clippy::too_many_arguments)] +pub fn list_keep_oks<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + before_layout: &Layout<'a>, + after_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_keep_result( + env, + layout_ids, + transform, + transform_layout, + list, + before_layout, + after_layout, + bitcode::LIST_KEEP_OKS, + ) +} + +/// List.keepErrs : List before, (before -> Result * after) -> List after +#[allow(clippy::too_many_arguments)] +pub fn list_keep_errs<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + before_layout: &Layout<'a>, + after_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_keep_result( + env, + layout_ids, + transform, + transform_layout, + list, + before_layout, + after_layout, + bitcode::LIST_KEEP_ERRS, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn list_keep_result<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + before_layout: &Layout<'a>, + after_layout: &Layout<'a>, + op: &str, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let result_layout = match transform_layout { + Layout::FunctionPointer(_, ret) => ret, + Layout::Closure(_, _, ret) => ret, + _ => unreachable!("not a callable layout"), + }; + + let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128"); + + let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); + env.builder.build_store(transform_ptr, transform); + + let stepper_caller = + build_transform_caller(env, layout_ids, transform_layout, &[before_layout.clone()]) + .as_global_value() + .as_pointer_value(); + + let before_width = env + .ptr_int() + .const_int(before_layout.stack_size(env.ptr_bytes) as u64, false); + + let after_width = env + .ptr_int() + .const_int(after_layout.stack_size(env.ptr_bytes) as u64, false); + + let result_width = env + .ptr_int() + .const_int(result_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = before_layout.alignment_bytes(env.ptr_bytes); + let alignment_iv = env.ptr_int().const_int(alignment as u64, false); + + let output = call_bitcode_fn( + env, + &[ + list_i128.into(), + env.builder + .build_bitcast(transform_ptr, u8_ptr, "to_opaque"), + stepper_caller.into(), + alignment_iv.into(), + before_width.into(), + result_width.into(), + after_width.into(), + ], + op, + ); + + complex_bitcast( + env.builder, + output, + collection(env.context, env.ptr_bytes).into(), + "from_i128", + ) +} + /// List.map : List before, (before -> after) -> List after #[allow(clippy::too_many_arguments)] pub fn list_map<'a, 'ctx, 'env>( diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index 88c72bdac1..efbe56bc42 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -1709,6 +1709,26 @@ mod gen_list { assert_evals_to!("List.sum [ 1.1, 2.2, 3.3 ]", 6.6, f64); } + #[test] + fn list_keep_oks() { + assert_evals_to!("List.keepOks [] (\\x -> x)", 0, i64); + assert_evals_to!("List.keepOks [1,2] (\\x -> Ok x)", &[1, 2], &[i64]); + assert_evals_to!("List.keepOks [1,2] (\\x -> x % 2)", &[1, 0], &[i64]); + assert_evals_to!("List.keepOks [Ok 1, Err 2] (\\x -> x)", &[1], &[i64]); + } + + #[test] + fn list_keep_errs() { + assert_evals_to!("List.keepErrs [] (\\x -> x)", 0, i64); + assert_evals_to!("List.keepErrs [1,2] (\\x -> Err x)", &[1, 2], &[i64]); + assert_evals_to!( + "List.keepErrs [0,1,2] (\\x -> x % 0 |> Result.mapErr (\\_ -> 32))", + &[32, 32, 32], + &[i64] + ); + assert_evals_to!("List.keepErrs [Ok 1, Err 2] (\\x -> x)", &[2], &[i64]); + } + #[test] #[should_panic(expected = r#"Roc failed with message: "integer addition overflowed!"#)] fn cleanup_because_exception() { diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 1b715b8991..0acb97d0d8 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -29,6 +29,8 @@ pub enum LowLevel { ListWalk, ListWalkBackwards, ListSum, + ListKeepOks, + ListKeepErrs, DictSize, DictEmpty, DictInsert, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index ab65b391a5..9380c1564b 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -896,6 +896,8 @@ define_builtins! { 18 LIST_SUM: "sum" 19 LIST_WALK: "walk" 20 LIST_LAST: "last" + 21 LIST_KEEP_OKS: "keepOks" + 22 LIST_KEEP_ERRS: "keepErrs" } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index e089f63b2d..3783ab9a5b 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -582,7 +582,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { StrJoinWith => arena.alloc_slice_copy(&[irrelevant, irrelevant]), ListJoin => arena.alloc_slice_copy(&[irrelevant]), ListMap => arena.alloc_slice_copy(&[owned, irrelevant]), - ListKeepIf => arena.alloc_slice_copy(&[owned, irrelevant]), + ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, irrelevant]), ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]), ListWalkBackwards => arena.alloc_slice_copy(&[owned, irrelevant, owned]),