add difference

This commit is contained in:
Folkert 2021-02-14 20:01:24 +01:00
parent 02db8f1a05
commit 39c4353554
5 changed files with 131 additions and 84 deletions

View file

@ -683,6 +683,31 @@ pub fn dictIntersection(dict1: RocDict, dict2: RocDict, alignment: Alignment, ke
} }
} }
pub fn dictDifference(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, dec_key: Inc, dec_value: Inc, output: *RocDict) callconv(.C) void {
output.* = dict1.makeUnique(std.heap.c_allocator, alignment, key_width, value_width);
var i: usize = 0;
const size = dict1.capacity();
while (i < size) : (i += 1) {
switch (output.getSlot(i, key_width, value_width)) {
Slot.Filled => {
const key = dict1.getKey(i, alignment, key_width, value_width);
switch (dict2.findIndex(alignment, key, key_width, value_width, hash_fn, is_eq)) {
MaybeIndex.not_found => {
// keep this key/value
continue;
},
MaybeIndex.index => |_| {
dictRemove(output.*, alignment, key, key_width, value_width, hash_fn, is_eq, dec_key, dec_value, output);
},
}
},
else => {},
}
}
}
fn decref( fn decref(
allocator: *Allocator, allocator: *Allocator,
alignment: Alignment, alignment: Alignment,

View file

@ -18,7 +18,7 @@ comptime {
exportDictFn(dict.dictValues, "values"); exportDictFn(dict.dictValues, "values");
exportDictFn(dict.dictUnion, "union"); exportDictFn(dict.dictUnion, "union");
exportDictFn(dict.dictIntersection, "intersection"); exportDictFn(dict.dictIntersection, "intersection");
// exportDictFn(dict.dictValues, "values"); exportDictFn(dict.dictDifference, "difference");
exportDictFn(hash.wyhash, "hash"); exportDictFn(hash.wyhash, "hash");
exportDictFn(hash.wyhash_rocstr, "hash_str"); exportDictFn(hash.wyhash_rocstr, "hash_str");

View file

@ -5,7 +5,6 @@ use crate::pattern::Pattern;
use roc_collections::all::{MutMap, SendMap}; use roc_collections::all::{MutMap, SendMap};
use roc_module::ident::TagName; use roc_module::ident::TagName;
use roc_module::low_level::LowLevel; use roc_module::low_level::LowLevel;
use roc_module::operator::CalledVia;
use roc_module::symbol::Symbol; use roc_module::symbol::Symbol;
use roc_region::all::{Located, Region}; use roc_region::all::{Located, Region};
use roc_types::subs::{VarStore, Variable}; use roc_types::subs::{VarStore, Variable};
@ -2169,21 +2168,17 @@ fn dict_dict_dict(symbol: Symbol, op: LowLevel, var_store: &mut VarStore) -> Def
/// Dict.union : Dict k v, Dict k v -> Dict k v /// Dict.union : Dict k v, Dict k v -> Dict k v
fn dict_union(symbol: Symbol, var_store: &mut VarStore) -> Def { fn dict_union(symbol: Symbol, var_store: &mut VarStore) -> Def {
dict_dict_dict(Symbol::DICT_UNION, LowLevel::DictUnion, var_store) dict_dict_dict(symbol, LowLevel::DictUnion, var_store)
} }
/// Dict.difference : Dict k v, Dict k v -> Dict k v /// Dict.difference : Dict k v, Dict k v -> Dict k v
fn dict_difference(symbol: Symbol, var_store: &mut VarStore) -> Def { fn dict_difference(symbol: Symbol, var_store: &mut VarStore) -> Def {
dict_dict_dict(Symbol::DICT_DIFFERENCE, LowLevel::DictDifference, var_store) dict_dict_dict(symbol, LowLevel::DictDifference, var_store)
} }
/// Dict.intersection : Dict k v, Dict k v -> Dict k v /// Dict.intersection : Dict k v, Dict k v -> Dict k v
fn dict_intersection(symbol: Symbol, var_store: &mut VarStore) -> Def { fn dict_intersection(symbol: Symbol, var_store: &mut VarStore) -> Def {
dict_dict_dict( dict_dict_dict(symbol, LowLevel::DictIntersection, var_store)
Symbol::DICT_INTERSECTION,
LowLevel::DictIntersection,
var_store,
)
} }
/// Num.rem : Int, Int -> Result Int [ DivByZero ]* /// Num.rem : Int, Int -> Result Int [ DivByZero ]*

View file

@ -570,80 +570,6 @@ pub fn dict_union<'a, 'ctx, 'env>(
env.builder.build_load(output_ptr, "load_output_ptr") env.builder.build_load(output_ptr, "load_output_ptr")
} }
#[allow(clippy::too_many_arguments)]
pub fn dict_intersection<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
dict1: BasicValueEnum<'ctx>,
dict2: BasicValueEnum<'ctx>,
key_layout: &Layout<'a>,
value_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap();
let dict1_ptr = builder.build_alloca(zig_dict_type, "dict_ptr");
let dict2_ptr = builder.build_alloca(zig_dict_type, "dict_ptr");
env.builder.build_store(
dict1_ptr,
struct_to_zig_dict(env, dict1.into_struct_value()),
);
env.builder.build_store(
dict2_ptr,
struct_to_zig_dict(env, dict2.into_struct_value()),
);
let key_width = env
.ptr_int()
.const_int(key_layout.stack_size(env.ptr_bytes) as u64, false);
let value_width = env
.ptr_int()
.const_int(value_layout.stack_size(env.ptr_bytes) as u64, false);
let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes);
let alignment_iv = env.context.i8_type().const_int(alignment as u64, false);
let hash_fn = build_hash_wrapper(env, layout_ids, key_layout);
let eq_fn = build_eq_wrapper(env, layout_ids, key_layout);
let dec_key_fn = build_rc_wrapper(env, layout_ids, key_layout, Mode::Dec);
let dec_value_fn = build_rc_wrapper(env, layout_ids, value_layout, Mode::Dec);
let output_ptr = builder.build_alloca(zig_dict_type, "output_ptr");
call_void_bitcode_fn(
env,
&[
dict1_ptr.into(),
dict2_ptr.into(),
alignment_iv.into(),
key_width.into(),
value_width.into(),
hash_fn.as_global_value().as_pointer_value().into(),
eq_fn.as_global_value().as_pointer_value().into(),
dec_key_fn.as_global_value().as_pointer_value().into(),
dec_value_fn.as_global_value().as_pointer_value().into(),
output_ptr.into(),
],
&bitcode::DICT_INTERSECTION,
);
let output_ptr = env
.builder
.build_bitcast(
output_ptr,
convert::dict(env.context, env.ptr_bytes).ptr_type(AddressSpace::Generic),
"to_roc_dict",
)
.into_pointer_value();
env.builder.build_load(output_ptr, "load_output_ptr")
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn dict_difference<'a, 'ctx, 'env>( pub fn dict_difference<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
@ -652,6 +578,47 @@ pub fn dict_difference<'a, 'ctx, 'env>(
dict2: BasicValueEnum<'ctx>, dict2: BasicValueEnum<'ctx>,
key_layout: &Layout<'a>, key_layout: &Layout<'a>,
value_layout: &Layout<'a>, value_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
dict_intersect_or_difference(
env,
layout_ids,
dict1,
dict2,
key_layout,
value_layout,
&bitcode::DICT_DIFFERENCE,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dict_intersection<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
dict1: BasicValueEnum<'ctx>,
dict2: BasicValueEnum<'ctx>,
key_layout: &Layout<'a>,
value_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
dict_intersect_or_difference(
env,
layout_ids,
dict1,
dict2,
key_layout,
value_layout,
&bitcode::DICT_INTERSECTION,
)
}
#[allow(clippy::too_many_arguments)]
fn dict_intersect_or_difference<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
dict1: BasicValueEnum<'ctx>,
dict2: BasicValueEnum<'ctx>,
key_layout: &Layout<'a>,
value_layout: &Layout<'a>,
op: &str,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
let builder = env.builder; let builder = env.builder;
@ -703,7 +670,7 @@ pub fn dict_difference<'a, 'ctx, 'env>(
dec_value_fn.as_global_value().as_pointer_value().into(), dec_value_fn.as_global_value().as_pointer_value().into(),
output_ptr.into(), output_ptr.into(),
], ],
&bitcode::DICT_DIFFERENCE, op,
); );
let output_ptr = env let output_ptr = env

View file

@ -437,4 +437,64 @@ mod gen_dict {
&[i64] &[i64]
); );
} }
#[test]
fn difference() {
assert_evals_to!(
indoc!(
r#"
dict1 : Dict I64 {}
dict1 =
Dict.empty
|> Dict.insert 1 {}
|> Dict.insert 2 {}
|> Dict.insert 3 {}
|> Dict.insert 4 {}
|> Dict.insert 5 {}
dict2 : Dict I64 {}
dict2 =
Dict.empty
|> Dict.insert 0 {}
|> Dict.insert 2 {}
|> Dict.insert 4 {}
Dict.difference dict1 dict2
|> Dict.len
"#
),
3,
i64
);
}
#[test]
fn difference_prefer_first() {
assert_evals_to!(
indoc!(
r#"
dict1 : Dict I64 I64
dict1 =
Dict.empty
|> Dict.insert 1 1
|> Dict.insert 2 2
|> Dict.insert 3 3
|> Dict.insert 4 4
|> Dict.insert 5 5
dict2 : Dict I64 I64
dict2 =
Dict.empty
|> Dict.insert 0 100
|> Dict.insert 2 200
|> Dict.insert 4 300
Dict.difference dict1 dict2
|> Dict.values
"#
),
&[5, 3, 1],
&[i64]
);
}
} }