diff --git a/compiler/builtins/bitcode/src/dict.zig b/compiler/builtins/bitcode/src/dict.zig index 6968d265dc..fc99439bc8 100644 --- a/compiler/builtins/bitcode/src/dict.zig +++ b/compiler/builtins/bitcode/src/dict.zig @@ -472,6 +472,24 @@ pub fn dictContains(dict: RocDict, alignment: Alignment, key: Opaque, key_width: } } +// Dict.get : Dict k v, k -> { flag: bool, value: Opaque } +pub fn dictGet(dict: RocDict, alignment: Alignment, key: Opaque, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, inc_value: Inc) callconv(.C) extern struct { value: Opaque, flag: bool } { + const capacity: usize = dict.dict_slot_len; + const n: usize = capacity; + const seed: u64 = 0; + + switch (dict.findIndex(capacity, seed, alignment, key, key_width, value_width, hash_fn, is_eq)) { + MaybeIndex.not_found => { + return .{ .flag = false, .value = null }; + }, + MaybeIndex.index => |index| { + var value = dict.getValue(n, index, alignment, key_width, value_width); + inc_value(value); + return .{ .flag = true, .value = value }; + }, + } +} + test "RocDict.init() contains nothing" { const key_size = @sizeOf(usize); const value_size = @sizeOf(usize); diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index 13ffca6ad0..c3cd7d34a1 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -12,6 +12,7 @@ comptime { exportDictFn(dict.dictInsert, "insert"); exportDictFn(dict.dictRemove, "remove"); exportDictFn(dict.dictContains, "contains"); + exportDictFn(dict.dictGet, "get"); exportDictFn(hash.wyhash, "hash"); exportDictFn(hash.wyhash_rocstr, "hash_str"); diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index 2743ee26ed..bfd2f1a442 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -42,3 +42,4 @@ pub const DICT_EMPTY: &str = "roc_builtins.dict.empty"; pub const DICT_INSERT: &str = "roc_builtins.dict.insert"; pub const DICT_REMOVE: &str = "roc_builtins.dict.remove"; pub const DICT_CONTAINS: &str = "roc_builtins.dict.contains"; +pub const DICT_GET: &str = "roc_builtins.dict.get"; diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index 5619018471..f383af25a5 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -84,6 +84,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option DICT_EMPTY => dict_empty, DICT_INSERT => dict_insert, DICT_REMOVE => dict_remove, + DICT_GET => dict_get, DICT_CONTAINS => dict_contains, NUM_ADD => num_add, NUM_ADD_CHECKED => num_add_checked, @@ -185,6 +186,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::DICT_EMPTY => dict_empty, Symbol::DICT_INSERT => dict_insert, Symbol::DICT_REMOVE => dict_remove, + Symbol::DICT_GET => dict_get, Symbol::DICT_CONTAINS => dict_contains, Symbol::NUM_ADD => num_add, Symbol::NUM_ADD_CHECKED => num_add_checked, @@ -1978,6 +1980,88 @@ fn dict_contains(symbol: Symbol, var_store: &mut VarStore) -> Def { ) } +/// Dict.get : Dict k v, k -> Result v [ KeyNotFound ]* +fn dict_get(symbol: Symbol, var_store: &mut VarStore) -> Def { + let arg_dict = Symbol::ARG_1; + let arg_key = Symbol::ARG_2; + + let temp_record = Symbol::ARG_3; + let temp_flag = Symbol::ARG_4; + + let bool_var = var_store.fresh(); + let flag_var = var_store.fresh(); + let key_var = var_store.fresh(); + let dict_var = var_store.fresh(); + let value_var = var_store.fresh(); + let ret_var = var_store.fresh(); + + let temp_record_var = var_store.fresh(); + let ext_var1 = var_store.fresh(); + let ext_var2 = var_store.fresh(); + + // NOTE DictGetUnsafe returns a { flag: Bool, value: v } + // when the flag is True, the value is found and defined; + // otherwise it is not and `Dict.get` should return `Err ...` + let def_body = RunLowLevel { + op: LowLevel::DictGetUnsafe, + args: vec![(dict_var, Var(arg_dict)), (key_var, Var(arg_key))], + ret_var: temp_record_var, + }; + + let def = Def { + annotation: None, + expr_var: temp_record_var, + loc_expr: Located::at_zero(def_body), + loc_pattern: Located::at_zero(Pattern::Identifier(temp_record)), + pattern_vars: Default::default(), + }; + + let get_value = Access { + record_var: temp_record_var, + ext_var: ext_var1, + field_var: value_var, + loc_expr: Box::new(no_region(Var(temp_record))), + field: "value".into(), + }; + + let get_flag = Access { + record_var: temp_record_var, + ext_var: ext_var2, + field_var: flag_var, + loc_expr: Box::new(no_region(Var(temp_record))), + field: "zflag".into(), + }; + + let make_ok = tag("Ok", vec![get_value], var_store); + + let make_err = tag( + "Err", + vec![tag("OutOfBounds", Vec::new(), var_store)], + var_store, + ); + + let inspect = If { + cond_var: bool_var, + branch_var: ret_var, + branches: vec![( + // if-condition + no_region(get_flag), + no_region(make_ok.clone()), + )], + final_else: Box::new(no_region(make_err)), + }; + + let body = LetNonRec(Box::new(def), Box::new(no_region(inspect)), ret_var); + + defn( + symbol, + vec![(dict_var, Symbol::ARG_1), (key_var, Symbol::ARG_2)], + var_store, + body, + ret_var, + ) +} + /// Num.rem : Int, Int -> Result Int [ DivByZero ]* fn num_rem(symbol: Symbol, var_store: &mut VarStore) -> Def { let num_var = var_store.fresh(); diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index a097018339..a228643127 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -1,4 +1,6 @@ -use crate::llvm::build_dict::{dict_contains, dict_empty, dict_insert, dict_len, dict_remove}; +use crate::llvm::build_dict::{ + dict_contains, dict_empty, dict_get, dict_insert, dict_len, dict_remove, +}; use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, @@ -4005,6 +4007,23 @@ fn run_low_level<'a, 'ctx, 'env>( _ => unreachable!("invalid dict layout"), } } + DictGetUnsafe => { + debug_assert_eq!(args.len(), 2); + + let (dict, dict_layout) = load_symbol_and_layout(scope, &args[0]); + let (key, key_layout) = load_symbol_and_layout(scope, &args[1]); + + match dict_layout { + Layout::Builtin(Builtin::EmptyDict) => { + unreachable!("we can't make up a layout for the return value"); + // in other words, make sure to check whether the dict is empty first + } + Layout::Builtin(Builtin::Dict(_, value_layout)) => { + dict_get(env, layout_ids, dict, key, key_layout, value_layout) + } + _ => unreachable!("invalid dict layout"), + } + } } } diff --git a/compiler/gen/src/llvm/build_dict.rs b/compiler/gen/src/llvm/build_dict.rs index e0d06bf533..87ec218c83 100644 --- a/compiler/gen/src/llvm/build_dict.rs +++ b/compiler/gen/src/llvm/build_dict.rs @@ -3,7 +3,7 @@ use crate::llvm::build::{ call_bitcode_fn, call_void_bitcode_fn, complex_bitcast, load_symbol, load_symbol_and_layout, set_name, Env, Scope, }; -use crate::llvm::convert::basic_type_from_layout; +use crate::llvm::convert::{as_const_zero, basic_type_from_layout}; use crate::llvm::refcounting::{decrement_refcount_layout, increment_refcount_layout}; use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::types::BasicType; @@ -284,6 +284,120 @@ pub fn dict_contains<'a, 'ctx, 'env>( ) } +#[allow(clippy::too_many_arguments)] +pub fn dict_get<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + dict: BasicValueEnum<'ctx>, + key: BasicValueEnum<'ctx>, + key_layout: &Layout<'a>, + value_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let dict_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); + let key_ptr = builder.build_alloca(key.get_type(), "key_ptr"); + + env.builder + .build_store(dict_ptr, struct_to_zig_dict(env, dict.into_struct_value())); + env.builder.build_store(key_ptr, key); + + let key_width = env + .ptr_int() + .const_int(key_layout.stack_size(env.ptr_bytes) as u64, false); + + let value_width = env + .ptr_int() + .const_int(value_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); + let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); + + let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); + let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); + + let inc_value_fn = build_rc_wrapper(env, layout_ids, value_layout, RCOperation::Inc); + + // { flag: bool, value: *const u8 } + let result = call_bitcode_fn( + env, + &[ + dict_ptr.into(), + alignment_iv.into(), + env.builder.build_bitcast(key_ptr, u8_ptr, "to_u8_ptr"), + key_width.into(), + value_width.into(), + hash_fn.as_global_value().as_pointer_value().into(), + eq_fn.as_global_value().as_pointer_value().into(), + inc_value_fn.as_global_value().as_pointer_value().into(), + ], + &bitcode::DICT_GET, + ) + .into_struct_value(); + + let flag = env + .builder + .build_extract_value(result, 1, "get_flag") + .unwrap() + .into_int_value(); + + let value_u8_ptr = env + .builder + .build_extract_value(result, 0, "get_value_ptr") + .unwrap() + .into_pointer_value(); + + let start_block = env.builder.get_insert_block().unwrap(); + let parent = start_block.get_parent().unwrap(); + + let if_not_null = env.context.append_basic_block(parent, "if_not_null"); + let done_block = env.context.append_basic_block(parent, "done"); + + let value_bt = basic_type_from_layout(env.arena, env.context, value_layout, env.ptr_bytes); + let default = as_const_zero(&value_bt); + + env.builder + .build_conditional_branch(flag, if_not_null, done_block); + + env.builder.position_at_end(if_not_null); + let value_ptr = env + .builder + .build_bitcast( + value_u8_ptr, + value_bt.ptr_type(AddressSpace::Generic), + "from_opaque", + ) + .into_pointer_value(); + let loaded = env.builder.build_load(value_ptr, "load_value"); + env.builder.build_unconditional_branch(done_block); + + env.builder.position_at_end(done_block); + let result_phi = env.builder.build_phi(value_bt, "result"); + + result_phi.add_incoming(&[(&default, start_block), (&loaded, if_not_null)]); + + let value = result_phi.as_basic_value(); + + let result = env + .context + .struct_type(&[value_bt, env.context.bool_type().into()], false) + .const_zero(); + + let result = env + .builder + .build_insert_value(result, flag, 1, "insert_flag") + .unwrap(); + + env.builder + .build_insert_value(result, value, 0, "insert_value") + .unwrap() + .into_struct_value() + .into() +} + fn build_hash_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index 590776bcbf..f5a319aaf8 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -3,6 +3,7 @@ use bumpalo::Bump; use inkwell::context::Context; use inkwell::types::BasicTypeEnum::{self, *}; use inkwell::types::{ArrayType, BasicType, FunctionType, IntType, PointerType, StructType}; +use inkwell::values::BasicValueEnum; use inkwell::AddressSpace; use roc_mono::layout::{Builtin, Layout, UnionLayout}; @@ -48,6 +49,18 @@ pub fn get_array_type<'ctx>(bt_enum: &BasicTypeEnum<'ctx>, size: u32) -> ArrayTy } } +/// TODO could this be added to Inkwell itself as a method on BasicValueEnum? +pub fn as_const_zero<'ctx>(bt_enum: &BasicTypeEnum<'ctx>) -> BasicValueEnum<'ctx> { + match bt_enum { + ArrayType(typ) => typ.const_zero().into(), + IntType(typ) => typ.const_zero().into(), + FloatType(typ) => typ.const_zero().into(), + PointerType(typ) => typ.const_zero().into(), + StructType(typ) => typ.const_zero().into(), + VectorType(typ) => typ.const_zero().into(), + } +} + fn basic_type_from_function_layout<'ctx>( arena: &Bump, context: &'ctx Context, diff --git a/compiler/gen/tests/gen_dict.rs b/compiler/gen/tests/gen_dict.rs index abe43f3c16..12ee27b19d 100644 --- a/compiler/gen/tests/gen_dict.rs +++ b/compiler/gen/tests/gen_dict.rs @@ -107,4 +107,46 @@ mod gen_dict { i64 ); } + + #[test] + fn dict_nonempty_get() { + assert_evals_to!( + indoc!( + r#" + empty : Dict I64 F64 + empty = Dict.insert Dict.empty 42 3.14 + + withDefault = \x, def -> + when x is + Ok v -> v + Err _ -> def + + empty + |> Dict.insert 42 3.14 + |> Dict.get 42 + |> withDefault 0 + "# + ), + 3.14, + f64 + ); + + assert_evals_to!( + indoc!( + r#" + withDefault = \x, def -> + when x is + Ok v -> v + Err _ -> def + + Dict.empty + |> Dict.insert 42 3.14 + |> Dict.get 43 + |> withDefault 0 + "# + ), + 0.0, + f64 + ); + } } diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 12a92e7ca7..402eab7dc3 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -33,6 +33,7 @@ pub enum LowLevel { DictInsert, DictRemove, DictContains, + DictGetUnsafe, NumAdd, NumAddWrap, NumAddChecked, diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 6c4272243d..67c687e83a 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -594,5 +594,6 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { DictInsert => arena.alloc_slice_copy(&[owned, owned, owned]), DictRemove => arena.alloc_slice_copy(&[owned, borrowed]), DictContains => arena.alloc_slice_copy(&[borrowed, borrowed]), + DictGetUnsafe => arena.alloc_slice_copy(&[borrowed, borrowed]), } }