diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index c25ccde9bb..05d1cea755 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -326,6 +326,19 @@ pub fn make_main_function<'a, 'ctx, 'env>( (main_fn_name, env.arena.alloc(main_fn)) } +fn get_inplace_from_layout(layout: &Layout<'_>) -> InPlace { + match layout { + Layout::Builtin(Builtin::EmptyList) => InPlace::InPlace, + Layout::Builtin(Builtin::List(memory_mode, _)) => match memory_mode { + MemoryMode::Unique => InPlace::InPlace, + MemoryMode::Refcounted => InPlace::Clone, + }, + Layout::Builtin(Builtin::EmptyStr) => InPlace::InPlace, + Layout::Builtin(Builtin::Str) => InPlace::Clone, + _ => unreachable!("Layout {:?} does not have an inplace", layout), + } +} + pub fn build_exp_literal<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, literal: &roc_mono::ir::Literal<'a>, @@ -353,7 +366,7 @@ pub fn build_exp_literal<'a, 'ctx, 'env>( let len_type = env.ptr_int(); let len = len_type.const_int(bytes_len, false); - allocate_list(env, &CHAR_LAYOUT, len) + allocate_list(env, InPlace::Clone, &CHAR_LAYOUT, len) // TODO check if malloc returned null; if so, runtime error for OOM! }; @@ -412,6 +425,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( layout_ids: &mut LayoutIds<'a>, scope: &Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, + layout: &Layout<'a>, expr: &roc_mono::ir::Expr<'a>, ) -> BasicValueEnum<'ctx> { use roc_mono::ir::CallType::*; @@ -419,7 +433,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( match expr { Literal(literal) => build_exp_literal(env, literal), - RunLowLevel(op, symbols) => run_low_level(env, scope, parent, *op, symbols), + RunLowLevel(op, symbols) => run_low_level(env, scope, parent, layout, *op, symbols), FunctionCall { call_type: ByName(name), @@ -784,7 +798,11 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( } } EmptyArray => empty_polymorphic_list(env), - Array { elem_layout, elems } => list_literal(env, scope, elem_layout, elems), + Array { elem_layout, elems } => { + let inplace = get_inplace_from_layout(layout); + + list_literal(env, inplace, scope, elem_layout, elems) + } FunctionPointer(symbol, layout) => { let fn_name = layout_ids .get(*symbol, layout) @@ -873,6 +891,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>( fn list_literal<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, scope: &Scope<'a, 'ctx>, elem_layout: &Layout<'a>, elems: &&[Symbol], @@ -888,7 +907,7 @@ fn list_literal<'a, 'ctx, 'env>( let len_type = env.ptr_int(); let len = len_type.const_int(bytes_len, false); - allocate_list(env, elem_layout, len) + allocate_list(env, inplace, elem_layout, len) // TODO check if malloc returned null; if so, runtime error for OOM! }; @@ -946,7 +965,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( Let(symbol, expr, layout, cont) => { let context = &env.context; - let val = build_exp_expr(env, layout_ids, &scope, parent, &expr); + let val = build_exp_expr(env, layout_ids, &scope, parent, layout, &expr); let expr_bt = if let Layout::RecursivePointer = layout { match expr { Expr::AccessAtIndex { field_layouts, .. } => { @@ -1582,6 +1601,7 @@ fn call_intrinsic<'a, 'ctx, 'env>( }) } +#[derive(Copy, Clone)] pub enum InPlace { InPlace, Clone, @@ -1610,6 +1630,7 @@ fn run_low_level<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, scope: &Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, + layout: &Layout<'a>, op: LowLevel, args: &[Symbol], ) -> BasicValueEnum<'ctx> { @@ -1624,7 +1645,9 @@ fn run_low_level<'a, 'ctx, 'env>( let second_str = load_symbol(env, scope, &args[1]); - str_concat(env, parent, first_str, second_str) + let inplace = get_inplace_from_layout(layout); + + str_concat(env, inplace, parent, first_str, second_str) } ListLen => { // List.len : List * -> Int @@ -1640,7 +1663,9 @@ fn run_low_level<'a, 'ctx, 'env>( let (arg, arg_layout) = load_symbol_and_layout(env, scope, &args[0]); - list_single(env, arg, arg_layout) + let inplace = get_inplace_from_layout(layout); + + list_single(env, inplace, arg, arg_layout) } ListRepeat => { // List.repeat : Int, elem -> List elem @@ -1649,7 +1674,9 @@ fn run_low_level<'a, 'ctx, 'env>( let list_len = load_symbol(env, scope, &args[0]).into_int_value(); let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]); - list_repeat(env, parent, list_len, elem, elem_layout) + let inplace = get_inplace_from_layout(layout); + + list_repeat(env, inplace, parent, list_len, elem, elem_layout) } ListReverse => { // List.reverse : List elem -> List elem @@ -1657,7 +1684,9 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(env, scope, &args[0]); - list_reverse(env, parent, InPlace::Clone, list, list_layout) + let inplace = get_inplace_from_layout(layout); + + list_reverse(env, parent, inplace, list, list_layout) } ListConcat => { debug_assert_eq!(args.len(), 2); @@ -1666,7 +1695,9 @@ fn run_low_level<'a, 'ctx, 'env>( let second_list = load_symbol(env, scope, &args[1]); - list_concat(env, parent, first_list, second_list, list_layout) + let inplace = get_inplace_from_layout(layout); + + list_concat(env, inplace, parent, first_list, second_list, list_layout) } ListMap => { // List.map : List before, (before -> after) -> List after @@ -1676,7 +1707,9 @@ fn run_low_level<'a, 'ctx, 'env>( let (func, func_layout) = load_symbol_and_layout(env, scope, &args[1]); - list_map(env, parent, func, func_layout, list, list_layout) + let inplace = get_inplace_from_layout(layout); + + list_map(env, inplace, parent, func, func_layout, list, list_layout) } ListKeepIf => { // List.keepIf : List elem, (elem -> Bool) -> List elem @@ -1686,7 +1719,9 @@ fn run_low_level<'a, 'ctx, 'env>( let (func, func_layout) = load_symbol_and_layout(env, scope, &args[1]); - list_keep_if(env, parent, func, func_layout, list, list_layout) + let inplace = get_inplace_from_layout(layout); + + list_keep_if(env, inplace, parent, func, func_layout, list, list_layout) } ListWalkRight => { // List.walkRight : List elem, (elem -> accum -> accum), accum -> accum @@ -1716,7 +1751,9 @@ fn run_low_level<'a, 'ctx, 'env>( let original_wrapper = load_symbol(env, scope, &args[0]).into_struct_value(); let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]); - list_append(env, original_wrapper, elem, elem_layout) + let inplace = get_inplace_from_layout(layout); + + list_append(env, inplace, original_wrapper, elem, elem_layout) } ListPrepend => { // List.prepend : List elem, elem -> List elem @@ -1725,7 +1762,9 @@ fn run_low_level<'a, 'ctx, 'env>( let original_wrapper = load_symbol(env, scope, &args[0]).into_struct_value(); let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]); - list_prepend(env, original_wrapper, elem, elem_layout) + let inplace = get_inplace_from_layout(layout); + + list_prepend(env, inplace, original_wrapper, elem, elem_layout) } ListJoin => { // List.join : List (List elem) -> List elem @@ -1733,7 +1772,9 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, outer_list_layout) = load_symbol_and_layout(env, scope, &args[0]); - list_join(env, parent, list, outer_list_layout) + let inplace = get_inplace_from_layout(layout); + + list_join(env, inplace, parent, list, outer_list_layout) } NumAbs | NumNeg | NumRound | NumSqrtUnchecked | NumSin | NumCos | NumToFloat => { debug_assert_eq!(args.len(), 1); @@ -1952,6 +1993,8 @@ fn run_low_level<'a, 'ctx, 'env>( ListSetInPlace => { let (list_symbol, list_layout) = load_symbol_and_layout(env, scope, &args[0]); + let output_inplace = get_inplace_from_layout(layout); + list_set( parent, &[ @@ -1961,6 +2004,7 @@ fn run_low_level<'a, 'ctx, 'env>( ], env, InPlace::InPlace, + output_inplace, ) } ListSet => { @@ -1972,8 +2016,10 @@ fn run_low_level<'a, 'ctx, 'env>( (load_symbol_and_layout(env, scope, &args[2])), ]; - let in_place = || list_set(parent, arguments, env, InPlace::InPlace); - let clone = || list_set(parent, arguments, env, InPlace::Clone); + let output_inplace = get_inplace_from_layout(layout); + + let in_place = || list_set(parent, arguments, env, InPlace::InPlace, output_inplace); + let clone = || list_set(parent, arguments, env, InPlace::Clone, output_inplace); let empty = || list_symbol; maybe_inplace_list( @@ -2037,6 +2083,7 @@ where /// Str.concat : Str, Str -> Str fn str_concat<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, parent: FunctionValue<'ctx>, first_str: BasicValueEnum<'ctx>, second_str: BasicValueEnum<'ctx>, @@ -2067,6 +2114,7 @@ fn str_concat<'a, 'ctx, 'env>( let (new_wrapper, _) = clone_nonempty_list( env, + inplace, second_str_len, load_list_ptr(builder, second_str_wrapper, ptr_type), &CHAR_LAYOUT, @@ -2094,6 +2142,7 @@ fn str_concat<'a, 'ctx, 'env>( let if_second_str_is_empty = || { let (new_wrapper, _) = clone_nonempty_list( env, + inplace, first_str_len, load_list_ptr(builder, first_str_wrapper, ptr_type), &CHAR_LAYOUT, @@ -2111,7 +2160,7 @@ fn str_concat<'a, 'ctx, 'env>( let combined_str_len = builder.build_int_add(first_str_len, second_str_len, "add_list_lengths"); - let combined_str_ptr = allocate_list(env, &CHAR_LAYOUT, combined_str_len); + let combined_str_ptr = allocate_list(env, inplace, &CHAR_LAYOUT, combined_str_len); // FIRST LOOP let first_str_ptr = load_list_ptr(builder, first_str_wrapper, ptr_type); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 47cc194679..0136bc2eae 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -7,16 +7,10 @@ use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, Str use inkwell::{AddressSpace, IntPredicate}; use roc_mono::layout::{Builtin, Layout, MemoryMode}; -fn get_list_element_type<'a, 'b>(layout: &'b Layout<'a>) -> Option<&'b Layout<'a>> { - match layout { - Layout::Builtin(Builtin::List(_, elem_layout)) => Some(elem_layout), - _ => None, - } -} - /// List.single : a -> List a pub fn list_single<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, elem: BasicValueEnum<'ctx>, elem_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { @@ -25,7 +19,8 @@ pub fn list_single<'a, 'ctx, 'env>( // allocate a list of size 1 on the heap let size = ctx.i64_type().const_int(1, false); - let ptr = allocate_list(env, elem_layout, size); + + let ptr = allocate_list(env, inplace, elem_layout, size); // Put the element into the list let elem_ptr = unsafe { @@ -47,6 +42,7 @@ pub fn list_single<'a, 'ctx, 'env>( /// List.repeat : Int, elem -> List elem pub fn list_repeat<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, parent: FunctionValue<'ctx>, list_len: IntValue<'ctx>, elem: BasicValueEnum<'ctx>, @@ -71,7 +67,7 @@ pub fn list_repeat<'a, 'ctx, 'env>( let build_then = || { // Allocate space for the new array that we'll copy into. - let list_ptr = allocate_list(env, elem_layout, list_len); + let list_ptr = allocate_list(env, inplace, elem_layout, list_len); // TODO check if malloc returned null; if so, runtime error for OOM! let index_name = "#index"; @@ -136,6 +132,7 @@ pub fn list_repeat<'a, 'ctx, 'env>( /// List.prepend List elem, elem -> List elem pub fn list_prepend<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, original_wrapper: StructValue<'ctx>, elem: BasicValueEnum<'ctx>, elem_layout: &Layout<'a>, @@ -157,7 +154,7 @@ pub fn list_prepend<'a, 'ctx, 'env>( ); // Allocate space for the new array that we'll copy into. - let clone_ptr = allocate_list(env, elem_layout, new_list_len); + let clone_ptr = allocate_list(env, inplace, elem_layout, new_list_len); builder.build_store(clone_ptr, elem); @@ -198,6 +195,7 @@ pub fn list_prepend<'a, 'ctx, 'env>( /// List.join : List (List elem) -> List elem pub fn list_join<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, parent: FunctionValue<'ctx>, outer_list: BasicValueEnum<'ctx>, outer_list_layout: &Layout<'a>, @@ -275,7 +273,7 @@ pub fn list_join<'a, 'ctx, 'env>( .build_load(list_len_sum_alloca, list_len_sum_name) .into_int_value(); - let final_list_ptr = allocate_list(env, elem_layout, final_list_sum); + let final_list_ptr = allocate_list(env, inplace, elem_layout, final_list_sum); let dest_elem_ptr_alloca = builder.build_alloca(elem_ptr_type, "dest_elem"); @@ -376,7 +374,7 @@ pub fn list_join<'a, 'ctx, 'env>( pub fn list_reverse_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, - in_place: InPlace, + inplace: InPlace, length: IntValue<'ctx>, source_ptr: PointerValue<'ctx>, dest_ptr: PointerValue<'ctx>, @@ -405,7 +403,7 @@ pub fn list_reverse_help<'a, 'ctx, 'env>( // if updating in-place, then the "middle element" can be left untouched // otherwise, the middle element needs to be copied over from the source to the target - let predicate = match in_place { + let predicate = match inplace { InPlace::InPlace => IntPredicate::SGT, InPlace::Clone => IntPredicate::SGE, }; @@ -429,7 +427,7 @@ pub fn list_reverse_help<'a, 'ctx, 'env>( let high_value = builder.build_load(high_ptr, "load_high"); // swap the two values - if let InPlace::Clone = in_place { + if let InPlace::Clone = inplace { low_ptr = unsafe { builder.build_in_bounds_gep(dest_ptr, &[low], "low_ptr") }; high_ptr = unsafe { builder.build_in_bounds_gep(dest_ptr, &[high], "high_ptr") }; } @@ -450,7 +448,7 @@ pub fn list_reverse_help<'a, 'ctx, 'env>( pub fn list_reverse<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, - in_place: InPlace, + output_inplace: InPlace, list: BasicValueEnum<'ctx>, list_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { @@ -460,12 +458,21 @@ pub fn list_reverse<'a, 'ctx, 'env>( let ctx = env.context; let wrapper_struct = list.into_struct_value(); - let element_layout = match get_list_element_type(list_layout) { - Some(element_layout) => element_layout.clone(), - None => { + let (input_inplace, element_layout) = match list_layout.clone() { + Layout::Builtin(Builtin::EmptyList) => ( + InPlace::InPlace, // this pointer will never actually be dereferenced - Layout::Builtin(Builtin::Int64) - } + Layout::Builtin(Builtin::Int64), + ), + Layout::Builtin(Builtin::List(memory_mode, elem_layout)) => ( + match memory_mode { + MemoryMode::Unique => InPlace::InPlace, + MemoryMode::Refcounted => InPlace::Clone, + }, + elem_layout.clone(), + ), + + _ => unreachable!("Invalid layout {:?} in List.reverse", list_layout), }; let list_type = basic_type_from_layout(env.arena, env.context, &element_layout, env.ptr_bytes); @@ -474,9 +481,9 @@ pub fn list_reverse<'a, 'ctx, 'env>( let list_ptr = load_list_ptr(builder, wrapper_struct, ptr_type); let length = list_len(builder, list.into_struct_value()); - match in_place { + match input_inplace { InPlace::InPlace => { - list_reverse_help(env, parent, in_place, length, list_ptr, list_ptr); + list_reverse_help(env, parent, input_inplace, length, list_ptr, list_ptr); list } @@ -511,7 +518,7 @@ pub fn list_reverse<'a, 'ctx, 'env>( { builder.position_at_end(len_1_block); - let new_list_ptr = clone_list(env, &element_layout, one, list_ptr); + let new_list_ptr = clone_list(env, output_inplace, &element_layout, one, list_ptr); builder.build_store(result, new_list_ptr); builder.build_unconditional_branch(cont_block); @@ -521,7 +528,7 @@ pub fn list_reverse<'a, 'ctx, 'env>( { builder.position_at_end(len_n_block); - let new_list_ptr = allocate_list(env, &element_layout, length); + let new_list_ptr = allocate_list(env, output_inplace, &element_layout, length); list_reverse_help(env, parent, InPlace::Clone, length, list_ptr, new_list_ptr); @@ -573,6 +580,7 @@ pub fn list_get_unsafe<'a, 'ctx, 'env>( /// List.append : List elem, elem -> List elem pub fn list_append<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, original_wrapper: StructValue<'ctx>, elem: BasicValueEnum<'ctx>, elem_layout: &Layout<'a>, @@ -608,7 +616,7 @@ pub fn list_append<'a, 'ctx, 'env>( .build_int_mul(elem_bytes, list_len, "mul_old_len_by_elem_bytes"); // Allocate space for the new array that we'll copy into. - let clone_ptr = allocate_list(env, elem_layout, new_list_len); + let clone_ptr = allocate_list(env, inplace, elem_layout, new_list_len); // TODO check if malloc returned null; if so, runtime error for OOM! @@ -634,7 +642,8 @@ pub fn list_set<'a, 'ctx, 'env>( parent: FunctionValue<'ctx>, args: &[(BasicValueEnum<'ctx>, &'a Layout<'a>)], env: &Env<'a, 'ctx, 'env>, - in_place: InPlace, + input_inplace: InPlace, + output_inplace: InPlace, ) -> BasicValueEnum<'ctx> { let builder = env.builder; @@ -656,13 +665,15 @@ pub fn list_set<'a, 'ctx, 'env>( let ctx = env.context; let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); let ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); - let (new_wrapper, array_data_ptr) = match in_place { + + let (new_wrapper, array_data_ptr) = match input_inplace { InPlace::InPlace => ( original_wrapper, load_list_ptr(builder, original_wrapper, ptr_type), ), InPlace::Clone => clone_nonempty_list( env, + output_inplace, list_len, load_list_ptr(builder, original_wrapper, ptr_type), elem_layout, @@ -812,186 +823,221 @@ pub fn list_walk_right<'a, 'ctx, 'env>( /// List.keepIf : List elem, (elem -> Bool) -> List elem pub fn list_keep_if<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + output_inplace: InPlace, parent: FunctionValue<'ctx>, func: BasicValueEnum<'ctx>, func_layout: &Layout<'a>, list: BasicValueEnum<'ctx>, list_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { + use inkwell::types::BasicType; + + let builder = env.builder; + let ctx = env.context; + + let wrapper_struct = list.into_struct_value(); + let (input_inplace, element_layout) = match list_layout.clone() { + Layout::Builtin(Builtin::EmptyList) => ( + InPlace::InPlace, + // this pointer will never actually be dereferenced + Layout::Builtin(Builtin::Int64), + ), + Layout::Builtin(Builtin::List(memory_mode, elem_layout)) => ( + match memory_mode { + MemoryMode::Unique => InPlace::InPlace, + MemoryMode::Refcounted => InPlace::Clone, + }, + elem_layout.clone(), + ), + + _ => unreachable!("Invalid layout {:?} in List.reverse", list_layout), + }; + + let list_type = basic_type_from_layout(env.arena, env.context, &list_layout, env.ptr_bytes); + let elem_type = basic_type_from_layout(env.arena, env.context, &element_layout, env.ptr_bytes); + let ptr_type = elem_type.ptr_type(AddressSpace::Generic); + + let list_ptr = load_list_ptr(builder, wrapper_struct, ptr_type); + let length = list_len(builder, list.into_struct_value()); + + let zero = ctx.i64_type().const_zero(); + + match input_inplace { + InPlace::InPlace => { + let new_length = list_keep_if_help( + env, + input_inplace, + parent, + length, + list_ptr, + list_ptr, + func, + func_layout, + ); + + store_list(env, list_ptr, new_length) + } + InPlace::Clone => { + let len_0_block = ctx.append_basic_block(parent, "len_0_block"); + let len_n_block = ctx.append_basic_block(parent, "len_n_block"); + let cont_block = ctx.append_basic_block(parent, "cont_block"); + + let result = builder.build_alloca(list_type, "result"); + + builder.build_switch(length, len_n_block, &[(zero, len_0_block)]); + + // build block for length 0 + { + builder.position_at_end(len_0_block); + + let new_list = store_list(env, ptr_type.const_zero(), zero); + + builder.build_store(result, new_list); + builder.build_unconditional_branch(cont_block); + } + + // build block for length > 0 + { + builder.position_at_end(len_n_block); + + let new_list_ptr = allocate_list(env, output_inplace, &element_layout, length); + + let new_length = list_keep_if_help( + env, + InPlace::Clone, + parent, + length, + list_ptr, + new_list_ptr, + func, + func_layout, + ); + + // store new list pointer there + let new_list = store_list(env, new_list_ptr, new_length); + + builder.build_store(result, new_list); + builder.build_unconditional_branch(cont_block); + } + + builder.position_at_end(cont_block); + + builder.build_load(result, "load_result") + } + } +} + +#[allow(clippy::too_many_arguments)] +pub fn list_keep_if_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + _inplace: InPlace, + parent: FunctionValue<'ctx>, + length: IntValue<'ctx>, + source_ptr: PointerValue<'ctx>, + dest_ptr: PointerValue<'ctx>, + func: BasicValueEnum<'ctx>, + func_layout: &Layout<'a>, +) -> IntValue<'ctx> { match (func, func_layout) { ( BasicValueEnum::PointerValue(func_ptr), Layout::FunctionPointer(_, Layout::Builtin(Builtin::Int1)), ) => { - let non_empty_fn = |elem_layout: &Layout<'a>, - len: IntValue<'ctx>, - list_wrapper: StructValue<'ctx>| { - let ctx = env.context; - let builder = env.builder; + let builder = env.builder; + let ctx = env.context; - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let elem_ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); + let index_alloca = builder.build_alloca(ctx.i64_type(), "index_alloca"); + let next_free_index_alloca = + builder.build_alloca(ctx.i64_type(), "next_free_index_alloca"); - let list_ptr = load_list_ptr(builder, list_wrapper, elem_ptr_type); + builder.build_store(index_alloca, ctx.i64_type().const_zero()); + builder.build_store(next_free_index_alloca, ctx.i64_type().const_zero()); - let ret_list_len_name = "#ret_list_alloca"; - let ret_list_len_alloca = builder.build_alloca(ctx.i64_type(), ret_list_len_name); - builder.build_store( - ret_list_len_alloca, - ctx.i64_type().const_int(0 as u64, false), - ); + // while (length > next_index) + let condition_bb = ctx.append_basic_block(parent, "condition"); + builder.build_unconditional_branch(condition_bb); + builder.position_at_end(condition_bb); - // Return List Length Loop - // This loop goes through the list and counts how many - // elements pass the filter function `elem -> Bool` - let ret_list_len_loop = |_, elem: BasicValueEnum<'ctx>| { - let call_site_value = builder.build_call( - func_ptr, - env.arena.alloc([elem]), - "#keep_if_count_func", - ); + let index = builder.build_load(index_alloca, "index").into_int_value(); - // set the calling convention explicitly for this call - call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); + let condition = builder.build_int_compare(IntPredicate::SGT, length, index, "loopcond"); - let should_keep = call_site_value - .try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) - .into_int_value(); + let body_bb = ctx.append_basic_block(parent, "body"); + let cont_bb = ctx.append_basic_block(parent, "cont"); + builder.build_conditional_branch(condition, body_bb, cont_bb); - let loop_bb = ctx.append_basic_block(parent, "loop"); - let after_bb = ctx.append_basic_block(parent, "after_loop"); + // loop body + builder.position_at_end(body_bb); - builder.build_conditional_branch(should_keep, loop_bb, after_bb); - builder.position_at_end(loop_bb); + let elem_ptr = unsafe { builder.build_in_bounds_gep(source_ptr, &[index], "elem_ptr") }; - // If the `elem` passes the `elem -> Bool` function - // then increment the return list length variable by 1 - let next_ret_list_len = builder.build_int_add( - builder - .build_load(ret_list_len_alloca, ret_list_len_name) - .into_int_value(), - ctx.i64_type().const_int(1, false), - "next_ret_list_len", - ); + let elem = builder.build_load(elem_ptr, "load_elem"); - // ..and store that incremented length in memory - builder.build_store(ret_list_len_alloca, next_ret_list_len); + let call_site_value = + builder.build_call(func_ptr, env.arena.alloc([elem]), "#keep_if_insert_func"); - builder.build_unconditional_branch(after_bb); - builder.position_at_end(after_bb); - }; + // set the calling convention explicitly for this call + call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); - let index_alloca = incrementing_elem_loop( - builder, - ctx, - parent, - list_ptr, - len, - "#index", - ret_list_len_loop, - ); + let should_keep = call_site_value + .try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) + .into_int_value(); - // Reset the index variable to 0. - builder.build_store(index_alloca, ctx.i64_type().const_int(0 as u64, false)); + let filter_pass_bb = ctx.append_basic_block(parent, "loop"); + let after_filter_pass_bb = ctx.append_basic_block(parent, "after_loop"); - let final_ret_list_len = builder - .build_load(ret_list_len_alloca, ret_list_len_name) - .into_int_value(); + let one = ctx.i64_type().const_int(1, false); - // Make a new list, with a length equal to the number - // of `elem` that passed the `elem -> Bool` function. - let ret_list_ptr = allocate_list(env, elem_layout, final_ret_list_len); + builder.build_conditional_branch(should_keep, filter_pass_bb, after_filter_pass_bb); + builder.position_at_end(filter_pass_bb); - // Make a pointer into the return list. This pointer is used - // below to store elements into return list. - let dest_elem_ptr_alloca = builder.build_alloca(elem_ptr_type, "dest_elem"); - // Store this new return list element pointer in memory as the - // pointer to the return list as a whole (`ret_list_ptr`). This - // is kind of a trick to point to the first elem in the list, - // because the pointer to the list is also the pointer to the first - // element. - builder.build_store(dest_elem_ptr_alloca, ret_list_ptr); + let next_free_index = builder + .build_load(next_free_index_alloca, "load_next_free") + .into_int_value(); - // Return List Loop - // This loop goes through the list and adds each - // `elem` only if it passes the `elem -> Bool` function - let ret_list_loop = |_, elem| { - let call_site_value = builder.build_call( - func_ptr, - env.arena.alloc([elem]), - "#keep_if_insert_func", - ); - - // set the calling convention explicitly for this call - call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); - - let should_keep = call_site_value - .try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) - .into_int_value(); - - let loop_bb = ctx.append_basic_block(parent, "loop"); - let after_bb = ctx.append_basic_block(parent, "after_loop"); - - builder.build_conditional_branch(should_keep, loop_bb, after_bb); - builder.position_at_end(loop_bb); - - // If the `elem` passes the `elem -> Bool` function - // then load the destination pointer.. - let dest_elem_ptr = builder - .build_load(dest_elem_ptr_alloca, "load_dest_elem_ptr") - .into_pointer_value(); - - // .. save the element into the return list at the - // destination pointer .. - builder.build_store(dest_elem_ptr, elem); - - // .. and then increment the destination pointer by one .. - let inc_dest_elem_ptr = BasicValueEnum::PointerValue(unsafe { - builder.build_in_bounds_gep( - dest_elem_ptr, - &[env.ptr_int().const_int(1 as u64, false)], - "increment_dest_elem", - ) - }); - - // .. and then finally, save the incremented value in memory. - builder.build_store(dest_elem_ptr_alloca, inc_dest_elem_ptr); - - builder.build_unconditional_branch(after_bb); - builder.position_at_end(after_bb); - }; - - incrementing_elem_loop( - builder, - ctx, - parent, - list_ptr, - len, - "#index", - ret_list_loop, - ); - - store_list(env, ret_list_ptr, final_ret_list_len) + // TODO if next_free_index equals index, and we are mutating in place, + // then maybe we should not write this value back into memory + let dest_elem_ptr = unsafe { + builder.build_in_bounds_gep(dest_ptr, &[next_free_index], "dest_elem_ptr") }; - if_list_is_not_empty(env, parent, non_empty_fn, list, list_layout, "List.keepIf") - } - _ => { - unreachable!( - "Invalid function basic value enum or layout for List.keepIf : {:?}", - (func, func_layout) + builder.build_store(dest_elem_ptr, elem); + + builder.build_store( + next_free_index_alloca, + builder.build_int_add(next_free_index, one, "incremented_next_free_index"), ); + + builder.build_unconditional_branch(after_filter_pass_bb); + builder.position_at_end(after_filter_pass_bb); + + builder.build_store( + index_alloca, + builder.build_int_add(index, one, "incremented_index"), + ); + + builder.build_unconditional_branch(condition_bb); + + // continuation + builder.position_at_end(cont_bb); + + builder + .build_load(next_free_index_alloca, "new_length") + .into_int_value() } + _ => unreachable!( + "Invalid function basic value enum or layout for List.keepIf : {:?}", + (func, func_layout) + ), } } /// List.map : List before, (before -> after) -> List after pub fn list_map<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, parent: FunctionValue<'ctx>, func: BasicValueEnum<'ctx>, func_layout: &Layout<'a>, @@ -1006,7 +1052,7 @@ pub fn list_map<'a, 'ctx, 'env>( let ctx = env.context; let builder = env.builder; - let ret_list_ptr = allocate_list(env, ret_elem_layout, len); + let ret_list_ptr = allocate_list(env, inplace, ret_elem_layout, len); let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); let ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); @@ -1053,6 +1099,7 @@ pub fn list_map<'a, 'ctx, 'env>( /// List.concat : List elem, List elem -> List elem pub fn list_concat<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, parent: FunctionValue<'ctx>, first_list: BasicValueEnum<'ctx>, second_list: BasicValueEnum<'ctx>, @@ -1094,6 +1141,7 @@ pub fn list_concat<'a, 'ctx, 'env>( let (new_wrapper, _) = clone_nonempty_list( env, + inplace, second_list_len, load_list_ptr(builder, second_list_wrapper, ptr_type), elem_layout, @@ -1121,6 +1169,7 @@ pub fn list_concat<'a, 'ctx, 'env>( let if_second_list_is_empty = || { let (new_wrapper, _) = clone_nonempty_list( env, + inplace, first_list_len, load_list_ptr(builder, first_list_wrapper, ptr_type), elem_layout, @@ -1139,7 +1188,8 @@ pub fn list_concat<'a, 'ctx, 'env>( let combined_list_len = builder.build_int_add(first_list_len, second_list_len, "add_list_lengths"); - let combined_list_ptr = allocate_list(env, elem_layout, combined_list_len); + let combined_list_ptr = + allocate_list(env, inplace, elem_layout, combined_list_len); let first_list_ptr = load_list_ptr(builder, first_list_wrapper, ptr_type); @@ -1539,6 +1589,7 @@ pub fn load_list_ptr<'ctx>( pub fn clone_nonempty_list<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, list_len: IntValue<'ctx>, elems_ptr: PointerValue<'ctx>, elem_layout: &Layout<'_>, @@ -1556,7 +1607,7 @@ pub fn clone_nonempty_list<'a, 'ctx, 'env>( .build_int_mul(elem_bytes, list_len, "clone_mul_len_by_elem_bytes"); // Allocate space for the new array that we'll copy into. - let clone_ptr = allocate_list(env, elem_layout, list_len); + let clone_ptr = allocate_list(env, inplace, elem_layout, list_len); let int_type = ptr_int(ctx, ptr_bytes); let ptr_as_int = builder.build_ptr_to_int(clone_ptr, int_type, "list_cast_ptr"); @@ -1606,6 +1657,7 @@ pub fn clone_nonempty_list<'a, 'ctx, 'env>( pub fn clone_list<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + output_inplace: InPlace, elem_layout: &Layout<'a>, length: IntValue<'ctx>, old_ptr: PointerValue<'ctx>, @@ -1614,7 +1666,7 @@ pub fn clone_list<'a, 'ctx, 'env>( let ptr_bytes = env.ptr_bytes; // allocate new empty list (with refcount 1) - let new_ptr = allocate_list(env, elem_layout, length); + let new_ptr = allocate_list(env, output_inplace, elem_layout, length); let stack_size = elem_layout.stack_size(env.ptr_bytes); let bytes = builder.build_int_mul( @@ -1631,6 +1683,7 @@ pub fn clone_list<'a, 'ctx, 'env>( pub fn allocate_list<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + inplace: InPlace, elem_layout: &Layout<'a>, length: IntValue<'ctx>, ) -> PointerValue<'ctx> { @@ -1683,11 +1736,16 @@ pub fn allocate_list<'a, 'ctx, 'env>( "make ptr", ); - // the refcount of a new list is initially 1 - // we assume that the list is indeed used (dead variables are eliminated) - let ref_count_one = ctx - .i64_type() - .const_int(crate::llvm::refcounting::REFCOUNT_1 as _, false); + let ref_count_one = match inplace { + InPlace::InPlace => length, + InPlace::Clone => { + // the refcount of a new list is initially 1 + // we assume that the list is indeed used (dead variables are eliminated) + ctx.i64_type() + .const_int(crate::llvm::refcounting::REFCOUNT_1 as _, false) + } + }; + builder.build_store(refcount_ptr, ref_count_one); list_element_ptr