List.keepIf in zig

This commit is contained in:
Folkert 2021-02-17 16:45:21 +01:00
parent 3093fe9e18
commit 7aceb8dc70
5 changed files with 91 additions and 224 deletions

View file

@ -1210,216 +1210,51 @@ pub fn list_contains_help<'a, 'ctx, 'env>(
pub fn list_keep_if<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
output_inplace: InPlace,
parent: FunctionValue<'ctx>,
func: BasicValueEnum<'ctx>,
func_layout: &Layout<'a>,
transform: BasicValueEnum<'ctx>,
transform_layout: &Layout<'a>,
list: BasicValueEnum<'ctx>,
list_layout: &Layout<'a>,
element_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let ctx = env.context;
let wrapper_struct = list.into_struct_value();
let (input_inplace, element_layout) = match list_layout.clone() {
Layout::Builtin(Builtin::EmptyList) => (
InPlace::InPlace,
// this pointer will never actually be dereferenced
Layout::Builtin(Builtin::Int64),
),
Layout::Builtin(Builtin::List(memory_mode, elem_layout)) => (
match memory_mode {
MemoryMode::Unique => InPlace::InPlace,
MemoryMode::Refcounted => InPlace::Clone,
},
elem_layout.clone(),
),
let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic);
_ => unreachable!("Invalid layout {:?} in List.keepIf", list_layout),
};
let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128");
let list_type = basic_type_from_layout(env.arena, env.context, &list_layout, env.ptr_bytes);
let elem_type = basic_type_from_layout(env.arena, env.context, &element_layout, env.ptr_bytes);
let ptr_type = elem_type.ptr_type(AddressSpace::Generic);
let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr");
env.builder.build_store(transform_ptr, transform);
let list_ptr = load_list_ptr(builder, wrapper_struct, ptr_type);
let length = list_len(builder, list.into_struct_value());
let stepper_caller =
build_transform_caller(env, layout_ids, transform_layout, &[element_layout.clone()])
.as_global_value()
.as_pointer_value();
let zero = ctx.i64_type().const_zero();
let element_width = env
.ptr_int()
.const_int(element_layout.stack_size(env.ptr_bytes) as u64, false);
match input_inplace {
InPlace::InPlace => {
let new_length = list_keep_if_help(
env,
input_inplace,
parent,
length,
list_ptr,
list_ptr,
func,
func_layout,
);
let alignment = element_layout.alignment_bytes(env.ptr_bytes);
let alignment_iv = env.ptr_int().const_int(alignment as u64, false);
store_list(env, list_ptr, new_length)
}
InPlace::Clone => {
let len_0_block = ctx.append_basic_block(parent, "len_0_block");
let len_n_block = ctx.append_basic_block(parent, "len_n_block");
let cont_block = ctx.append_basic_block(parent, "cont_block");
let output = call_bitcode_fn(
env,
&[
list_i128.into(),
env.builder
.build_bitcast(transform_ptr, u8_ptr, "to_opaque"),
stepper_caller.into(),
alignment_iv.into(),
element_width.into(),
],
&bitcode::LIST_KEEP_IF,
);
let result = builder.build_alloca(list_type, "result");
builder.build_switch(length, len_n_block, &[(zero, len_0_block)]);
// build block for length 0
{
builder.position_at_end(len_0_block);
let new_list = store_list(env, ptr_type.const_zero(), zero);
builder.build_store(result, new_list);
builder.build_unconditional_branch(cont_block);
}
// build block for length > 0
{
builder.position_at_end(len_n_block);
let new_list_ptr = allocate_list(env, output_inplace, &element_layout, length);
let new_length = list_keep_if_help(
env,
InPlace::Clone,
parent,
length,
list_ptr,
new_list_ptr,
func,
func_layout,
);
// store new list pointer there
let new_list = store_list(env, new_list_ptr, new_length);
builder.build_store(result, new_list);
builder.build_unconditional_branch(cont_block);
}
builder.position_at_end(cont_block);
// consume the input list
decrement_refcount_layout(env, parent, layout_ids, list, list_layout);
builder.build_load(result, "load_result")
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn list_keep_if_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
_inplace: InPlace,
parent: FunctionValue<'ctx>,
length: IntValue<'ctx>,
source_ptr: PointerValue<'ctx>,
dest_ptr: PointerValue<'ctx>,
func: BasicValueEnum<'ctx>,
func_layout: &Layout<'a>,
) -> IntValue<'ctx> {
match (func, func_layout) {
(
BasicValueEnum::PointerValue(func_ptr),
Layout::FunctionPointer(_, Layout::Builtin(Builtin::Int1)),
) => {
let builder = env.builder;
let ctx = env.context;
let index_alloca = builder.build_alloca(ctx.i64_type(), "index_alloca");
let next_free_index_alloca =
builder.build_alloca(ctx.i64_type(), "next_free_index_alloca");
builder.build_store(index_alloca, ctx.i64_type().const_zero());
builder.build_store(next_free_index_alloca, ctx.i64_type().const_zero());
// while (length > next_index)
let condition_bb = ctx.append_basic_block(parent, "condition");
builder.build_unconditional_branch(condition_bb);
builder.position_at_end(condition_bb);
let index = builder.build_load(index_alloca, "index").into_int_value();
let condition = builder.build_int_compare(IntPredicate::SGT, length, index, "loopcond");
let body_bb = ctx.append_basic_block(parent, "body");
let cont_bb = ctx.append_basic_block(parent, "cont");
builder.build_conditional_branch(condition, body_bb, cont_bb);
// loop body
builder.position_at_end(body_bb);
let elem_ptr = unsafe { builder.build_in_bounds_gep(source_ptr, &[index], "elem_ptr") };
let elem = builder.build_load(elem_ptr, "load_elem");
let call_site_value =
builder.build_call(func_ptr, env.arena.alloc([elem]), "#keep_if_insert_func");
// set the calling convention explicitly for this call
call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV);
let should_keep = call_site_value
.try_as_basic_value()
.left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."))
.into_int_value();
let filter_pass_bb = ctx.append_basic_block(parent, "loop");
let after_filter_pass_bb = ctx.append_basic_block(parent, "after_loop");
let one = ctx.i64_type().const_int(1, false);
builder.build_conditional_branch(should_keep, filter_pass_bb, after_filter_pass_bb);
builder.position_at_end(filter_pass_bb);
let next_free_index = builder
.build_load(next_free_index_alloca, "load_next_free")
.into_int_value();
// TODO if next_free_index equals index, and we are mutating in place,
// then maybe we should not write this value back into memory
let dest_elem_ptr = unsafe {
builder.build_in_bounds_gep(dest_ptr, &[next_free_index], "dest_elem_ptr")
};
builder.build_store(dest_elem_ptr, elem);
builder.build_store(
next_free_index_alloca,
builder.build_int_add(next_free_index, one, "incremented_next_free_index"),
);
builder.build_unconditional_branch(after_filter_pass_bb);
builder.position_at_end(after_filter_pass_bb);
builder.build_store(
index_alloca,
builder.build_int_add(index, one, "incremented_index"),
);
builder.build_unconditional_branch(condition_bb);
// continuation
builder.position_at_end(cont_bb);
builder
.build_load(next_free_index_alloca, "new_length")
.into_int_value()
}
_ => unreachable!(
"Invalid function basic value enum or layout for List.keepIf : {:?}",
(func, func_layout)
),
}
complex_bitcast(
env.builder,
output,
collection(env.context, env.ptr_bytes).into(),
"from_i128",
)
}
/// List.map : List before, (before -> after) -> List after