diff --git a/compiler/builtins/bitcode/README.md b/compiler/builtins/bitcode/README.md index 604f3a53d0..1585c2e8fb 100644 --- a/compiler/builtins/bitcode/README.md +++ b/compiler/builtins/bitcode/README.md @@ -7,7 +7,7 @@ To add a builtin: 2. Make sure the function is public with the `pub` keyword and uses the C calling convention. This is really easy, just add `pub` and `callconv(.C)` to the function declaration like so: `pub fn atan(num: f64) callconv(.C) f64 { ... }` 3. In `src/main.zig`, export the function. This is also organized by module. For example, for a `Num` function find the `Num` section and add: `comptime { exportNumFn(num.atan, "atan"); }`. The first argument is the function, the second is the name of it in LLVM. 4. In `compiler/builtins/src/bitcode.rs`, add a constant for the new function. This is how we use it in Rust. Once again, this is organized by module, so just find the relevant area and add your new function. -5. You can now your function in Rust using `call_bitcode_fn` in `llvm/src/build.rs`! +5. You can now use your function in Rust using `call_bitcode_fn` in `llvm/src/build.rs`! ## How it works diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index 70925554f1..012f24b141 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -1256,6 +1256,93 @@ pub fn listConcat(list_a: RocList, list_b: RocList, alignment: u32, element_widt return output; } +pub fn listReplaceInPlace( + bytes: ?[*]u8, + index: usize, + element: Opaque, + element_width: usize, +) callconv(.C) ?[*]u8 { + // INVARIANT: bounds checking happens on the roc side + // + // at the time of writing, the function is implemented roughly as + // `if inBounds then LowLevelListReplace input index item else input` + // so we don't do a bounds check here. Hence, the list is also non-empty, + // because inserting into an empty list is always out of bounds + + return listReplaceInPlaceHelp(bytes, index, element, element_width); +} + +pub fn listReplace( + bytes: ?[*]u8, + length: usize, + alignment: u32, + index: usize, + element: Opaque, + element_width: usize, +) callconv(.C) ?[*]u8 { + // INVARIANT: bounds checking happens on the roc side + // + // at the time of writing, the function is implemented roughly as + // `if inBounds then LowLevelListReplace input index item else input` + // so we don't do a bounds check here. Hence, the list is also non-empty, + // because inserting into an empty list is always out of bounds + const ptr: [*]usize = @ptrCast([*]usize, @alignCast(@alignOf(usize), bytes)); + + if ((ptr - 1)[0] == utils.REFCOUNT_ONE) { + return listReplaceInPlaceHelp(bytes, index, element, element_width); + } else { + return listReplaceImmutable(bytes, length, alignment, index, element, element_width); + } +} + +inline fn listReplaceInPlaceHelp( + bytes: ?[*]u8, + index: usize, + element: Opaque, + element_width: usize, +) ?[*]u8 { + // the element we will replace + var element_at_index = (bytes orelse undefined) + (index * element_width); + + // decrement its refcount + // dec(element_at_index); + + // copy in the new element + @memcpy(element_at_index, element orelse undefined, element_width); + + return bytes; +} + +inline fn listReplaceImmutable( + old_bytes: ?[*]u8, + length: usize, + alignment: u32, + index: usize, + element: Opaque, + element_width: usize, +) ?[*]u8 { + const data_bytes = length * element_width; + + var new_bytes = utils.allocateWithRefcount(data_bytes, alignment); + + @memcpy(new_bytes, old_bytes orelse undefined, data_bytes); + + // the element we will replace + var element_at_index = new_bytes + (index * element_width); + + // decrement its refcount + // dec(element_at_index); + + // copy in the new element + @memcpy(element_at_index, element orelse undefined, element_width); + + // consume RC token of original + utils.decref(old_bytes, data_bytes, alignment); + + //return list; + return new_bytes; +} + pub fn listSetInPlace( bytes: ?[*]u8, index: usize, diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index c047588d92..8a23f09985 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -49,6 +49,8 @@ comptime { exportListFn(list.listConcat, "concat"); exportListFn(list.listSublist, "sublist"); exportListFn(list.listDropAt, "drop_at"); + exportListFn(list.listReplace, "replace"); + exportListFn(list.listReplaceInPlace, "replace_in_place"); exportListFn(list.listSet, "set"); exportListFn(list.listSetInPlace, "set_in_place"); exportListFn(list.listSwap, "swap"); diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index 1345156d0f..c4c0b12e01 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -354,6 +354,8 @@ pub const LIST_RANGE: &str = "roc_builtins.list.range"; pub const LIST_REVERSE: &str = "roc_builtins.list.reverse"; pub const LIST_SORT_WITH: &str = "roc_builtins.list.sort_with"; pub const LIST_CONCAT: &str = "roc_builtins.list.concat"; +pub const LIST_REPLACE: &str = "roc_builtins.list.replace"; +pub const LIST_REPLACE_IN_PLACE: &str = "roc_builtins.list.replace_in_place"; pub const LIST_SET: &str = "roc_builtins.list.set"; pub const LIST_SET_IN_PLACE: &str = "roc_builtins.list.set_in_place"; pub const LIST_ANY: &str = "roc_builtins.list.any"; diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 97394b969b..35ab535a6a 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -4,9 +4,9 @@ use roc_module::symbol::Symbol; use roc_region::all::Region; use roc_types::builtin_aliases::{ bool_type, dec_type, dict_type, f32_type, f64_type, float_type, i128_type, i16_type, i32_type, - i64_type, i8_type, int_type, list_type, nat_type, num_type, ordering_type, result_type, - set_type, str_type, str_utf8_byte_problem_type, u128_type, u16_type, u32_type, u64_type, - u8_type, + i64_type, i8_type, int_type, list_type, nat_type, num_type, ordering_type, pair_type, + result_type, set_type, str_type, str_utf8_byte_problem_type, u128_type, u16_type, u32_type, + u64_type, u8_type, }; use roc_types::solved_types::SolvedType; use roc_types::subs::VarId; @@ -1034,7 +1034,7 @@ pub fn types() -> MutMap { add_top_level_function_type!( Symbol::LIST_GET, vec![list_type(flex(TVAR1)), nat_type()], - Box::new(result_type(flex(TVAR1), index_out_of_bounds)), + Box::new(result_type(flex(TVAR1), index_out_of_bounds.clone())), ); // first : List elem -> Result elem [ ListWasEmpty ]* @@ -1056,6 +1056,16 @@ pub fn types() -> MutMap { Box::new(result_type(flex(TVAR1), list_was_empty.clone())), ); + // replace : List elem, Nat, elem -> Result (Pair (List elem) elem) [ OutOfBounds ]* + add_top_level_function_type!( + Symbol::LIST_REPLACE, + vec![list_type(flex(TVAR1)), nat_type(), flex(TVAR1)], + Box::new(result_type( + pair_type(list_type(flex(TVAR1)), flex(TVAR1)), + index_out_of_bounds + )), + ); + // set : List elem, Nat, elem -> List elem add_top_level_function_type!( Symbol::LIST_SET, diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index a8228c6ecb..ae6375ada7 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -102,6 +102,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option STR_TO_I8 => str_to_num, LIST_LEN => list_len, LIST_GET => list_get, + LIST_REPLACE => list_replace, LIST_SET => list_set, LIST_APPEND => list_append, LIST_FIRST => list_first, @@ -2303,6 +2304,97 @@ fn list_get(symbol: Symbol, var_store: &mut VarStore) -> Def { ) } +/// List.replace : List elem, Nat, elem -> Result (Pair (List elem) elem) [ OutOfBounds ]* +/// +/// List.replace : +/// Attr (w | u | v) (List (Attr u a)), +/// Attr * Int, +/// Attr (u | v) a +/// -> Attr * (List (Attr u a)) +/// -> Attr * (Result (Pair (List (Attr u a)) Attr u a)) (Attr * [ OutOfBounds ]*)) +fn list_replace(symbol: Symbol, var_store: &mut VarStore) -> Def { + let arg_list = Symbol::ARG_1; + let arg_index = Symbol::ARG_2; + let arg_elem = Symbol::ARG_3; + let bool_var = var_store.fresh(); + let len_var = var_store.fresh(); + let elem_var = var_store.fresh(); + let list_arg_var = var_store.fresh(); + let ret_pair_var = var_store.fresh(); + + // Perform a bounds check. If it passes, run LowLevel::ListReplace. + // Otherwise, return the list unmodified. + let body = If { + cond_var: bool_var, + branch_var: ret_pair_var, + branches: vec![( + // if-condition + no_region( + // index < List.len list + RunLowLevel { + op: LowLevel::NumLt, + args: vec![ + (len_var, Var(arg_index)), + ( + len_var, + RunLowLevel { + op: LowLevel::ListLen, + args: vec![(list_arg_var, Var(arg_list))], + ret_var: len_var, + }, + ), + ], + ret_var: bool_var, + }, + ), + // then-branch + no_region( + // Ok + tag( + "Ok", + vec![ + // TODO: This should probably call get and then build the pair + // List.replaceUnsafe list index elem + RunLowLevel { + op: LowLevel::ListReplace, + args: vec![ + (list_arg_var, Var(arg_list)), + (len_var, Var(arg_index)), + (elem_var, Var(arg_elem)), + ], + ret_var: ret_pair_var, + }, + ], + var_store, + ), + ), + )], + final_else: Box::new( + // else-branch + no_region( + // Err + tag( + "Err", + vec![tag("OutOfBounds", Vec::new(), var_store)], + var_store, + ), + ), + ), + }; + + defn( + symbol, + vec![ + (list_arg_var, Symbol::ARG_1), + (len_var, Symbol::ARG_2), + (elem_var, Symbol::ARG_3), + ], + var_store, + body, + ret_pair_var, + ) +} + /// List.set : List elem, Nat, elem -> List elem /// /// List.set : @@ -2347,7 +2439,7 @@ fn list_set(symbol: Symbol, var_store: &mut VarStore) -> Def { ), // then-branch no_region( - // List.setUnsafe list index + // List.setUnsafe list index elem RunLowLevel { op: LowLevel::ListSet, args: vec![ diff --git a/compiler/gen_llvm/src/llvm/build.rs b/compiler/gen_llvm/src/llvm/build.rs index af42090989..b44f84767b 100644 --- a/compiler/gen_llvm/src/llvm/build.rs +++ b/compiler/gen_llvm/src/llvm/build.rs @@ -13,8 +13,8 @@ use crate::llvm::build_list::{ self, allocate_list, empty_polymorphic_list, list_all, list_any, list_append, list_concat, list_contains, list_drop_at, list_find_unsafe, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, list_map2, list_map3, list_map4, - list_map_with_index, list_prepend, list_range, list_repeat, list_reverse, list_set, - list_single, list_sort_with, list_sublist, list_swap, + list_map_with_index, list_prepend, list_range, list_repeat, list_replace, list_reverse, + list_set, list_single, list_sort_with, list_sublist, list_swap, }; use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8, @@ -5653,6 +5653,21 @@ fn run_low_level<'a, 'ctx, 'env>( wrapper_struct, ) } + ListReplace => { + let list = load_symbol(scope, &args[0]); + let index = load_symbol(scope, &args[1]); + let (element, element_layout) = load_symbol_and_layout(scope, &args[2]); + + list_replace( + env, + layout_ids, + list, + index.into_int_value(), + element, + element_layout, + update_mode, + ) + } ListSet => { let list = load_symbol(scope, &args[0]); let index = load_symbol(scope, &args[1]); diff --git a/compiler/gen_llvm/src/llvm/build_list.rs b/compiler/gen_llvm/src/llvm/build_list.rs index 3f865b77aa..b27316a443 100644 --- a/compiler/gen_llvm/src/llvm/build_list.rs +++ b/compiler/gen_llvm/src/llvm/build_list.rs @@ -291,6 +291,50 @@ pub fn list_drop_at<'a, 'ctx, 'env>( ) } +/// List.replace : List elem, Nat, elem -> List elem +pub fn list_replace<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + list: BasicValueEnum<'ctx>, + index: IntValue<'ctx>, + element: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, + update_mode: UpdateMode, +) -> BasicValueEnum<'ctx> { + let (length, bytes) = load_list( + env.builder, + list.into_struct_value(), + env.context.i8_type().ptr_type(AddressSpace::Generic), + ); + + let new_bytes = match update_mode { + UpdateMode::InPlace => call_bitcode_fn( + env, + &[ + bytes.into(), + index.into(), + pass_element_as_opaque(env, element, *element_layout), + layout_width(env, element_layout), + ], + bitcode::LIST_REPLACE_IN_PLACE, + ), + UpdateMode::Immutable => call_bitcode_fn( + env, + &[ + bytes.into(), + length.into(), + env.alignment_intvalue(element_layout), + index.into(), + pass_element_as_opaque(env, element, *element_layout), + layout_width(env, element_layout), + ], + bitcode::LIST_REPLACE, + ), + }; + + store_list(env, new_bytes.into_pointer_value(), length) +} + /// List.set : List elem, Nat, elem -> List elem pub fn list_set<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, diff --git a/compiler/gen_wasm/src/low_level.rs b/compiler/gen_wasm/src/low_level.rs index a5875e1813..8d71d3f484 100644 --- a/compiler/gen_wasm/src/low_level.rs +++ b/compiler/gen_wasm/src/low_level.rs @@ -260,14 +260,14 @@ impl<'a> LowLevelCall<'a> { _ => internal_error!("invalid storage for List"), }, - ListGetUnsafe | ListSet | ListSingle | ListRepeat | ListReverse | ListConcat - | ListContains | ListAppend | ListPrepend | ListJoin | ListRange | ListMap - | ListMap2 | ListMap3 | ListMap4 | ListMapWithIndex | ListKeepIf | ListWalk - | ListWalkUntil | ListWalkBackwards | ListKeepOks | ListKeepErrs | ListSortWith - | ListSublist | ListDropAt | ListSwap | ListAny | ListAll | ListFindUnsafe - | DictSize | DictEmpty | DictInsert | DictRemove | DictContains | DictGetUnsafe - | DictKeys | DictValues | DictUnion | DictIntersection | DictDifference | DictWalk - | SetFromList => { + ListGetUnsafe | ListReplace | ListSet | ListSingle | ListRepeat | ListReverse + | ListConcat | ListContains | ListAppend | ListPrepend | ListJoin | ListRange + | ListMap | ListMap2 | ListMap3 | ListMap4 | ListMapWithIndex | ListKeepIf + | ListWalk | ListWalkUntil | ListWalkBackwards | ListKeepOks | ListKeepErrs + | ListSortWith | ListSublist | ListDropAt | ListSwap | ListAny | ListAll + | ListFindUnsafe | DictSize | DictEmpty | DictInsert | DictRemove | DictContains + | DictGetUnsafe | DictKeys | DictValues | DictUnion | DictIntersection + | DictDifference | DictWalk | SetFromList => { todo!("{:?}", self.lowlevel); } diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index be5a8394d3..720382ce1a 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -28,6 +28,7 @@ pub enum LowLevel { ListSet, ListSingle, ListRepeat, + ListReplace, ListReverse, ListConcat, ListContains, @@ -229,6 +230,7 @@ impl LowLevelWrapperType { Symbol::LIST_LEN => CanBeReplacedBy(ListLen), Symbol::LIST_GET => WrapperIsRequired, Symbol::LIST_SET => WrapperIsRequired, + Symbol::LIST_REPLACE => WrapperIsRequired, Symbol::LIST_SINGLE => CanBeReplacedBy(ListSingle), Symbol::LIST_REPEAT => CanBeReplacedBy(ListRepeat), Symbol::LIST_REVERSE => CanBeReplacedBy(ListReverse), diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 82778fa6d0..80d4bd235c 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -1141,6 +1141,7 @@ define_builtins! { 55 LIST_SORT_ASC: "sortAsc" 56 LIST_SORT_DESC: "sortDesc" 57 LIST_SORT_DESC_COMPARE: "#sortDescCompare" + 58 LIST_REPLACE: "replace" } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 3f6a0bbc4e..d8ed8fa610 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -935,6 +935,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { match op { ListLen | StrIsEmpty | StrCountGraphemes => arena.alloc_slice_copy(&[borrowed]), ListSet => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]), + ListReplace => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]), ListGetUnsafe => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListConcat => arena.alloc_slice_copy(&[owned, owned]), StrConcat => arena.alloc_slice_copy(&[owned, borrowed]), diff --git a/compiler/test_gen/src/gen_list.rs b/compiler/test_gen/src/gen_list.rs index 333bdb4a63..ba8270808e 100644 --- a/compiler/test_gen/src/gen_list.rs +++ b/compiler/test_gen/src/gen_list.rs @@ -1702,6 +1702,40 @@ fn get_int_list_oob() { ); } +#[test] +#[cfg(any(feature = "gen-llvm"))] +fn replace_unique_int_list() { + assert_evals_to!( + indoc!( + r#" + result = List.replace [ 12, 9, 7, 1, 5 ] 2 33 + when result is + Ok (Pair newList _) -> newList + Err _ -> [] + "# + ), + RocList::from_slice(&[12, 9, 33, 1, 5]), + RocList + ); +} + +#[test] +#[cfg(any(feature = "gen-llvm"))] +fn replace_unique_int_list_get_old_value() { + assert_evals_to!( + indoc!( + r#" + result = List.replace [ 12, 9, 7, 1, 5 ] 2 33 + when result is + Ok (Pair _ oldValue) -> oldValue + Err _ -> -1 + "# + ), + 7, + i64 + ); +} + #[test] #[cfg(any(feature = "gen-llvm"))] fn get_set_unique_int_list() { diff --git a/compiler/types/src/builtin_aliases.rs b/compiler/types/src/builtin_aliases.rs index f8df53a256..07efa07c9e 100644 --- a/compiler/types/src/builtin_aliases.rs +++ b/compiler/types/src/builtin_aliases.rs @@ -914,6 +914,15 @@ pub fn ordering_type() -> SolvedType { ) } +#[inline(always)] +pub fn pair_type(t1: SolvedType, t2: SolvedType) -> SolvedType { + // [ Pair t1 t2 ] + SolvedType::TagUnion( + vec![(TagName::Global("Pair".into()), vec![t1, t2])], + Box::new(SolvedType::EmptyTagUnion), + ) +} + #[inline(always)] pub fn result_type(a: SolvedType, e: SolvedType) -> SolvedType { SolvedType::Alias(