Merge pull request #1049 from rtfeldman/map2

List.map2
This commit is contained in:
Richard Feldman 2021-03-05 22:55:12 -05:00 committed by GitHub
commit 1211fa93f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 230 additions and 2 deletions

View file

@ -153,6 +153,66 @@ pub fn listMapWithIndex(list: RocList, transform: Opaque, caller: Caller2, align
}
}
pub fn listMap2(list1: RocList, list2: RocList, transform: Opaque, caller: Caller2, alignment: usize, a_width: usize, b_width: usize, c_width: usize, dec_a: Dec, dec_b: Dec) callconv(.C) RocList {
const output_length = std.math.min(list1.len(), list2.len());
if (list1.bytes) |source_a| {
if (list2.bytes) |source_b| {
const output = RocList.allocate(std.heap.c_allocator, alignment, output_length, c_width);
const target_ptr = output.bytes orelse unreachable;
var i: usize = 0;
while (i < output_length) : (i += 1) {
const element_a = source_a + i * a_width;
const element_b = source_b + i * b_width;
const target = target_ptr + i * c_width;
caller(transform, element_a, element_b, target);
}
// if the lists don't have equal length, we must consume the remaining elements
// In this case we consume by (recursively) decrementing the elements
if (list1.len() > output_length) {
while (i < list1.len()) : (i += 1) {
const element_a = source_a + i * a_width;
dec_a(element_a);
}
} else if (list2.len() > output_length) {
while (i < list2.len()) : (i += 1) {
const element_b = source_b + i * b_width;
dec_b(element_b);
}
}
utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width);
utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width);
return output;
} else {
// consume list1 elements (we know there is at least one because the list1.bytes pointer is non-null
var i: usize = 0;
while (i < list1.len()) : (i += 1) {
const element_a = source_a + i * a_width;
dec_a(element_a);
}
utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width);
return RocList.empty();
}
} else {
// consume list2 elements (if any)
if (list2.bytes) |source_b| {
var i: usize = 0;
while (i < list2.len()) : (i += 1) {
const element_b = source_b + i * b_width;
dec_b(element_b);
}
utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width);
}
return RocList.empty();
}
}
pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, element_width: usize, inc: Inc, dec: Dec) callconv(.C) RocList {
if (list.bytes) |source_ptr| {
const size = list.len();

View file

@ -7,6 +7,7 @@ const list = @import("list.zig");
comptime {
exportListFn(list.listMap, "map");
exportListFn(list.listMap2, "map2");
exportListFn(list.listMapWithIndex, "map_with_index");
exportListFn(list.listKeepIf, "keep_if");
exportListFn(list.listWalk, "walk");

View file

@ -63,6 +63,7 @@ pub const DICT_WALK: &str = "roc_builtins.dict.walk";
pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list";
pub const LIST_MAP: &str = "roc_builtins.list.map";
pub const LIST_MAP2: &str = "roc_builtins.list.map2";
pub const LIST_MAP_WITH_INDEX: &str = "roc_builtins.list.map_with_index";
pub const LIST_KEEP_IF: &str = "roc_builtins.list.keep_if";
pub const LIST_KEEP_OKS: &str = "roc_builtins.list.keep_oks";

View file

@ -804,6 +804,19 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
)
});
// map2 : List a, List b, (a, b -> c) -> List c
add_type(Symbol::LIST_MAP2, {
let_tvars! {a, b, c, cvar};
top_level_function(
vec![
list_type(flex(a)),
list_type(flex(b)),
closure(vec![flex(a), flex(b)], cvar, Box::new(flex(c))),
],
Box::new(list_type(flex(c))),
)
});
// append : List elem, elem -> List elem
add_type(
Symbol::LIST_APPEND,

View file

@ -80,6 +80,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option<Def>
LIST_PREPEND => list_prepend,
LIST_JOIN => list_join,
LIST_MAP => list_map,
LIST_MAP2 => list_map2,
LIST_MAP_WITH_INDEX => list_map_with_index,
LIST_KEEP_IF => list_keep_if,
LIST_KEEP_OKS => list_keep_oks,
@ -215,6 +216,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap<Symbol, Def> {
Symbol::LIST_PREPEND => list_prepend,
Symbol::LIST_JOIN => list_join,
Symbol::LIST_MAP => list_map,
Symbol::LIST_MAP2 => list_map2,
Symbol::LIST_MAP_WITH_INDEX => list_map_with_index,
Symbol::LIST_KEEP_IF => list_keep_if,
Symbol::LIST_KEEP_OKS => list_keep_oks,
@ -2113,6 +2115,11 @@ fn list_map_with_index(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_2(symbol, LowLevel::ListMapWithIndex, var_store)
}
/// List.map2 : List a, List b, (a, b -> c) -> List c
fn list_map2(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_3(symbol, LowLevel::ListMap2, var_store)
}
/// Dict.hashTestOnly : k, v -> Nat
pub fn dict_hash_test_only(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_2(symbol, LowLevel::Hash, var_store)

View file

@ -7,8 +7,8 @@ use crate::llvm::build_hash::generic_hash;
use crate::llvm::build_list::{
allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains,
list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map,
list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, list_single, list_sum,
list_walk, list_walk_backwards,
list_map2, list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, list_single,
list_sum, list_walk, list_walk_backwards,
};
use crate::llvm::build_str::{
str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8,
@ -3719,6 +3719,33 @@ fn run_low_level<'a, 'ctx, 'env>(
_ => unreachable!("invalid list layout"),
}
}
ListMap2 => {
debug_assert_eq!(args.len(), 3);
let (list1, list1_layout) = load_symbol_and_layout(scope, &args[0]);
let (list2, list2_layout) = load_symbol_and_layout(scope, &args[1]);
let (func, func_layout) = load_symbol_and_layout(scope, &args[2]);
match (list1_layout, list2_layout) {
(
Layout::Builtin(Builtin::List(_, element1_layout)),
Layout::Builtin(Builtin::List(_, element2_layout)),
) => list_map2(
env,
layout_ids,
func,
func_layout,
list1,
list2,
element1_layout,
element2_layout,
),
(Layout::Builtin(Builtin::EmptyList), _)
| (_, Layout::Builtin(Builtin::EmptyList)) => empty_list(env),
_ => unreachable!("invalid list layout"),
}
}
ListMapWithIndex => {
// List.map : List before, (before -> after) -> List after
debug_assert_eq!(args.len(), 2);

View file

@ -1218,6 +1218,93 @@ fn list_map_generic<'a, 'ctx, 'env>(
)
}
pub fn list_map2<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
transform: BasicValueEnum<'ctx>,
transform_layout: &Layout<'a>,
list1: BasicValueEnum<'ctx>,
list2: BasicValueEnum<'ctx>,
element1_layout: &Layout<'a>,
element2_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let return_layout = match transform_layout {
Layout::FunctionPointer(_, ret) => ret,
Layout::Closure(_, _, ret) => ret,
_ => unreachable!("not a callable layout"),
};
let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic);
let list1_i128 = complex_bitcast(
env.builder,
list1,
env.context.i128_type().into(),
"to_i128",
);
let list2_i128 = complex_bitcast(
env.builder,
list2,
env.context.i128_type().into(),
"to_i128",
);
let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr");
env.builder.build_store(transform_ptr, transform);
let argument_layouts = [element1_layout.clone(), element2_layout.clone()];
let stepper_caller =
build_transform_caller(env, layout_ids, transform_layout, &argument_layouts)
.as_global_value()
.as_pointer_value();
let a_width = env
.ptr_int()
.const_int(element1_layout.stack_size(env.ptr_bytes) as u64, false);
let b_width = env
.ptr_int()
.const_int(element2_layout.stack_size(env.ptr_bytes) as u64, false);
let c_width = env
.ptr_int()
.const_int(return_layout.stack_size(env.ptr_bytes) as u64, false);
let alignment = return_layout.alignment_bytes(env.ptr_bytes);
let alignment_iv = env.ptr_int().const_int(alignment as u64, false);
let dec_a = build_dec_wrapper(env, layout_ids, element1_layout);
let dec_b = build_dec_wrapper(env, layout_ids, element2_layout);
let output = call_bitcode_fn(
env,
&[
list1_i128,
list2_i128,
env.builder
.build_bitcast(transform_ptr, u8_ptr, "to_opaque"),
stepper_caller.into(),
alignment_iv.into(),
a_width.into(),
b_width.into(),
c_width.into(),
dec_a.as_global_value().as_pointer_value().into(),
dec_b.as_global_value().as_pointer_value().into(),
],
bitcode::LIST_MAP2,
);
complex_bitcast(
env.builder,
output,
collection(env.context, env.ptr_bytes).into(),
"from_i128",
)
}
/// List.concat : List elem, List elem -> List elem
pub fn list_concat<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,

View file

@ -27,6 +27,7 @@ pub enum LowLevel {
ListPrepend,
ListJoin,
ListMap,
ListMap2,
ListMapWithIndex,
ListKeepIf,
ListWalk,

View file

@ -909,6 +909,7 @@ define_builtins! {
21 LIST_KEEP_OKS: "keepOks"
22 LIST_KEEP_ERRS: "keepErrs"
23 LIST_MAP_WITH_INDEX: "mapWithIndex"
24 LIST_MAP2: "map2"
}
5 RESULT: "Result" => {
0 RESULT_RESULT: "Result" imported // the Result.Result type alias

View file

@ -651,6 +651,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
StrJoinWith => arena.alloc_slice_copy(&[borrowed, borrowed]),
ListJoin => arena.alloc_slice_copy(&[irrelevant]),
ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, irrelevant]),
ListMap2 => arena.alloc_slice_copy(&[owned, owned, irrelevant]),
ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, borrowed]),
ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]),
ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]),

View file

@ -568,6 +568,35 @@ fn list_map_closure() {
);
}
#[test]
fn list_map2_pair() {
assert_evals_to!(
indoc!(
r#"
List.map2 [1,2,3] [3,2,1] (\a,b -> Pair a b)
"#
),
RocList::from_slice(&[(1, 3), (2, 2), (3, 1)]),
RocList<(i64, i64)>
);
}
#[test]
fn list_map2_different_lengths() {
assert_evals_to!(
indoc!(
r#"
List.map2
["a", "b", "lllllllllllllongnggg" ]
["b"]
Str.concat
"#
),
RocList::from_slice(&[RocStr::from_slice("ab".as_bytes()),]),
RocList<RocStr>
);
}
#[test]
fn list_join_empty_list() {
assert_evals_to!("List.join []", RocList::from_slice(&[]), RocList<i64>);