diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index ffdbfcba2e..906ffa4456 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -1361,3 +1361,39 @@ inline fn listSetImmutable( //return list; return new_bytes; } + +pub fn listFindUnsafe( + list: RocList, + caller: Caller1, + data: Opaque, + inc_n_data: IncN, + data_is_owned: bool, + alignment: u32, + element_width: usize, + inc: Inc, + dec: Dec, +) callconv(.C) extern struct { value: Opaque, found: bool } { + if (list.bytes) |source_ptr| { + const size = list.len(); + if (data_is_owned) { + inc_n_data(data, size); + } + + var i: usize = 0; + while (i < size) : (i += 1) { + var theOne = false; + const element = source_ptr + (i * element_width); + inc(element); + caller(data, element, @ptrCast(?[*]u8, &theOne)); + + if (theOne) { + return .{ .value = element, .found = true }; + } else { + dec(element); + } + } + return .{ .value = null, .found = false }; + } else { + return .{ .value = null, .found = false }; + } +} diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index ba063b0702..b8205c7cfd 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -53,6 +53,7 @@ comptime { exportListFn(list.listSetInPlace, "set_in_place"); exportListFn(list.listSwap, "swap"); exportListFn(list.listAny, "any"); + exportListFn(list.listFindUnsafe, "find_unsafe"); } // Dict Module diff --git a/compiler/builtins/docs/List.roc b/compiler/builtins/docs/List.roc index 17530aba37..be4f8891c1 100644 --- a/compiler/builtins/docs/List.roc +++ b/compiler/builtins/docs/List.roc @@ -690,3 +690,7 @@ all : List elem, (elem -> Bool) -> Bool ## Run the given predicate on each element of the list, returning `True` if ## any of the elements satisfy it. any : List elem, (elem -> Bool) -> Bool + +## Returns the first element of the list satisfying a predicate function. +## If no satisfying element is found, an `Err NotFound` is returned. +find : List elem, (elem -> Bool) -> Result elem [ NotFound ]* diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index 1ee6c97133..45b297143f 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -191,6 +191,7 @@ pub const LIST_CONCAT: &str = "roc_builtins.list.concat"; pub const LIST_SET: &str = "roc_builtins.list.set"; pub const LIST_SET_IN_PLACE: &str = "roc_builtins.list.set_in_place"; pub const LIST_ANY: &str = "roc_builtins.list.any"; +pub const LIST_FIND_UNSAFE: &str = "roc_builtins.list.find_unsafe"; pub const DEC_FROM_F64: &str = "roc_builtins.dec.from_f64"; pub const DEC_EQ: &str = "roc_builtins.dec.eq"; diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 700609d936..e58b83639d 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -1093,6 +1093,23 @@ pub fn types() -> MutMap { Box::new(list_type(flex(TVAR1))), ); + // find : List elem, (elem -> Bool) -> Result elem [ NotFound ]* + { + let not_found = SolvedType::TagUnion( + vec![(TagName::Global("NotFound".into()), vec![])], + Box::new(SolvedType::Wildcard), + ); + let (elem, cvar) = (TVAR1, TVAR2); + add_top_level_function_type!( + Symbol::LIST_FIND, + vec![ + list_type(flex(elem)), + closure(vec![flex(elem)], cvar, Box::new(bool_type())), + ], + Box::new(result_type(flex(elem), not_found)), + ) + } + // Dict module // len : Dict * * -> Nat diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index 2a3cebae1e..1894f09284 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -108,6 +108,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option LIST_WALK_UNTIL => list_walk_until, LIST_SORT_WITH => list_sort_with, LIST_ANY => list_any, + LIST_FIND => list_find, DICT_LEN => dict_len, DICT_EMPTY => dict_empty, DICT_SINGLE => dict_single, @@ -2748,6 +2749,87 @@ fn list_any(symbol: Symbol, var_store: &mut VarStore) -> Def { lowlevel_2(symbol, LowLevel::ListAny, var_store) } +/// List.find : List elem, (elem -> Bool) -> Result elem [ NotFound ]* +fn list_find(symbol: Symbol, var_store: &mut VarStore) -> Def { + let list = Symbol::ARG_1; + let find_predicate = Symbol::ARG_2; + + let find_result = Symbol::LIST_FIND_RESULT; + + let t_list = var_store.fresh(); + let t_pred_fn = var_store.fresh(); + let t_bool = var_store.fresh(); + let t_found = var_store.fresh(); + let t_value = var_store.fresh(); + let t_ret = var_store.fresh(); + let t_find_result = var_store.fresh(); + let t_ext_var1 = var_store.fresh(); + let t_ext_var2 = var_store.fresh(); + + // ListFindUnsafe returns { value: elem, found: Bool }. + // When `found` is true, the value was found. Otherwise `List.find` should return `Err ...` + let find_result_def = Def { + annotation: None, + expr_var: t_find_result, + loc_expr: no_region(RunLowLevel { + op: LowLevel::ListFindUnsafe, + args: vec![(t_list, Var(list)), (t_pred_fn, Var(find_predicate))], + ret_var: t_find_result, + }), + loc_pattern: no_region(Pattern::Identifier(find_result)), + pattern_vars: Default::default(), + }; + + let get_value = Access { + record_var: t_find_result, + ext_var: t_ext_var1, + field_var: t_value, + loc_expr: Box::new(no_region(Var(find_result))), + field: "value".into(), + }; + + let get_found = Access { + record_var: t_find_result, + ext_var: t_ext_var2, + field_var: t_found, + loc_expr: Box::new(no_region(Var(find_result))), + field: "found".into(), + }; + + let make_ok = tag("Ok", vec![get_value], var_store); + + let make_err = tag( + "Err", + vec![tag("NotFound", Vec::new(), var_store)], + var_store, + ); + + let inspect = If { + cond_var: t_bool, + branch_var: t_ret, + branches: vec![( + // if-condition + no_region(get_found), + no_region(make_ok), + )], + final_else: Box::new(no_region(make_err)), + }; + + let body = LetNonRec( + Box::new(find_result_def), + Box::new(no_region(inspect)), + t_ret, + ); + + defn( + symbol, + vec![(t_list, Symbol::ARG_1), (t_pred_fn, Symbol::ARG_2)], + var_store, + body, + t_ret, + ) +} + /// Dict.len : Dict * * -> Nat fn dict_len(symbol: Symbol, var_store: &mut VarStore) -> Def { let arg1_var = var_store.fresh(); diff --git a/compiler/gen_llvm/src/llvm/build.rs b/compiler/gen_llvm/src/llvm/build.rs index c5221f2890..5e30070c72 100644 --- a/compiler/gen_llvm/src/llvm/build.rs +++ b/compiler/gen_llvm/src/llvm/build.rs @@ -9,10 +9,11 @@ use crate::llvm::build_dict::{ use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ self, allocate_list, empty_list, empty_polymorphic_list, list_any, list_append, list_concat, - list_contains, list_drop, list_drop_at, list_get_unsafe, list_join, list_keep_errs, - list_keep_if, list_keep_oks, list_len, list_map, list_map2, list_map3, list_map4, - list_map_with_index, list_prepend, list_range, list_repeat, list_reverse, list_set, - list_single, list_sort_with, list_swap, list_take_first, list_take_last, + list_contains, list_drop, list_drop_at, list_find_trivial_not_found, list_find_unsafe, + list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, + list_map2, list_map3, list_map4, list_map_with_index, list_prepend, list_range, list_repeat, + list_reverse, list_set, list_single, list_sort_with, list_swap, list_take_first, + list_take_last, }; use crate::llvm::build_str::{ empty_str, str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, @@ -4887,6 +4888,37 @@ fn run_higher_order_low_level<'a, 'ctx, 'env>( _ => unreachable!("invalid list layout"), } } + ListFindUnsafe { xs } => { + let (list, list_layout) = load_symbol_and_layout(scope, &xs); + + let (function, closure, closure_layout) = function_details!(); + + match list_layout { + Layout::Builtin(Builtin::EmptyList) => { + // Returns { found: False, elem: \empty }, where the `elem` field is zero-sized. + // NB: currently we never hit this case, since the only caller of this + // lowlevel, namely List.find, will fail during monomorphization when there is no + // concrete list element type. This is because List.find returns a + // `Result elem [ NotFound ]*`, and we can't figure out the size of that if + // `elem` is not concrete. + list_find_trivial_not_found(env) + } + Layout::Builtin(Builtin::List(element_layout)) => { + let argument_layouts = &[**element_layout]; + let roc_function_call = roc_function_call( + env, + layout_ids, + function, + closure, + closure_layout, + function_owns_closure_data, + argument_layouts, + ); + list_find_unsafe(env, layout_ids, roc_function_call, list, element_layout) + } + _ => unreachable!("invalid list layout"), + } + } DictWalk { xs, state } => { let (dict, dict_layout) = load_symbol_and_layout(scope, &xs); let (default, default_layout) = load_symbol_and_layout(scope, &state); @@ -5778,7 +5810,9 @@ fn run_low_level<'a, 'ctx, 'env>( ListMap | ListMap2 | ListMap3 | ListMap4 | ListMapWithIndex | ListKeepIf | ListWalk | ListWalkUntil | ListWalkBackwards | ListKeepOks | ListKeepErrs | ListSortWith - | ListAny | DictWalk => unreachable!("these are higher order, and are handled elsewhere"), + | ListAny | ListFindUnsafe | DictWalk => { + unreachable!("these are higher order, and are handled elsewhere") + } } } diff --git a/compiler/gen_llvm/src/llvm/build_list.rs b/compiler/gen_llvm/src/llvm/build_list.rs index e3f0b00ff4..7112c4b19a 100644 --- a/compiler/gen_llvm/src/llvm/build_list.rs +++ b/compiler/gen_llvm/src/llvm/build_list.rs @@ -958,6 +958,123 @@ pub fn list_any<'a, 'ctx, 'env>( ) } +/// List.findUnsafe : List elem, (elem -> Bool) -> { value: elem, found: bool } +pub fn list_find_unsafe<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + roc_function_call: RocFunctionCall<'ctx>, + list: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let inc_element_fn = build_inc_wrapper(env, layout_ids, element_layout); + let dec_element_fn = build_dec_wrapper(env, layout_ids, element_layout); + + // { value: *const u8, found: bool } + let result = call_bitcode_fn( + env, + &[ + pass_list_cc(env, list), + roc_function_call.caller.into(), + pass_as_opaque(env, roc_function_call.data), + roc_function_call.inc_n_data.into(), + roc_function_call.data_is_owned.into(), + env.alignment_intvalue(element_layout), + layout_width(env, element_layout), + inc_element_fn.as_global_value().as_pointer_value().into(), + dec_element_fn.as_global_value().as_pointer_value().into(), + ], + bitcode::LIST_FIND_UNSAFE, + ) + .into_struct_value(); + + // We promised the caller we'd give them back a struct containing the element + // loaded on the stack, so we do that now. The element can't be loaded directly + // in the Zig definition called above, because we don't know the size of the + // element until user compile time, which is later than the compile time of bitcode defs. + + let value_u8_ptr = env + .builder + .build_extract_value(result, 0, "get_value_ptr") + .unwrap() + .into_pointer_value(); + + let found = env + .builder + .build_extract_value(result, 1, "get_found") + .unwrap() + .into_int_value(); + + let start_block = env.builder.get_insert_block().unwrap(); + let parent = start_block.get_parent().unwrap(); + + let if_not_null = env.context.append_basic_block(parent, "if_not_null"); + let done_block = env.context.append_basic_block(parent, "done"); + + let value_bt = basic_type_from_layout(env, element_layout); + let default = value_bt.const_zero(); + + env.builder + .build_conditional_branch(found, if_not_null, done_block); + + env.builder.position_at_end(if_not_null); + let value_ptr = env + .builder + .build_bitcast( + value_u8_ptr, + value_bt.ptr_type(AddressSpace::Generic), + "from_opaque", + ) + .into_pointer_value(); + let loaded = env.builder.build_load(value_ptr, "load_value"); + env.builder.build_unconditional_branch(done_block); + + env.builder.position_at_end(done_block); + let result_phi = env.builder.build_phi(value_bt, "result"); + + result_phi.add_incoming(&[(&default, start_block), (&loaded, if_not_null)]); + + let value = result_phi.as_basic_value(); + + let result = env + .context + .struct_type(&[value_bt, env.context.bool_type().into()], false) + .const_zero(); + + let result = env + .builder + .build_insert_value(result, value, 0, "insert_value") + .unwrap(); + + env.builder + .build_insert_value(result, found, 1, "insert_found") + .unwrap() + .into_struct_value() + .into() +} + +/// Returns { value: \empty, found: False }, representing that no element was found in a call +/// to List.find when the layout of the element is also unknown. +pub fn list_find_trivial_not_found<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, +) -> BasicValueEnum<'ctx> { + let empty_type = env.context.custom_width_int_type(0); + let result = env + .context + .struct_type(&[empty_type.into(), env.context.bool_type().into()], false) + .const_zero(); + + env.builder + .build_insert_value( + result, + env.context.bool_type().const_zero(), + 1, + "insert_found", + ) + .unwrap() + .into_struct_value() + .into() +} + pub fn decrementing_elem_loop<'ctx, LoopFn>( builder: &Builder<'ctx>, ctx: &'ctx Context, diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 0c0be62ed9..682ff9c7e2 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -48,6 +48,7 @@ pub enum LowLevel { ListDropAt, ListSwap, ListAny, + ListFindUnsafe, DictSize, DictEmpty, DictInsert, @@ -227,6 +228,7 @@ macro_rules! higher_order { | ListKeepErrs | ListSortWith | ListAny + | ListFindUnsafe | DictWalk }; } @@ -261,6 +263,7 @@ impl LowLevel { ListKeepErrs => 1, ListSortWith => 1, ListAny => 1, + ListFindUnsafe => 1, DictWalk => 2, } } diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 5306e59b93..1cc1cd15a4 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -1067,6 +1067,8 @@ define_builtins! { 44 LIST_ANY: "any" 45 LIST_TAKE_FIRST: "takeFirst" 46 LIST_TAKE_LAST: "takeLast" + 47 LIST_FIND: "find" + 48 LIST_FIND_RESULT: "#find_result" // symbol used in the definition of List.find } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias diff --git a/compiler/mono/src/alias_analysis.rs b/compiler/mono/src/alias_analysis.rs index 51cae3886f..bd90890392 100644 --- a/compiler/mono/src/alias_analysis.rs +++ b/compiler/mono/src/alias_analysis.rs @@ -1093,6 +1093,41 @@ fn call_spec( add_loop(builder, block, state_type, init_state, loop_body) } + ListFindUnsafe { xs } => { + let list = env.symbols[xs]; + + // ListFindUnsafe returns { value: v, found: Bool=Int1 } + let output_layouts = vec![arg_layouts[0], Layout::Builtin(Builtin::Int1)]; + let output_layout = Layout::Struct(&output_layouts); + let output_type = layout_spec(builder, &output_layout)?; + + let loop_body = |builder: &mut FuncDefBuilder, block, output| { + let bag = builder.add_get_tuple_field(block, list, LIST_BAG_INDEX)?; + let element = builder.add_bag_get(block, bag)?; + let _is_found = call_function!(builder, block, [element]); + + // We may or may not use the element we got from the list in the output struct, + // depending on whether we found the element to satisfy the "find" predicate. + // If we did find the element, our output "changes" to be a record including that element. + let found_branch = builder.add_block(); + let new_output = + builder.add_unknown_with(block, &[element], output_type)?; + + let not_found_branch = builder.add_block(); + + builder.add_choice( + block, + &[ + BlockExpr(found_branch, new_output), + BlockExpr(not_found_branch, output), + ], + ) + }; + + // Assume the output is initially { found: False, value: \empty } + let output_state = builder.add_unknown_with(block, &[], output_type)?; + add_loop(builder, block, output_type, output_state, loop_body) + } } } } diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index ab6fccb595..a9465d54eb 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -618,7 +618,8 @@ impl<'a> BorrowInfState<'a> { | ListKeepIf { xs } | ListKeepOks { xs } | ListKeepErrs { xs } - | ListAny { xs } => { + | ListAny { xs } + | ListFindUnsafe { xs } => { // own the list if the function wants to own the element if !function_ps[0].borrow { self.own_var(*xs); @@ -959,6 +960,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { arena.alloc_slice_copy(&[owned, owned, function, closure_data]) } ListSortWith => arena.alloc_slice_copy(&[owned, function, closure_data]), + ListFindUnsafe => arena.alloc_slice_copy(&[owned, function, closure_data]), // TODO when we have lists with capacity (if ever) // List.append should own its first argument diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index f33c38f3ac..d377795253 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -531,7 +531,8 @@ impl<'a> Context<'a> { | ListKeepIf { xs } | ListKeepOks { xs } | ListKeepErrs { xs } - | ListAny { xs } => { + | ListAny { xs } + | ListFindUnsafe { xs } => { let borrows = [function_ps[0].borrow, FUNCTION, CLOSURE_DATA]; let b = self.add_dec_after_lowlevel(arguments, &borrows, b, b_live_vars); diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index ff43de921d..bb84793eac 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -4164,6 +4164,11 @@ pub fn with_hole<'a>( match_on_closure_argument!(ListMap4, [xs, ys, zs, ws]) } + ListFindUnsafe => { + debug_assert_eq!(arg_symbols.len(), 2); + let xs = arg_symbols[0]; + match_on_closure_argument!(ListFindUnsafe, [xs]) + } _ => { let call = self::Call { call_type: CallType::LowLevel { diff --git a/compiler/mono/src/low_level.rs b/compiler/mono/src/low_level.rs index 8546ec2c4e..2a31addd74 100644 --- a/compiler/mono/src/low_level.rs +++ b/compiler/mono/src/low_level.rs @@ -50,6 +50,9 @@ pub enum HigherOrder { ListAny { xs: Symbol, }, + ListFindUnsafe { + xs: Symbol, + }, DictWalk { xs: Symbol, state: Symbol, @@ -71,6 +74,7 @@ impl HigherOrder { HigherOrder::ListKeepOks { .. } => 1, HigherOrder::ListKeepErrs { .. } => 1, HigherOrder::ListSortWith { .. } => 2, + HigherOrder::ListFindUnsafe { .. } => 1, HigherOrder::DictWalk { .. } => 2, HigherOrder::ListAny { .. } => 1, } diff --git a/compiler/test_gen/src/gen_list.rs b/compiler/test_gen/src/gen_list.rs index bef46421dc..baf17d2b41 100644 --- a/compiler/test_gen/src/gen_list.rs +++ b/compiler/test_gen/src/gen_list.rs @@ -2394,7 +2394,7 @@ fn list_join_map() { RocStr::from_slice("cyrus".as_bytes()), ]), RocList - ); + ) } #[test] @@ -2408,5 +2408,68 @@ fn list_join_map_empty() { ), RocList::from_slice(&[]), RocList + ) +} + +#[test] +#[cfg(any(feature = "gen-llvm"))] +fn list_find() { + assert_evals_to!( + indoc!( + r#" + when List.find ["a", "bc", "def"] (\s -> Str.countGraphemes s > 1) is + Ok v -> v + Err _ -> "not found" + "# + ), + RocStr::from_slice(b"bc"), + RocStr + ); +} + +#[test] +#[cfg(any(feature = "gen-llvm"))] +fn list_find_not_found() { + assert_evals_to!( + indoc!( + r#" + when List.find ["a", "bc", "def"] (\s -> Str.countGraphemes s > 5) is + Ok v -> v + Err _ -> "not found" + "# + ), + RocStr::from_slice(b"not found"), + RocStr + ); +} + +#[test] +#[cfg(any(feature = "gen-llvm"))] +fn list_find_empty_typed_list() { + assert_evals_to!( + indoc!( + r#" + when List.find [] (\s -> Str.countGraphemes s > 5) is + Ok v -> v + Err _ -> "not found" + "# + ), + RocStr::from_slice(b"not found"), + RocStr + ); +} + +#[test] +#[cfg(any(feature = "gen-llvm"))] +#[ignore = "Fails because monomorphization can't be done if we don't have a concrete element type!"] +fn list_find_empty_layout() { + assert_evals_to!( + indoc!( + r#" + List.find [] (\_ -> True) + "# + ), + 0, + i64 ); }