diff --git a/compiler/builtins/bitcode/src/dict.zig b/compiler/builtins/bitcode/src/dict.zig index 50f938724a..173638fbb7 100644 --- a/compiler/builtins/bitcode/src/dict.zig +++ b/compiler/builtins/bitcode/src/dict.zig @@ -712,6 +712,26 @@ pub fn dictDifference(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_ } } +pub fn setFromList(list: RocList, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, dec_key: Dec, output: *RocDict) callconv(.C) void { + output.* = RocDict.empty(); + + var ptr = @ptrCast([*]u8, list.bytes); + + const dec_value = doNothing; + const value = null; + + const size = list.length; + var i: usize = 0; + while (i < size) : (i += 1) { + const key = ptr + i * key_width; + dictInsert(output.*, alignment, key, key_width, value, value_width, hash_fn, is_eq, dec_key, dec_value, output); + } + + // NOTE: decref checks for the empty case + const data_bytes = size * key_width; + decref(std.heap.c_allocator, alignment, list.bytes, data_bytes); +} + const StepperCaller = fn (?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void; pub fn dictWalk(dict: RocDict, stepper: Opaque, stepper_caller: StepperCaller, accum: Opaque, alignment: Alignment, key_width: usize, value_width: usize, accum_width: usize, output: Opaque) callconv(.C) void { @memcpy(output orelse unreachable, accum orelse unreachable, accum_width); @@ -740,6 +760,10 @@ fn decref( bytes_or_null: ?[*]u8, data_bytes: usize, ) void { + if (data_bytes == 0) { + return; + } + var bytes = bytes_or_null orelse return; const usizes: [*]usize = @ptrCast([*]usize, @alignCast(8, bytes)); diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index e0605c3c5a..fbb4892194 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -21,6 +21,8 @@ comptime { exportDictFn(dict.dictDifference, "difference"); exportDictFn(dict.dictWalk, "walk"); + exportDictFn(dict.setFromList, "set_from_list"); + exportDictFn(hash.wyhash, "hash"); exportDictFn(hash.wyhash_rocstr, "hash_str"); } diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index 837359abd9..e320475912 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -51,3 +51,5 @@ pub const DICT_UNION: &str = "roc_builtins.dict.union"; pub const DICT_DIFFERENCE: &str = "roc_builtins.dict.difference"; pub const DICT_INTERSECTION: &str = "roc_builtins.dict.intersection"; pub const DICT_WALK: &str = "roc_builtins.dict.walk"; + +pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list"; diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index 53d271bcd8..111a495663 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -2240,8 +2240,8 @@ fn set_to_list(symbol: Symbol, var_store: &mut VarStore) -> Def { } /// Set.fromList : List k -> Set k -fn set_from_list(_symbol: Symbol, _var_store: &mut VarStore) -> Def { - todo!() +fn set_from_list(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_1(symbol, LowLevel::SetFromList, var_store) } /// Set.insert : Set k, k -> Set k diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index a7681bfd4f..d49a6caf21 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -1,6 +1,6 @@ use crate::llvm::build_dict::{ dict_contains, dict_difference, dict_empty, dict_get, dict_insert, dict_intersection, - dict_keys, dict_len, dict_remove, dict_union, dict_values, dict_walk, + dict_keys, dict_len, dict_remove, dict_union, dict_values, dict_walk, set_from_list, }; use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ @@ -4095,7 +4095,7 @@ fn run_low_level<'a, 'ctx, 'env>( match dict_layout { Layout::Builtin(Builtin::EmptyDict) => { // no elements, so `key` is not in here - panic!("key type unknown") + empty_list(env) } Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { dict_keys(env, layout_ids, dict, key_layout, value_layout) @@ -4111,7 +4111,7 @@ fn run_low_level<'a, 'ctx, 'env>( match dict_layout { Layout::Builtin(Builtin::EmptyDict) => { // no elements, so `key` is not in here - panic!("key type unknown") + empty_list(env) } Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { dict_values(env, layout_ids, dict, key_layout, value_layout) @@ -4196,6 +4196,19 @@ fn run_low_level<'a, 'ctx, 'env>( _ => unreachable!("invalid dict layout"), } } + SetFromList => { + debug_assert_eq!(args.len(), 1); + + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + match list_layout { + Layout::Builtin(Builtin::EmptyList) => dict_empty(env, scope), + Layout::Builtin(Builtin::List(_, key_layout)) => { + set_from_list(env, layout_ids, list, key_layout) + } + _ => unreachable!("invalid dict layout"), + } + } } } diff --git a/compiler/gen/src/llvm/build_dict.rs b/compiler/gen/src/llvm/build_dict.rs index 68925db5b3..df80f34aec 100644 --- a/compiler/gen/src/llvm/build_dict.rs +++ b/compiler/gen/src/llvm/build_dict.rs @@ -811,6 +811,68 @@ pub fn dict_values<'a, 'ctx, 'env>( env.builder.build_load(list_ptr, "load_keys_list") } +#[allow(clippy::too_many_arguments)] +pub fn set_from_list<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + list: BasicValueEnum<'ctx>, + key_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); + + let list_alloca = builder.build_alloca(list.get_type(), "list_alloca"); + let list_ptr = env.builder.build_bitcast( + list_alloca, + env.context.i128_type().ptr_type(AddressSpace::Generic), + "to_zig_list", + ); + + env.builder.build_store(list_alloca, list); + + let key_width = env + .ptr_int() + .const_int(key_layout.stack_size(env.ptr_bytes) as u64, false); + + let value_width = env.ptr_int().const_zero(); + + let result_alloca = + builder.build_alloca(convert::dict(env.context, env.ptr_bytes), "result_alloca"); + let result_ptr = builder.build_bitcast( + result_alloca, + zig_dict_type.ptr_type(AddressSpace::Generic), + "to_zig_dict", + ); + + let alignment = + Alignment::from_key_value_layout(key_layout, &Layout::Struct(&[]), env.ptr_bytes); + let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); + + let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); + let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); + + let dec_key_fn = build_rc_wrapper(env, layout_ids, key_layout, Mode::Dec); + + call_void_bitcode_fn( + env, + &[ + env.builder + .build_load(list_ptr.into_pointer_value(), "as_i128"), + alignment_iv.into(), + key_width.into(), + value_width.into(), + hash_fn.as_global_value().as_pointer_value().into(), + eq_fn.as_global_value().as_pointer_value().into(), + dec_key_fn.as_global_value().as_pointer_value().into(), + result_ptr.into(), + ], + &bitcode::SET_FROM_LIST, + ); + + env.builder.build_load(result_alloca, "load_result") +} + fn build_hash_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index f5a319aaf8..9cefbc3fbc 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -283,6 +283,10 @@ pub fn dict(ctx: &Context, ptr_bytes: u32) -> StructType<'_> { ) } +pub fn dict_ptr(ctx: &Context, ptr_bytes: u32) -> PointerType<'_> { + dict(ctx, ptr_bytes).ptr_type(AddressSpace::Generic) +} + pub fn ptr_int(ctx: &Context, ptr_bytes: u32) -> IntType<'_> { match ptr_bytes { 1 => ctx.i8_type(), diff --git a/compiler/gen/tests/gen_set.rs b/compiler/gen/tests/gen_set.rs index a85f48f6a7..012eaf64ab 100644 --- a/compiler/gen/tests/gen_set.rs +++ b/compiler/gen/tests/gen_set.rs @@ -113,14 +113,11 @@ mod gen_set { assert_evals_to!( indoc!( r#" - fromList : List a -> Set a - fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty - set1 : Set I64 - set1 = fromList [1,2] + set1 = Set.fromList [1,2] set2 : Set I64 - set2 = fromList [1,3,4] + set2 = Set.fromList [1,3,4] Set.union set1 set2 |> Set.toList @@ -136,14 +133,11 @@ mod gen_set { assert_evals_to!( indoc!( r#" - fromList : List a -> Set a - fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty - set1 : Set I64 - set1 = fromList [1,2] + set1 = Set.fromList [1,2] set2 : Set I64 - set2 = fromList [1,3,4] + set2 = Set.fromList [1,3,4] Set.difference set1 set2 |> Set.toList @@ -159,14 +153,11 @@ mod gen_set { assert_evals_to!( indoc!( r#" - fromList : List a -> Set a - fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty - set1 : Set I64 - set1 = fromList [1,2] + set1 = Set.fromList [1,2] set2 : Set I64 - set2 = fromList [1,3,4] + set2 = Set.fromList [1,3,4] Set.intersection set1 set2 |> Set.toList @@ -182,11 +173,7 @@ mod gen_set { assert_evals_to!( indoc!( r#" - fromList : List a -> Set a - fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty - - - Set.walk (fromList [1,2,3]) (\x, y -> x + y) 0 + Set.walk (Set.fromList [1,2,3]) (\x, y -> x + y) 0 "# ), 6, @@ -199,11 +186,7 @@ mod gen_set { assert_evals_to!( indoc!( r#" - fromList : List a -> Set a - fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty - - - Set.contains (fromList [1,3,4]) 4 + Set.contains (Set.fromList [1,3,4]) 4 "# ), true, @@ -213,15 +196,53 @@ mod gen_set { assert_evals_to!( indoc!( r#" - fromList : List a -> Set a - fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty - - - Set.contains (fromList [1,3,4]) 2 + Set.contains (Set.fromList [1,3,4]) 2 "# ), false, bool ); } + + #[test] + fn from_list() { + assert_evals_to!( + indoc!( + r#" + [1,2,2,3,1,4] + |> Set.fromList + |> Set.toList + "# + ), + &[4, 2, 3, 1], + &[i64] + ); + + assert_evals_to!( + indoc!( + r#" + [] + |> Set.fromList + |> Set.toList + "# + ), + &[], + &[i64] + ); + + assert_evals_to!( + indoc!( + r#" + empty : List I64 + empty = [] + + empty + |> Set.fromList + |> Set.toList + "# + ), + &[], + &[i64] + ); + } } diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 98873438b8..1b715b8991 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -41,6 +41,7 @@ pub enum LowLevel { DictIntersection, DictDifference, DictWalk, + SetFromList, NumAdd, NumAddWrap, NumAddChecked, diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 1e8b4d47ff..72ed112e89 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -615,5 +615,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { // borrow function argument so we don't have to worry about RC of the closure DictWalk => arena.alloc_slice_copy(&[owned, borrowed, owned]), + + SetFromList => arena.alloc_slice_copy(&[owned]), } }