List.map in zig!

This commit is contained in:
Folkert 2021-02-17 15:38:55 +01:00
parent 821e750876
commit 42c5662872
6 changed files with 250 additions and 206 deletions

View file

@ -1,15 +1,19 @@
const std = @import("std"); const std = @import("std");
const utils = @import("utils.zig"); const utils = @import("utils.zig");
const mem = std.mem;
const Allocator = mem.Allocator;
const Opaque = ?[*]u8;
pub const RocList = extern struct { pub const RocList = extern struct {
bytes: ?[*]u8, bytes: ?[*]u8,
length: usize, length: usize,
pub fn len(self: RocDict) usize { pub fn len(self: RocList) usize {
return self.length; return self.length;
} }
pub fn isEmpty(self: RocDict) bool { pub fn isEmpty(self: RocList) bool {
return self.len() == 0; return self.len() == 0;
} }
@ -17,7 +21,7 @@ pub const RocList = extern struct {
return RocList{ .bytes = null, .length = 0 }; return RocList{ .bytes = null, .length = 0 };
} }
pub fn isUnique(self: RocDict) bool { pub fn isUnique(self: RocList) bool {
// the empty list is unique (in the sense that copying it will not leak memory) // the empty list is unique (in the sense that copying it will not leak memory)
if (self.isEmpty()) { if (self.isEmpty()) {
return true; return true;
@ -101,3 +105,23 @@ pub const RocList = extern struct {
return result; return result;
} }
}; };
const ListMapCaller = fn (?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void;
pub fn listMap(list: RocList, transform: Opaque, caller: ListMapCaller, alignment: usize, old_element_width: usize, new_element_width: usize) callconv(.C) RocList {
if (list.bytes) |source_ptr| {
const size = list.len();
var i: usize = 0;
const output = RocList.allocate(std.heap.c_allocator, alignment, size, new_element_width);
const target_ptr = output.bytes orelse unreachable;
while (i < size) : (i += 1) {
caller(transform, source_ptr + (i * old_element_width), target_ptr + (i * new_element_width));
}
utils.decref(std.heap.c_allocator, alignment, list.bytes, size * old_element_width);
return output;
} else {
return RocList.empty();
}
}

View file

@ -2,6 +2,13 @@ const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const testing = std.testing; const testing = std.testing;
// List Module
const list = @import("list.zig");
comptime {
exportListFn(list.listMap, "map");
}
// Dict Module // Dict Module
const dict = @import("dict.zig"); const dict = @import("dict.zig");
const hash = @import("hash.zig"); const hash = @import("hash.zig");
@ -67,6 +74,10 @@ fn exportDictFn(comptime func: anytype, comptime func_name: []const u8) void {
exportBuiltinFn(func, "dict." ++ func_name); exportBuiltinFn(func, "dict." ++ func_name);
} }
fn exportListFn(comptime func: anytype, comptime func_name: []const u8) void {
exportBuiltinFn(func, "list." ++ func_name);
}
// Run all tests in imported modules // Run all tests in imported modules
// https://github.com/ziglang/zig/blob/master/lib/std/std.zig#L94 // https://github.com/ziglang/zig/blob/master/lib/std/std.zig#L94
test "" { test "" {

View file

@ -53,3 +53,5 @@ pub const DICT_INTERSECTION: &str = "roc_builtins.dict.intersection";
pub const DICT_WALK: &str = "roc_builtins.dict.walk"; pub const DICT_WALK: &str = "roc_builtins.dict.walk";
pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list"; pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list";
pub const LIST_MAP: &str = "roc_builtins.list.map";

View file

@ -3630,18 +3630,16 @@ fn run_low_level<'a, 'ctx, 'env>(
let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); let (func, func_layout) = load_symbol_and_layout(scope, &args[1]);
let inplace = get_inplace_from_layout(layout); match list_layout {
Layout::Builtin(Builtin::EmptyList) => {
list_map( // no elements, so `key` is not in here
env, panic!("key type unknown")
layout_ids, }
inplace, Layout::Builtin(Builtin::List(_, element_layout)) => {
parent, list_map(env, layout_ids, func, func_layout, list, element_layout)
func, }
func_layout, _ => unreachable!("invalid list layout"),
list, }
list_layout,
)
} }
ListKeepIf => { ListKeepIf => {
// List.keepIf : List elem, (elem -> Bool) -> List elem // List.keepIf : List elem, (elem -> Bool) -> List elem

View file

@ -1,5 +1,7 @@
use crate::debug_info_init;
use crate::llvm::build::{ use crate::llvm::build::{
allocate_with_refcount_help, build_num_binop, cast_basic_basic, Env, InPlace, allocate_with_refcount_help, build_num_binop, call_bitcode_fn, cast_basic_basic,
complex_bitcast, set_name, Env, InPlace, FAST_CALL_CONV,
}; };
use crate::llvm::compare::generic_eq; use crate::llvm::compare::generic_eq;
use crate::llvm::convert::{basic_type_from_layout, collection, get_ptr_type}; use crate::llvm::convert::{basic_type_from_layout, collection, get_ptr_type};
@ -7,11 +9,15 @@ use crate::llvm::refcounting::{
decrement_refcount_layout, increment_refcount_layout, refcount_is_one_comparison, decrement_refcount_layout, increment_refcount_layout, refcount_is_one_comparison,
PointerToRefcount, PointerToRefcount,
}; };
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::builder::Builder; use inkwell::builder::Builder;
use inkwell::context::Context; use inkwell::context::Context;
use inkwell::types::BasicType;
use inkwell::types::{BasicTypeEnum, PointerType}; use inkwell::types::{BasicTypeEnum, PointerType};
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue}; use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue};
use inkwell::{AddressSpace, IntPredicate}; use inkwell::{AddressSpace, IntPredicate};
use roc_builtins::bitcode;
use roc_module::symbol::Symbol;
use roc_mono::layout::{Builtin, Layout, LayoutIds, MemoryMode}; use roc_mono::layout::{Builtin, Layout, LayoutIds, MemoryMode};
/// List.single : a -> List a /// List.single : a -> List a
@ -461,8 +467,6 @@ pub fn list_reverse<'a, 'ctx, 'env>(
list: BasicValueEnum<'ctx>, list: BasicValueEnum<'ctx>,
list_layout: &Layout<'a>, list_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
use inkwell::types::BasicType;
let builder = env.builder; let builder = env.builder;
let ctx = env.context; let ctx = env.context;
@ -1092,8 +1096,6 @@ pub fn list_contains<'a, 'ctx, 'env>(
list: BasicValueEnum<'ctx>, list: BasicValueEnum<'ctx>,
list_layout: &Layout<'a>, list_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
use inkwell::types::BasicType;
let builder = env.builder; let builder = env.builder;
let wrapper_struct = list.into_struct_value(); let wrapper_struct = list.into_struct_value();
@ -1215,8 +1217,6 @@ pub fn list_keep_if<'a, 'ctx, 'env>(
list: BasicValueEnum<'ctx>, list: BasicValueEnum<'ctx>,
list_layout: &Layout<'a>, list_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
use inkwell::types::BasicType;
let builder = env.builder; let builder = env.builder;
let ctx = env.context; let ctx = env.context;
@ -1422,151 +1422,210 @@ pub fn list_keep_if_help<'a, 'ctx, 'env>(
} }
} }
/// List.map : List before, (before -> after) -> List after
macro_rules! list_map_help {
($env:expr, $layout_ids:expr, $inplace:expr, $parent:expr, $func:expr, $func_layout:expr, $list:expr, $list_layout:expr, $function_ptr:expr, $function_return_layout: expr, $closure_info:expr) => {{
let layout_ids = $layout_ids;
let inplace = $inplace;
let parent = $parent;
let func = $func;
let func_layout = $func_layout;
let list = $list;
let list_layout = $list_layout;
let function_ptr = $function_ptr;
let function_return_layout = $function_return_layout;
let closure_info : Option<(&Layout, BasicValueEnum)> = $closure_info;
let non_empty_fn = |elem_layout: &Layout<'a>,
len: IntValue<'ctx>,
list_wrapper: StructValue<'ctx>| {
let ctx = $env.context;
let builder = $env.builder;
let ret_list_ptr = allocate_list($env, inplace, function_return_layout, len);
let elem_type = basic_type_from_layout($env.arena, ctx, elem_layout, $env.ptr_bytes);
let ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic);
let list_ptr = load_list_ptr(builder, list_wrapper, ptr_type);
let list_loop = |index, before_elem| {
increment_refcount_layout($env, parent, layout_ids, 1, before_elem, elem_layout);
let arguments = match closure_info {
Some((closure_data_layout, closure_data)) => {
increment_refcount_layout( $env, parent, layout_ids, 1, closure_data, closure_data_layout);
bumpalo::vec![in $env.arena; before_elem, closure_data]
}
None => bumpalo::vec![in $env.arena; before_elem],
};
let call_site_value = builder.build_call(function_ptr, &arguments, "map_func");
// set the calling convention explicitly for this call
call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV);
let after_elem = call_site_value
.try_as_basic_value()
.left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."));
// The pointer to the element in the mapped-over list
let after_elem_ptr = unsafe {
builder.build_in_bounds_gep(ret_list_ptr, &[index], "load_index_after_list")
};
// Mutate the new array in-place to change the element.
builder.build_store(after_elem_ptr, after_elem);
};
incrementing_elem_loop(builder, ctx, parent, list_ptr, len, "#index", list_loop);
let result = store_list($env, ret_list_ptr, len);
// decrement the input list and function (if it's a closure)
decrement_refcount_layout($env, parent, layout_ids, list, list_layout);
decrement_refcount_layout($env, parent, layout_ids, func, func_layout);
if let Some((closure_data_layout, closure_data)) = closure_info {
decrement_refcount_layout( $env, parent, layout_ids, closure_data, closure_data_layout);
}
result
};
if_list_is_not_empty($env, parent, non_empty_fn, list, list_layout, "List.map")
}};
}
/// List.map : List before, (before -> after) -> List after /// List.map : List before, (before -> after) -> List after
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn list_map<'a, 'ctx, 'env>( pub fn list_map<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>, layout_ids: &mut LayoutIds<'a>,
inplace: InPlace, transform: BasicValueEnum<'ctx>,
parent: FunctionValue<'ctx>, transform_layout: &Layout<'a>,
func: BasicValueEnum<'ctx>,
func_layout: &Layout<'a>,
list: BasicValueEnum<'ctx>, list: BasicValueEnum<'ctx>,
list_layout: &Layout<'a>, element_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
match (func, func_layout) { let builder = env.builder;
(BasicValueEnum::PointerValue(func_ptr), Layout::FunctionPointer(_, ret_elem_layout)) => {
list_map_help!(
env,
layout_ids,
inplace,
parent,
func,
func_layout,
list,
list_layout,
func_ptr,
ret_elem_layout,
None
)
}
(
BasicValueEnum::StructValue(ptr_and_data),
Layout::Closure(_, closure_layout, ret_elem_layout),
) => {
let builder = env.builder;
let func_ptr = builder let return_layout = match transform_layout {
.build_extract_value(ptr_and_data, 0, "function_ptr") Layout::FunctionPointer(_, ret) => ret,
.unwrap() Layout::Closure(_, _, ret) => ret,
.into_pointer_value(); _ => unreachable!("not a callable layout"),
};
let closure_data = builder let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic);
.build_extract_value(ptr_and_data, 1, "closure_data")
let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128");
let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr");
env.builder.build_store(transform_ptr, transform);
let stepper_caller =
build_transform_caller(env, layout_ids, transform_layout, &[element_layout.clone()])
.as_global_value()
.as_pointer_value();
let old_element_width = env
.ptr_int()
.const_int(element_layout.stack_size(env.ptr_bytes) as u64, false);
let new_element_width = env
.ptr_int()
.const_int(return_layout.stack_size(env.ptr_bytes) as u64, false);
let alignment = element_layout.alignment_bytes(env.ptr_bytes);
let alignment_iv = env.ptr_int().const_int(alignment as u64, false);
let output = call_bitcode_fn(
env,
&[
list_i128.into(),
env.builder
.build_bitcast(transform_ptr, u8_ptr, "to_opaque"),
stepper_caller.into(),
alignment_iv.into(),
old_element_width.into(),
new_element_width.into(),
],
&bitcode::LIST_MAP,
);
complex_bitcast(
env.builder,
output,
collection(env.context, env.ptr_bytes).into(),
"from_i128",
)
}
const ARGUMENT_SYMBOLS: [Symbol; 8] = [
Symbol::ARG_1,
Symbol::ARG_2,
Symbol::ARG_3,
Symbol::ARG_4,
Symbol::ARG_5,
Symbol::ARG_6,
Symbol::ARG_7,
Symbol::ARG_8,
];
fn build_transform_caller<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
function_layout: &Layout<'a>,
argument_layouts: &[Layout<'a>],
) -> FunctionValue<'ctx> {
debug_assert!(argument_layouts.len() <= 7);
let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap();
let symbol = Symbol::ZIG_FUNCTION_CALLER;
let fn_name = layout_ids
.get(symbol, &function_layout)
.to_symbol_string(symbol, &env.interns);
let fn_name = format!("transform_caller_{}", fn_name);
let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic);
let function_value = crate::llvm::refcounting::build_header_help(
env,
&fn_name,
env.context.void_type().into(),
&(bumpalo::vec![ in env.arena; BasicTypeEnum::PointerType(arg_type); argument_layouts.len() + 2 ]),
);
let kind_id = Attribute::get_named_enum_kind_id("alwaysinline");
debug_assert!(kind_id > 0);
let attr = env.context.create_enum_attribute(kind_id, 1);
function_value.add_attribute(AttributeLoc::Function, attr);
let entry = env.context.append_basic_block(function_value, "entry");
env.builder.position_at_end(entry);
debug_info_init!(env, function_value);
let mut it = function_value.get_param_iter();
let closure_ptr = it.next().unwrap().into_pointer_value();
set_name(closure_ptr.into(), Symbol::ARG_1.ident_string(&env.interns));
let arguments =
bumpalo::collections::Vec::from_iter_in(it.take(argument_layouts.len()), env.arena);
for (argument, name) in arguments.iter().zip(ARGUMENT_SYMBOLS[1..].iter()) {
set_name(*argument, name.ident_string(&env.interns));
}
let closure_type =
basic_type_from_layout(env.arena, env.context, function_layout, env.ptr_bytes)
.ptr_type(AddressSpace::Generic);
let mut arguments_cast =
bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena);
for (argument_ptr, layout) in arguments.iter().zip(argument_layouts) {
let basic_type = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes)
.ptr_type(AddressSpace::Generic);
let argument_cast = env
.builder
.build_bitcast(*argument_ptr, basic_type, "load_opaque")
.into_pointer_value();
let argument = env.builder.build_load(argument_cast, "load_opaque");
arguments_cast.push(argument);
}
let closure_cast = env
.builder
.build_bitcast(closure_ptr, closure_type, "load_opaque")
.into_pointer_value();
let fpointer = env.builder.build_load(closure_cast, "load_opaque");
let call = match function_layout {
Layout::FunctionPointer(_, _) => env.builder.build_call(
fpointer.into_pointer_value(),
arguments_cast.as_slice(),
"tmp",
),
Layout::Closure(_, _, _) | Layout::Struct(_) => {
let pair = fpointer.into_struct_value();
let fpointer = env
.builder
.build_extract_value(pair, 0, "get_fpointer")
.unwrap(); .unwrap();
let closure_data_layout = closure_layout.as_block_of_memory_layout(); let closure_data = env
.builder
.build_extract_value(pair, 1, "get_closure_data")
.unwrap();
list_map_help!( arguments_cast.push(closure_data);
env, env.builder.build_call(
layout_ids, fpointer.into_pointer_value(),
inplace, arguments_cast.as_slice(),
parent, "tmp",
func,
func_layout,
list,
list_layout,
func_ptr,
ret_elem_layout,
Some((&closure_data_layout, closure_data))
) )
} }
_ => { _ => unreachable!("layout is not callable {:?}", function_layout),
unreachable!( };
"Invalid function basic value enum or layout for List.map : {:?}", call.set_call_convention(FAST_CALL_CONV);
(func, func_layout)
); let result = call
} .try_as_basic_value()
} .left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."));
let result_u8_ptr = function_value
.get_nth_param(argument_layouts.len() as u32 + 1)
.unwrap();
let result_ptr = env
.builder
.build_bitcast(
result_u8_ptr,
result.get_type().ptr_type(AddressSpace::Generic),
"write_result",
)
.into_pointer_value();
env.builder.build_store(result_ptr, result);
env.builder.build_return(None);
env.builder.position_at_end(block);
env.builder
.set_current_debug_location(env.context, di_location);
function_value
} }
/// List.concat : List elem, List elem -> List elem /// List.concat : List elem, List elem -> List elem
@ -1910,59 +1969,6 @@ where
index_alloca index_alloca
} }
// This function checks if the list is empty, and
// if it is, it returns an empty list, and if not
// it runs whatever code is passed in under `build_non_empty`
// This is the avoid allocating memory if the list is empty.
fn if_list_is_not_empty<'a, 'ctx, 'env, 'b, NonEmptyFn>(
env: &Env<'a, 'ctx, 'env>,
parent: FunctionValue<'ctx>,
mut build_non_empty: NonEmptyFn,
list: BasicValueEnum<'ctx>,
list_layout: &Layout<'a>,
list_fn_name: &str,
) -> BasicValueEnum<'ctx>
where
NonEmptyFn: FnMut(&Layout<'a>, IntValue<'ctx>, StructValue<'ctx>) -> BasicValueEnum<'ctx>,
{
match list_layout {
Layout::Builtin(Builtin::EmptyList) => empty_list(env),
Layout::Builtin(Builtin::List(_, elem_layout)) => {
let builder = env.builder;
let ctx = env.context;
let wrapper_struct = list.into_struct_value();
let len = list_len(builder, wrapper_struct);
// list_len > 0
let comparison = builder.build_int_compare(
IntPredicate::UGT,
len,
ctx.i64_type().const_int(0, false),
"greaterthanzero",
);
let build_empty = || empty_list(env);
let struct_type = collection(ctx, env.ptr_bytes);
build_basic_phi2(
env,
parent,
comparison,
|| build_non_empty(elem_layout, len, wrapper_struct),
build_empty,
BasicTypeEnum::StructType(struct_type),
)
}
_ => {
unreachable!("Invalid List layout for {} {:?}", list_fn_name, list_layout);
}
}
}
pub fn build_basic_phi2<'a, 'ctx, 'env, PassFn, FailFn>( pub fn build_basic_phi2<'a, 'ctx, 'env, PassFn, FailFn>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
parent: FunctionValue<'ctx>, parent: FunctionValue<'ctx>,

View file

@ -753,6 +753,9 @@ define_builtins! {
// a user-defined function that we need to capture in a closure // a user-defined function that we need to capture in a closure
// see e.g. Set.walk // see e.g. Set.walk
19 USER_FUNCTION: "#user_function" 19 USER_FUNCTION: "#user_function"
// A caller (wrapper) that we pass to zig for it to be able to call Roc functions
20 ZIG_FUNCTION_CALLER: "#zig_function_caller"
} }
1 NUM: "Num" => { 1 NUM: "Num" => {
0 NUM_NUM: "Num" imported // the Num.Num type alias 0 NUM_NUM: "Num" imported // the Num.Num type alias