diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index 584759ad5d..0d8980ac30 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -101,16 +101,19 @@ impl<'ctx> PointerToRefcount<'ctx> { env.builder.build_store(self.value, refcount); } - fn modify<'a, 'env>(&self, mode: Mode, layout: &Layout<'a>, env: &Env<'a, 'ctx, 'env>) { + fn modify<'a, 'env>( + &self, + mode: CallMode<'ctx>, + layout: &Layout<'a>, + env: &Env<'a, 'ctx, 'env>, + ) { match mode { - Mode::Inc(inc_amount) => self.increment(inc_amount, env), - Mode::Dec => self.decrement(env, layout), + CallMode::Inc(_, inc_amount) => self.increment(inc_amount, env), + CallMode::Dec => self.decrement(env, layout), } } - fn increment<'a, 'env>(&self, amount: u64, env: &Env<'a, 'ctx, 'env>) { - debug_assert!(amount > 0); - + fn increment<'a, 'env>(&self, amount: IntValue<'ctx>, env: &Env<'a, 'ctx, 'env>) { let refcount = self.get_refcount(env); let builder = env.builder; let refcount_type = ptr_int(env.context, env.ptr_bytes); @@ -121,11 +124,7 @@ impl<'ctx> PointerToRefcount<'ctx> { refcount_type.const_int(REFCOUNT_MAX as u64, false), "refcount_max_check", ); - let incremented = builder.build_int_add( - refcount, - refcount_type.const_int(amount, false), - "increment_refcount", - ); + let incremented = builder.build_int_add(refcount, amount, "increment_refcount"); let new_refcount = builder .build_select(max, refcount, incremented, "select_refcount") @@ -173,6 +172,7 @@ impl<'ctx> PointerToRefcount<'ctx> { env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); + let call = env .builder .build_call(function, &[refcount_ptr.into()], fn_name); @@ -525,7 +525,7 @@ fn modify_refcount_list<'a, 'ctx, 'env>( Some(function_value) => function_value, None => { let basic_type = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); modify_refcount_list_help(env, mode, layout, function_value); @@ -536,11 +536,15 @@ fn modify_refcount_list<'a, 'ctx, 'env>( env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[original_wrapper.into()], call_name); - call.set_call_convention(FAST_CALL_CONV); + call_help(env, function, mode, original_wrapper.into(), call_name); +} + +fn mode_to_call_mode<'a, 'ctx, 'env>(function: FunctionValue<'ctx>, mode: Mode) -> CallMode<'ctx> { + match mode { + Mode::Dec => CallMode::Dec, + Mode::Inc(num) => CallMode::Inc(num, function.get_nth_param(1).unwrap().into_int_value()), + } } fn modify_refcount_list_help<'a, 'ctx, 'env>( @@ -614,7 +618,8 @@ fn modify_refcount_list_help<'a, 'ctx, 'env>( builder.position_at_end(modification_block); let refcount_ptr = PointerToRefcount::from_list_wrapper(env, original_wrapper); - refcount_ptr.modify(mode, layout, env); + let call_mode = mode_to_call_mode(fn_val, mode); + refcount_ptr.modify(call_mode, layout, env); builder.build_unconditional_branch(cont_block); @@ -647,7 +652,7 @@ fn modify_refcount_str<'a, 'ctx, 'env>( Some(function_value) => function_value, None => { let basic_type = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); modify_refcount_str_help(env, mode, layout, function_value); @@ -658,10 +663,8 @@ fn modify_refcount_str<'a, 'ctx, 'env>( env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[original_wrapper.into()], call_name); - call.set_call_convention(FAST_CALL_CONV); + + call_help(env, function, mode, original_wrapper.into(), call_name); } fn modify_refcount_str_help<'a, 'ctx, 'env>( @@ -739,7 +742,8 @@ fn modify_refcount_str_help<'a, 'ctx, 'env>( builder.position_at_end(modification_block); let refcount_ptr = PointerToRefcount::from_list_wrapper(env, str_wrapper); - refcount_ptr.modify(mode, layout, env); + let call_mode = mode_to_call_mode(fn_val, mode); + refcount_ptr.modify(call_mode, layout, env); builder.build_unconditional_branch(cont_block); @@ -753,9 +757,18 @@ fn modify_refcount_str_help<'a, 'ctx, 'env>( fn build_header<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, arg_type: BasicTypeEnum<'ctx>, + mode: Mode, fn_name: &str, ) -> FunctionValue<'ctx> { - build_header_help(env, fn_name, env.context.void_type().into(), &[arg_type]) + match mode { + Mode::Inc(_) => build_header_help( + env, + fn_name, + env.context.void_type().into(), + &[arg_type, ptr_int(env.context, env.ptr_bytes).into()], + ), + Mode::Dec => build_header_help(env, fn_name, env.context.void_type().into(), &[arg_type]), + } } /// Build an increment or decrement function for a specific layout @@ -798,6 +811,12 @@ enum Mode { Dec, } +#[derive(Clone, Copy)] +enum CallMode<'ctx> { + Inc(u64, IntValue<'ctx>), + Dec, +} + fn build_rec_union<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -808,9 +827,6 @@ fn build_rec_union<'a, 'ctx, 'env>( ) { let layout = Layout::Union(UnionLayout::Recursive(fields)); - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, symbol) = match mode { Mode::Inc(_) => ("increment_rec_union", Symbol::INC), Mode::Dec => ("decrement_rec_union", Symbol::DEC), @@ -823,23 +839,25 @@ fn build_rec_union<'a, 'ctx, 'env>( let function = match env.module.get_function(fn_name.as_str()) { Some(function_value) => function_value, None => { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + let basic_type = block_of_memory_slices(env.context, fields, env.ptr_bytes) .ptr_type(AddressSpace::Generic) .into(); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); build_rec_union_help(env, layout_ids, mode, fields, function_value, is_nullable); + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + function_value } }; - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - let call = env.builder.build_call(function, &[value.into()], call_name); - - call.set_call_convention(FAST_CALL_CONV); + call_help(env, function, mode, value.into(), call_name); } fn build_rec_union_help<'a, 'ctx, 'env>( @@ -981,14 +999,9 @@ fn build_rec_union_help<'a, 'ctx, 'env>( ); // recursively decrement the field - let call = env.builder.build_call( - fn_val, - &[recursive_field_ptr], - pick("recursive_tag_increment", "recursive_tag_decrement"), - ); - - // Because it's an internal-only function, use the fast calling convention. - call.set_call_convention(FAST_CALL_CONV); + let call_name = pick("recursive_tag_increment", "recursive_tag_decrement"); + let call_mode = mode_to_call_mode(fn_val, mode); + recursive_call_help(env, fn_val, call_mode, recursive_field_ptr, call_name); } else if field_layout.contains_refcounted() { // TODO this loads the whole field onto the stack; // that's wasteful if e.g. the field is a big record, where only @@ -1039,7 +1052,8 @@ fn build_rec_union_help<'a, 'ctx, 'env>( // increment/decrement the cons-cell itself let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); - refcount_ptr.modify(mode, &layout, env); + let call_mode = mode_to_call_mode(fn_val, mode); + refcount_ptr.modify(call_mode, &layout, env); // this function returns void builder.build_return(None); @@ -1062,6 +1076,44 @@ fn rec_union_read_tag<'a, 'ctx, 'env>( .into_int_value() } +fn call_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + function: FunctionValue<'ctx>, + mode: Mode, + value: BasicValueEnum<'ctx>, + call_name: &str, +) { + let call = match mode { + Mode::Inc(inc_amount) => { + let rc_increment = ptr_int(env.context, env.ptr_bytes).const_int(inc_amount, false); + + env.builder + .build_call(function, &[value, rc_increment.into()], call_name) + } + Mode::Dec => env.builder.build_call(function, &[value], call_name), + }; + + call.set_call_convention(FAST_CALL_CONV); +} + +fn recursive_call_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + function: FunctionValue<'ctx>, + mode: CallMode, + value: BasicValueEnum<'ctx>, + call_name: &str, +) { + let call = match mode { + CallMode::Inc(_, inc_amount) => { + env.builder + .build_call(function, &[value, inc_amount.into()], call_name) + } + CallMode::Dec => env.builder.build_call(function, &[value], call_name), + }; + + call.set_call_convention(FAST_CALL_CONV); +} + fn modify_refcount_union<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -1087,7 +1139,7 @@ fn modify_refcount_union<'a, 'ctx, 'env>( Some(function_value) => function_value, None => { let basic_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); modify_refcount_union_help(env, layout_ids, mode, fields, function_value); @@ -1098,9 +1150,8 @@ fn modify_refcount_union<'a, 'ctx, 'env>( env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); - let call = env.builder.build_call(function, &[value], call_name); - call.set_call_convention(FAST_CALL_CONV); + call_help(env, function, mode, value, call_name); } fn modify_refcount_union_help<'a, 'ctx, 'env>(