diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index b43e329405..1791a8f61b 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -11,6 +11,7 @@ const CompareFn = fn (?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) u8; const Opaque = ?[*]u8; const Inc = fn (?[*]u8) callconv(.C) void; +const IncN = fn (?[*]u8, usize) callconv(.C) void; const Dec = fn (?[*]u8) callconv(.C) void; pub const RocList = extern struct { @@ -615,7 +616,7 @@ pub fn listContains(list: RocList, key: Opaque, key_width: usize, is_eq: EqFn) c return false; } -pub fn listRepeat(count: usize, alignment: usize, element: Opaque, element_width: usize, inc_n_element: Inc) callconv(.C) RocList { +pub fn listRepeat(count: usize, alignment: usize, element: Opaque, element_width: usize, inc_n_element: IncN) callconv(.C) RocList { if (count == 0) { return RocList.empty(); } @@ -624,18 +625,15 @@ pub fn listRepeat(count: usize, alignment: usize, element: Opaque, element_width var output = RocList.allocate(allocator, alignment, count, element_width); if (output.bytes) |target_ptr| { + // increment the element's RC N times + inc_n_element(element, count); + var i: usize = 0; const source = element orelse unreachable; while (i < count) : (i += 1) { @memcpy(target_ptr + i * element_width, source, element_width); } - // TODO do all increments at once! - i = 0; - while (i < count) : (i += 1) { - inc_n_element(element); - } - return output; } else { unreachable; diff --git a/compiler/builtins/bitcode/src/utils.zig b/compiler/builtins/bitcode/src/utils.zig index f36eee65d5..ab3e6951f7 100644 --- a/compiler/builtins/bitcode/src/utils.zig +++ b/compiler/builtins/bitcode/src/utils.zig @@ -1,6 +1,10 @@ const std = @import("std"); const Allocator = std.mem.Allocator; +pub const Inc = fn (?[*]u8) callconv(.C) void; +pub const IncN = fn (?[*]u8, u64) callconv(.C) void; +pub const Dec = fn (?[*]u8) callconv(.C) void; + const REFCOUNT_MAX_ISIZE: comptime isize = 0; const REFCOUNT_ONE_ISIZE: comptime isize = std.math.minInt(isize); pub const REFCOUNT_ONE: usize = @bitCast(usize, REFCOUNT_ONE_ISIZE); diff --git a/compiler/gen/src/llvm/bitcode.rs b/compiler/gen/src/llvm/bitcode.rs index b80f51375e..8545e8770a 100644 --- a/compiler/gen/src/llvm/bitcode.rs +++ b/compiler/gen/src/llvm/bitcode.rs @@ -2,7 +2,9 @@ use crate::debug_info_init; use crate::llvm::build::{set_name, Env, C_CALL_CONV, FAST_CALL_CONV}; use crate::llvm::convert::basic_type_from_layout; -use crate::llvm::refcounting::{decrement_refcount_layout, increment_refcount_layout, Mode}; +use crate::llvm::refcounting::{ + decrement_refcount_layout, increment_n_refcount_layout, increment_refcount_layout, +}; use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::types::{BasicType, BasicTypeEnum}; use inkwell::values::{BasicValueEnum, CallSiteValue, FunctionValue, InstructionValue}; @@ -204,21 +206,28 @@ fn build_transform_caller_help<'a, 'ctx, 'env>( function_value } +enum Mode { + Inc, + IncN, + Dec, +} + +/// a functin that accepts two arguments: the value to increment, and an amount to increment by pub fn build_inc_n_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, layout: &Layout<'a>, - n: u64, ) -> FunctionValue<'ctx> { - build_rc_wrapper(env, layout_ids, layout, Mode::Inc(n)) + build_rc_wrapper(env, layout_ids, layout, Mode::IncN) } +/// a functin that accepts two arguments: the value to increment; increments by 1 pub fn build_inc_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, layout: &Layout<'a>, ) -> FunctionValue<'ctx> { - build_rc_wrapper(env, layout_ids, layout, Mode::Inc(1)) + build_rc_wrapper(env, layout_ids, layout, Mode::Inc) } pub fn build_dec_wrapper<'a, 'ctx, 'env>( @@ -229,7 +238,7 @@ pub fn build_dec_wrapper<'a, 'ctx, 'env>( build_rc_wrapper(env, layout_ids, layout, Mode::Dec) } -pub fn build_rc_wrapper<'a, 'ctx, 'env>( +fn build_rc_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, layout: &Layout<'a>, @@ -244,7 +253,8 @@ pub fn build_rc_wrapper<'a, 'ctx, 'env>( .to_symbol_string(symbol, &env.interns); let fn_name = match rc_operation { - Mode::Inc(n) => format!("{}_inc_{}", fn_name, n), + Mode::IncN => format!("{}_inc_n", fn_name), + Mode::Inc => format!("{}_inc", fn_name), Mode::Dec => format!("{}_dec", fn_name), }; @@ -253,12 +263,20 @@ pub fn build_rc_wrapper<'a, 'ctx, 'env>( None => { let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); - let function_value = crate::llvm::refcounting::build_header_help( - env, - &fn_name, - env.context.void_type().into(), - &[arg_type.into()], - ); + let function_value = match rc_operation { + Mode::Inc | Mode::Dec => crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.void_type().into(), + &[arg_type.into()], + ), + Mode::IncN => crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.void_type().into(), + &[arg_type.into(), env.ptr_int().into()], + ), + }; let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); debug_assert!(kind_id > 0); @@ -285,9 +303,16 @@ pub fn build_rc_wrapper<'a, 'ctx, 'env>( let value = env.builder.build_load(value_cast, "load_opaque"); match rc_operation { - Mode::Inc(n) => { + Mode::Inc => { + let n = 1; increment_refcount_layout(env, function_value, layout_ids, n, value, layout); } + Mode::IncN => { + let n = it.next().unwrap().into_int_value(); + set_name(n.into(), Symbol::ARG_2.ident_string(&env.interns)); + + increment_n_refcount_layout(env, function_value, layout_ids, n, value, layout); + } Mode::Dec => { decrement_refcount_layout(env, function_value, layout_ids, value, layout); } diff --git a/compiler/gen/src/llvm/build_dict.rs b/compiler/gen/src/llvm/build_dict.rs index 7ef54fd055..b282f77ddd 100644 --- a/compiler/gen/src/llvm/build_dict.rs +++ b/compiler/gen/src/llvm/build_dict.rs @@ -397,9 +397,16 @@ pub fn dict_elements_rc<'a, 'ctx, 'env>( let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); - use crate::llvm::bitcode::build_rc_wrapper; - let inc_key_fn = build_rc_wrapper(env, layout_ids, key_layout, rc_operation); - let inc_value_fn = build_rc_wrapper(env, layout_ids, value_layout, rc_operation); + let (key_fn, value_fn) = match rc_operation { + Mode::Inc => ( + build_inc_wrapper(env, layout_ids, key_layout), + build_inc_wrapper(env, layout_ids, value_layout), + ), + Mode::Dec => ( + build_dec_wrapper(env, layout_ids, key_layout), + build_dec_wrapper(env, layout_ids, value_layout), + ), + }; call_void_bitcode_fn( env, @@ -408,8 +415,8 @@ pub fn dict_elements_rc<'a, 'ctx, 'env>( alignment_iv.into(), key_width.into(), value_width.into(), - inc_key_fn.as_global_value().as_pointer_value().into(), - inc_value_fn.as_global_value().as_pointer_value().into(), + key_fn.as_global_value().as_pointer_value().into(), + value_fn.as_global_value().as_pointer_value().into(), ], &bitcode::DICT_ELEMENTS_RC, ); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index b3a759e641..6634c51862 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use crate::llvm::bitcode::{ - build_compare_wrapper, build_dec_wrapper, build_eq_wrapper, build_inc_wrapper, - build_transform_caller, call_bitcode_fn, call_void_bitcode_fn, + build_compare_wrapper, build_dec_wrapper, build_eq_wrapper, build_inc_n_wrapper, + build_inc_wrapper, build_transform_caller, call_bitcode_fn, call_void_bitcode_fn, }; use crate::llvm::build::{ allocate_with_refcount_help, cast_basic_basic, complex_bitcast, Env, InPlace, @@ -118,7 +118,7 @@ pub fn list_repeat<'a, 'ctx, 'env>( element: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - let inc_element_fn = build_inc_wrapper(env, layout_ids, element_layout); + let inc_element_fn = build_inc_n_wrapper(env, layout_ids, element_layout); call_bitcode_fn_returns_list( env, diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index ad6c1d24de..a03398322e 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -109,7 +109,7 @@ impl<'ctx> PointerToRefcount<'ctx> { env: &Env<'a, 'ctx, 'env>, ) { match mode { - CallMode::Inc(_, inc_amount) => self.increment(inc_amount, env), + CallMode::Inc(inc_amount) => self.increment(inc_amount, env), CallMode::Dec => self.decrement(env, layout), } } @@ -315,14 +315,85 @@ impl<'ctx> PointerToRefcount<'ctx> { fn modify_refcount_struct<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, - value: BasicValueEnum<'ctx>, - layouts: &[Layout<'a>], + layouts: &'a [Layout<'a>], mode: Mode, when_recursive: &WhenRecursive<'a>, +) -> FunctionValue<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let layout = Layout::Struct(layouts); + + let (_, fn_name) = function_name_from_mode( + layout_ids, + &env.interns, + "increment_struct", + "decrement_struct", + &layout, + mode, + ); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let basic_type = basic_type_from_layout(env, &layout); + let function_value = build_header(env, basic_type, mode, &fn_name); + + modify_refcount_struct_help( + env, + layout_ids, + mode, + when_recursive, + layouts, + function_value, + ); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + function +} + +#[allow(clippy::too_many_arguments)] +fn modify_refcount_struct_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + mode: Mode, + when_recursive: &WhenRecursive<'a>, + layouts: &[Layout<'a>], + fn_val: FunctionValue<'ctx>, ) { - let wrapper_struct = value.into_struct_value(); + debug_assert_eq!( + when_recursive, + &WhenRecursive::Unreachable, + "TODO pipe when_recursive through the dict key/value inc/dec" + ); + + let builder = env.builder; + let ctx = env.context; + + // Add a basic block for the entry point + let entry = ctx.append_basic_block(fn_val, "entry"); + + builder.position_at_end(entry); + + debug_info_init!(env, fn_val); + + // Add args to scope + let arg_symbol = Symbol::ARG_1; + let arg_val = fn_val.get_param_iter().next().unwrap(); + + set_name(arg_val, arg_symbol.ident_string(&env.interns)); + + let parent = fn_val; + + let wrapper_struct = arg_val.into_struct_value(); for (i, field_layout) in layouts.iter().enumerate() { if field_layout.contains_refcounted() { @@ -335,13 +406,15 @@ fn modify_refcount_struct<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, field_ptr, field_layout, ); } } + // this function returns void + builder.build_return(None); } pub fn increment_refcount_layout<'a, 'ctx, 'env>( @@ -351,12 +424,24 @@ pub fn increment_refcount_layout<'a, 'ctx, 'env>( inc_amount: u64, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, +) { + let amount = env.ptr_int().const_int(inc_amount, false); + increment_n_refcount_layout(env, parent, layout_ids, amount, value, layout); +} + +pub fn increment_n_refcount_layout<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + layout_ids: &mut LayoutIds<'a>, + amount: IntValue<'ctx>, + value: BasicValueEnum<'ctx>, + layout: &Layout<'a>, ) { modify_refcount_layout( env, parent, layout_ids, - Mode::Inc(inc_amount), + CallMode::Inc(amount), value, layout, ); @@ -369,7 +454,7 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>( value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { - modify_refcount_layout(env, parent, layout_ids, Mode::Dec, value, layout); + modify_refcount_layout(env, parent, layout_ids, CallMode::Dec, value, layout); } fn modify_refcount_builtin<'a, 'ctx, 'env>( @@ -377,37 +462,33 @@ fn modify_refcount_builtin<'a, 'ctx, 'env>( layout_ids: &mut LayoutIds<'a>, mode: Mode, when_recursive: &WhenRecursive<'a>, - value: BasicValueEnum<'ctx>, layout: &Layout<'a>, builtin: &Builtin<'a>, -) { +) -> Option> { use Builtin::*; match builtin { List(memory_mode, element_layout) => { - let wrapper_struct = value.into_struct_value(); - if let MemoryMode::Refcounted = memory_mode { - modify_refcount_list( + let function = modify_refcount_list( env, layout_ids, mode, when_recursive, layout, element_layout, - wrapper_struct, ); + + Some(function) + } else { + None } } Set(element_layout) => { - if element_layout.contains_refcounted() { - // TODO decrement all values - } - todo!(); - } - Dict(key_layout, value_layout) => { - let wrapper_struct = value.into_struct_value(); - modify_refcount_dict( + let key_layout = &Layout::Struct(&[]); + let value_layout = element_layout; + + let function = modify_refcount_dict( env, layout_ids, mode, @@ -415,16 +496,29 @@ fn modify_refcount_builtin<'a, 'ctx, 'env>( layout, key_layout, value_layout, - wrapper_struct, ); + + Some(function) + } + Dict(key_layout, value_layout) => { + let function = modify_refcount_dict( + env, + layout_ids, + mode, + when_recursive, + layout, + key_layout, + value_layout, + ); + + Some(function) } - Str => { - let wrapper_struct = value.into_struct_value(); - modify_refcount_str(env, layout_ids, mode, layout, wrapper_struct); - } + Str => Some(modify_refcount_str(env, layout_ids, mode, layout)), + _ => { debug_assert!(!builtin.is_refcounted()); + None } } } @@ -433,7 +527,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, - mode: Mode, + call_mode: CallMode<'ctx>, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { @@ -441,7 +535,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + call_mode, &WhenRecursive::Unreachable, value, layout, @@ -458,127 +552,29 @@ fn modify_refcount_layout_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, - mode: Mode, + call_mode: CallMode<'ctx>, when_recursive: &WhenRecursive<'a>, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { - use Layout::*; + let mode = match call_mode { + CallMode::Inc(_) => Mode::Inc, + CallMode::Dec => Mode::Dec, + }; + + let function = match modify_refcount_layout_build_function( + env, + parent, + layout_ids, + mode, + when_recursive, + layout, + ) { + Some(f) => f, + None => return, + }; match layout { - Builtin(builtin) => modify_refcount_builtin( - env, - layout_ids, - mode, - when_recursive, - value, - layout, - builtin, - ), - - Union(variant) => { - use UnionLayout::*; - - match variant { - NullableWrapped { - other_tags: tags, .. - } => { - debug_assert!(value.is_pointer_value()); - - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - tags, - value.into_pointer_value(), - true, - ); - } - - NullableUnwrapped { other_fields, .. } => { - debug_assert!(value.is_pointer_value()); - - let other_fields = &other_fields[1..]; - - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - &*env.arena.alloc([other_fields]), - value.into_pointer_value(), - true, - ); - } - - NonNullableUnwrapped(fields) => { - debug_assert!(value.is_pointer_value()); - - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - &*env.arena.alloc([*fields]), - value.into_pointer_value(), - true, - ); - } - - Recursive(tags) => { - debug_assert!(value.is_pointer_value()); - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - tags, - value.into_pointer_value(), - false, - ); - } - - NonRecursive(tags) => { - modify_refcount_union(env, layout_ids, mode, when_recursive, tags, value) - } - } - } - Closure(_, closure_layout, _) => { - if closure_layout.contains_refcounted() { - let wrapper_struct = value.into_struct_value(); - - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, 1, "modify_rc_closure_data") - .unwrap(); - - modify_refcount_layout_help( - env, - parent, - layout_ids, - mode, - when_recursive, - field_ptr, - &closure_layout.as_block_of_memory_layout(), - ) - } - } - - Struct(layouts) => { - modify_refcount_struct( - env, - parent, - layout_ids, - value, - layouts, - mode, - when_recursive, - ); - } - - PhantomEmptyStruct => {} - Layout::RecursivePointer => match when_recursive { WhenRecursive::Unreachable => { unreachable!("recursion pointers should never be hashed directly") @@ -594,19 +590,172 @@ fn modify_refcount_layout_help<'a, 'ctx, 'env>( .build_bitcast(value, bt, "i64_to_opaque") .into_pointer_value(); - modify_refcount_layout_help( + call_help(env, function, call_mode, field_cast.into()); + } + }, + _ => { + call_help(env, function, call_mode, value); + } + } +} + +fn call_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + function: FunctionValue<'ctx>, + call_mode: CallMode<'ctx>, + value: BasicValueEnum<'ctx>, +) -> inkwell::values::CallSiteValue<'ctx> { + let call = match call_mode { + CallMode::Inc(inc_amount) => { + env.builder + .build_call(function, &[value, inc_amount.into()], "increment") + } + CallMode::Dec => env.builder.build_call(function, &[value], "decrement"), + }; + + call.set_call_convention(FAST_CALL_CONV); + + call +} + +fn modify_refcount_layout_build_function<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + layout_ids: &mut LayoutIds<'a>, + mode: Mode, + when_recursive: &WhenRecursive<'a>, + layout: &Layout<'a>, +) -> Option> { + use Layout::*; + + match layout { + Builtin(builtin) => { + modify_refcount_builtin(env, layout_ids, mode, when_recursive, layout, builtin) + } + + Union(variant) => { + use UnionLayout::*; + + match variant { + NullableWrapped { + other_tags: tags, .. + } => { + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + tags, + true, + ); + + Some(function) + } + + NullableUnwrapped { other_fields, .. } => { + let other_fields = &other_fields[1..]; + + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + &*env.arena.alloc([other_fields]), + true, + ); + + Some(function) + } + + NonNullableUnwrapped(fields) => { + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + &*env.arena.alloc([*fields]), + true, + ); + Some(function) + } + + Recursive(tags) => { + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + tags, + false, + ); + Some(function) + } + + NonRecursive(tags) => { + let function = + modify_refcount_union(env, layout_ids, mode, when_recursive, tags); + + Some(function) + } + } + } + Closure(argument_layouts, closure_layout, return_layout) => { + if closure_layout.contains_refcounted() { + // Temporary hack to make this work for now. With defunctionalization, none of this + // will matter + let p2 = closure_layout.as_block_of_memory_layout(); + let mut argument_layouts = + Vec::from_iter_in(argument_layouts.iter().copied(), env.arena); + argument_layouts.push(p2); + let argument_layouts = argument_layouts.into_bump_slice(); + + let p1 = Layout::FunctionPointer(argument_layouts, return_layout); + let actual_layout = Layout::Struct(env.arena.alloc([p1, p2])); + + let function = modify_refcount_layout_build_function( + env, + parent, + layout_ids, + mode, + when_recursive, + &actual_layout, + )?; + + Some(function) + } else { + None + } + } + + Struct(layouts) => { + let function = modify_refcount_struct(env, layout_ids, layouts, mode, when_recursive); + + Some(function) + } + + PhantomEmptyStruct => None, + + Layout::RecursivePointer => match when_recursive { + WhenRecursive::Unreachable => { + unreachable!("recursion pointers should never be hashed directly") + } + WhenRecursive::Loop(union_layout) => { + let layout = Layout::Union(*union_layout); + + let function = modify_refcount_layout_build_function( env, parent, layout_ids, mode, when_recursive, - field_cast.into(), &layout, - ) + )?; + + Some(function) } }, - FunctionPointer(_, _) | Pointer(_) => {} + FunctionPointer(_, _) | Pointer(_) => None, } } @@ -617,12 +766,11 @@ fn modify_refcount_list<'a, 'ctx, 'env>( when_recursive: &WhenRecursive<'a>, layout: &Layout<'a>, element_layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { +) -> FunctionValue<'ctx> { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_list", @@ -655,13 +803,13 @@ fn modify_refcount_list<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, original_wrapper.into(), call_name); + function } fn mode_to_call_mode(function: FunctionValue<'_>, mode: Mode) -> CallMode<'_> { match mode { Mode::Dec => CallMode::Dec, - Mode::Inc(num) => CallMode::Inc(num, function.get_nth_param(1).unwrap().into_int_value()), + Mode::Inc => CallMode::Inc(function.get_nth_param(1).unwrap().into_int_value()), } } @@ -720,7 +868,7 @@ fn modify_refcount_list_help<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, element, element_layout, @@ -755,12 +903,11 @@ fn modify_refcount_str<'a, 'ctx, 'env>( layout_ids: &mut LayoutIds<'a>, mode: Mode, layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { +) -> FunctionValue<'ctx> { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_str", @@ -785,7 +932,7 @@ fn modify_refcount_str<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, original_wrapper.into(), call_name); + function } fn modify_refcount_str_help<'a, 'ctx, 'env>( @@ -855,12 +1002,11 @@ fn modify_refcount_dict<'a, 'ctx, 'env>( layout: &Layout<'a>, key_layout: &Layout<'a>, value_layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { +) -> FunctionValue<'ctx> { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_dict", @@ -894,7 +1040,7 @@ fn modify_refcount_dict<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, original_wrapper.into(), call_name); + function } #[allow(clippy::too_many_arguments)] @@ -990,7 +1136,7 @@ fn build_header<'a, 'ctx, 'env>( fn_name: &str, ) -> FunctionValue<'ctx> { match mode { - Mode::Inc(_) => build_header_help( + Mode::Inc => build_header_help( env, fn_name, env.context.void_type().into(), @@ -1036,13 +1182,26 @@ pub fn build_header_help<'a, 'ctx, 'env>( #[derive(Clone, Copy)] pub enum Mode { - Inc(u64), + Inc, Dec, } +impl Mode { + fn to_call_mode<'ctx>(&self, function: FunctionValue<'ctx>) -> CallMode<'ctx> { + match self { + Mode::Inc => { + let amount = function.get_nth_param(1).unwrap().into_int_value(); + + CallMode::Inc(amount) + } + Mode::Dec => CallMode::Dec, + } + } +} + #[derive(Clone, Copy)] enum CallMode<'ctx> { - Inc(u64, IntValue<'ctx>), + Inc(IntValue<'ctx>), Dec, } @@ -1052,12 +1211,11 @@ fn build_rec_union<'a, 'ctx, 'env>( mode: Mode, when_recursive: &WhenRecursive<'a>, fields: &'a [&'a [Layout<'a>]], - value: PointerValue<'ctx>, is_nullable: bool, -) { +) -> FunctionValue<'ctx> { let layout = Layout::Union(UnionLayout::Recursive(fields)); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_rec_union", @@ -1095,7 +1253,7 @@ fn build_rec_union<'a, 'ctx, 'env>( } }; - call_help(env, function, mode, value.into(), call_name); + function } fn build_rec_union_help<'a, 'ctx, 'env>( @@ -1112,7 +1270,7 @@ fn build_rec_union_help<'a, 'ctx, 'env>( let context = &env.context; let builder = env.builder; - let pick = |a, b| if let Mode::Inc(_) = mode { a } else { b }; + let pick = |a, b| if let Mode::Inc = mode { a } else { b }; // Add a basic block for the entry point let entry = context.append_basic_block(fn_val, "entry"); @@ -1258,17 +1416,16 @@ fn build_rec_union_help<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, field, field_layout, ); } - let call_name = pick("recursive_tag_increment", "recursive_tag_decrement"); for ptr in deferred_rec { // recursively decrement the field - let call = call_help(env, fn_val, mode, ptr, call_name); + let call = call_help(env, fn_val, mode.to_call_mode(fn_val), ptr); call.set_tail_call(true); } @@ -1331,28 +1488,6 @@ fn rec_union_read_tag<'a, 'ctx, 'env>( .into_int_value() } -fn call_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - function: FunctionValue<'ctx>, - mode: Mode, - value: BasicValueEnum<'ctx>, - call_name: &str, -) -> inkwell::values::CallSiteValue<'ctx> { - let call = match mode { - Mode::Inc(inc_amount) => { - let rc_increment = ptr_int(env.context, env.ptr_bytes).const_int(inc_amount, false); - - env.builder - .build_call(function, &[value, rc_increment.into()], call_name) - } - Mode::Dec => env.builder.build_call(function, &[value], call_name), - }; - - call.set_call_convention(FAST_CALL_CONV); - - call -} - fn function_name_from_mode<'a>( layout_ids: &mut LayoutIds<'a>, interns: &Interns, @@ -1368,7 +1503,7 @@ fn function_name_from_mode<'a>( // rather confusing, so now `inc_x` always corresponds to `dec_x` let layout_id = layout_ids.get(Symbol::DEC, layout); match mode { - Mode::Inc(_) => (if_inc, layout_id.to_symbol_string(Symbol::INC, interns)), + Mode::Inc => (if_inc, layout_id.to_symbol_string(Symbol::INC, interns)), Mode::Dec => (if_dec, layout_id.to_symbol_string(Symbol::DEC, interns)), } } @@ -1379,14 +1514,13 @@ fn modify_refcount_union<'a, 'ctx, 'env>( mode: Mode, when_recursive: &WhenRecursive<'a>, fields: &'a [&'a [Layout<'a>]], - value: BasicValueEnum<'ctx>, -) { +) -> FunctionValue<'ctx> { let layout = Layout::Union(UnionLayout::NonRecursive(fields)); let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_union", @@ -1418,7 +1552,7 @@ fn modify_refcount_union<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, value, call_name); + function } fn modify_refcount_union_help<'a, 'ctx, 'env>( @@ -1509,7 +1643,7 @@ fn modify_refcount_union_help<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, field_ptr, field_layout,