fix List.map3

This commit is contained in:
Folkert 2021-05-19 16:41:12 +02:00
parent d2b0ecdd04
commit e3b102e0c3
4 changed files with 97 additions and 169 deletions

View file

@ -299,10 +299,35 @@ pub fn listMap2(
}
}
pub fn listMap3(list1: RocList, list2: RocList, list3: RocList, transform: Opaque, caller: Caller3, alignment: usize, a_width: usize, b_width: usize, c_width: usize, d_width: usize, dec_a: Dec, dec_b: Dec, dec_c: Dec) callconv(.C) RocList {
pub fn listMap3(
list1: RocList,
list2: RocList,
list3: RocList,
transform: Opaque,
caller: Caller3,
data: Opaque,
inc_n_data: IncN,
data_is_owned: bool,
alignment: usize,
a_width: usize,
b_width: usize,
c_width: usize,
d_width: usize,
dec_a: Dec,
dec_b: Dec,
dec_c: Dec,
) callconv(.C) RocList {
const smaller_length = std.math.min(list1.len(), list2.len());
const output_length = std.math.min(smaller_length, list3.len());
decrementTail(list1, output_length, a_width, dec_a);
decrementTail(list2, output_length, b_width, dec_b);
decrementTail(list3, output_length, c_width, dec_c);
if (data_is_owned) {
inc_n_data(data, output_length);
}
if (list1.bytes) |source_a| {
if (list2.bytes) |source_b| {
if (list3.bytes) |source_c| {
@ -316,108 +341,17 @@ pub fn listMap3(list1: RocList, list2: RocList, list3: RocList, transform: Opaqu
const element_c = source_c + i * c_width;
const target = target_ptr + i * d_width;
caller(transform, element_a, element_b, element_c, target);
caller(data, element_a, element_b, element_c, target);
}
// if the lists don't have equal length, we must consume the remaining elements
// In this case we consume by (recursively) decrementing the elements
if (list1.len() > output_length) {
i = output_length;
while (i < list1.len()) : (i += 1) {
const element_a = source_a + i * a_width;
dec_a(element_a);
}
}
if (list2.len() > output_length) {
i = output_length;
while (i < list2.len()) : (i += 1) {
const element_b = source_b + i * b_width;
dec_b(element_b);
}
}
if (list3.len() > output_length) {
i = output_length;
while (i < list3.len()) : (i += 1) {
const element_c = source_c + i * c_width;
dec_c(element_c);
}
}
utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width);
utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width);
utils.decref(std.heap.c_allocator, alignment, list3.bytes, list3.len() * c_width);
return output;
} else {
// consume list1 elements (we know there is at least one because the list1.bytes pointer is non-null
var i: usize = 0;
while (i < list1.len()) : (i += 1) {
const element_a = source_a + i * a_width;
dec_a(element_a);
}
utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width);
// consume list2 elements (we know there is at least one because the list1.bytes pointer is non-null
i = 0;
while (i < list2.len()) : (i += 1) {
const element_b = source_b + i * b_width;
dec_b(element_b);
}
utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width);
return RocList.empty();
}
} else {
// consume list1 elements (we know there is at least one because the list1.bytes pointer is non-null
var i: usize = 0;
while (i < list1.len()) : (i += 1) {
const element_a = source_a + i * a_width;
dec_a(element_a);
}
utils.decref(std.heap.c_allocator, alignment, list1.bytes, list1.len() * a_width);
// consume list3 elements (if any)
if (list3.bytes) |source_c| {
i = 0;
while (i < list2.len()) : (i += 1) {
const element_c = source_c + i * c_width;
dec_c(element_c);
}
utils.decref(std.heap.c_allocator, alignment, list3.bytes, list3.len() * c_width);
}
return RocList.empty();
}
} else {
// consume list2 elements (if any)
if (list2.bytes) |source_b| {
var i: usize = 0;
while (i < list2.len()) : (i += 1) {
const element_b = source_b + i * b_width;
dec_b(element_b);
}
utils.decref(std.heap.c_allocator, alignment, list2.bytes, list2.len() * b_width);
}
// consume list3 elements (if any)
if (list3.bytes) |source_c| {
var i: usize = 0;
while (i < list2.len()) : (i += 1) {
const element_c = source_c + i * c_width;
dec_c(element_c);
}
utils.decref(std.heap.c_allocator, alignment, list3.bytes, list3.len() * c_width);
}
return RocList.empty();
}
}

View file

@ -3787,20 +3787,33 @@ fn run_higher_order_low_level<'a, 'ctx, 'env>(
Layout::Builtin(Builtin::List(_, element1_layout)),
Layout::Builtin(Builtin::List(_, element2_layout)),
Layout::Builtin(Builtin::List(_, element3_layout)),
) => list_map3(
) => {
let argument_layouts =
&[**element1_layout, **element2_layout, **element3_layout];
let roc_function_call = roc_function_call(
env,
layout_ids,
function,
function_layout,
closure,
*closure_layout,
function_owns_closure_data,
argument_layouts,
);
list_map3(
env,
layout_ids,
roc_function_call,
list1,
list2,
list3,
element1_layout,
element2_layout,
element3_layout,
),
return_layout,
)
}
(Layout::Builtin(Builtin::EmptyList), _, _)
| (_, Layout::Builtin(Builtin::EmptyList), _)
| (_, _, Layout::Builtin(Builtin::EmptyList)) => empty_list(env),

View file

@ -893,18 +893,6 @@ pub fn list_map2<'a, 'ctx, 'env>(
element2_layout: &Layout<'a>,
return_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
let a_width = env
.ptr_int()
.const_int(element1_layout.stack_size(env.ptr_bytes) as u64, false);
let b_width = env
.ptr_int()
.const_int(element2_layout.stack_size(env.ptr_bytes) as u64, false);
let c_width = env
.ptr_int()
.const_int(return_layout.stack_size(env.ptr_bytes) as u64, false);
let dec_a = build_dec_wrapper(env, layout_ids, element1_layout);
let dec_b = build_dec_wrapper(env, layout_ids, element2_layout);
@ -918,9 +906,9 @@ pub fn list_map2<'a, 'ctx, 'env>(
roc_function_call.inc_n_data.into(),
roc_function_call.data_is_owned.into(),
alignment_intvalue(env, return_layout),
a_width.into(),
b_width.into(),
c_width.into(),
layout_width(env, element1_layout),
layout_width(env, element2_layout),
layout_width(env, return_layout),
dec_a.as_global_value().as_pointer_value().into(),
dec_b.as_global_value().as_pointer_value().into(),
],
@ -931,53 +919,15 @@ pub fn list_map2<'a, 'ctx, 'env>(
pub fn list_map3<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
transform: FunctionValue<'ctx>,
transform_layout: Layout<'a>,
closure_data: BasicValueEnum<'ctx>,
closure_data_layout: Layout<'a>,
roc_function_call: RocFunctionCall<'ctx>,
list1: BasicValueEnum<'ctx>,
list2: BasicValueEnum<'ctx>,
list3: BasicValueEnum<'ctx>,
element1_layout: &Layout<'a>,
element2_layout: &Layout<'a>,
element3_layout: &Layout<'a>,
return_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let return_layout = match transform_layout {
Layout::FunctionPointer(_, ret) => ret,
Layout::Closure(_, _, ret) => ret,
_ => unreachable!("not a callable layout"),
};
let closure_data_ptr = builder.build_alloca(closure_data.get_type(), "closure_data_ptr");
env.builder.build_store(closure_data_ptr, closure_data);
let stepper_caller = build_transform_caller_new(
env,
transform,
closure_data_layout,
&[*element1_layout, *element2_layout, *element3_layout],
)
.as_global_value()
.as_pointer_value();
let a_width = env
.ptr_int()
.const_int(element1_layout.stack_size(env.ptr_bytes) as u64, false);
let b_width = env
.ptr_int()
.const_int(element2_layout.stack_size(env.ptr_bytes) as u64, false);
let c_width = env
.ptr_int()
.const_int(element3_layout.stack_size(env.ptr_bytes) as u64, false);
let d_width = env
.ptr_int()
.const_int(return_layout.stack_size(env.ptr_bytes) as u64, false);
let dec_a = build_dec_wrapper(env, layout_ids, element1_layout);
let dec_b = build_dec_wrapper(env, layout_ids, element2_layout);
let dec_c = build_dec_wrapper(env, layout_ids, element3_layout);
@ -988,13 +938,15 @@ pub fn list_map3<'a, 'ctx, 'env>(
pass_list_as_i128(env, list1),
pass_list_as_i128(env, list2),
pass_list_as_i128(env, list3),
pass_as_opaque(env, closure_data_ptr),
stepper_caller.into(),
alignment_intvalue(env, &transform_layout),
a_width.into(),
b_width.into(),
c_width.into(),
d_width.into(),
roc_function_call.caller.into(),
pass_as_opaque(env, roc_function_call.data),
roc_function_call.inc_n_data.into(),
roc_function_call.data_is_owned.into(),
alignment_intvalue(env, return_layout),
layout_width(env, element1_layout),
layout_width(env, element2_layout),
layout_width(env, element3_layout),
layout_width(env, return_layout),
dec_a.as_global_value().as_pointer_value().into(),
dec_b.as_global_value().as_pointer_value().into(),
dec_c.as_global_value().as_pointer_value().into(),

View file

@ -560,6 +560,35 @@ impl<'a> Context<'a> {
None => unreachable!(),
}
}
roc_module::low_level::LowLevel::ListMap3 => {
match self.param_map.get_symbol(arguments[3], *closure_layout) {
Some(function_ps) => {
let borrows = [
function_ps[0].borrow,
function_ps[1].borrow,
function_ps[2].borrow,
FUNCTION,
CLOSURE_DATA,
];
let b = self.add_dec_after_lowlevel(
arguments,
&borrows,
b,
b_live_vars,
);
let b = decref_if_owned!(function_ps[0].borrow, arguments[0], b);
let b = decref_if_owned!(function_ps[1].borrow, arguments[1], b);
let b = decref_if_owned!(function_ps[2].borrow, arguments[2], b);
let v = create_call!(function_ps.get(3));
&*self.arena.alloc(Stmt::Let(z, v, l, b))
}
None => unreachable!(),
}
}
_ => {
let ps = crate::borrow::lowlevel_borrow_signature(self.arena, *op);
let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars);