diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index e9d87c032b..7ddbfd3e9f 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -1,6 +1,6 @@ use crate::llvm::build::{ - cast_basic_basic, cast_block_of_memory_to_tag, create_entry_block_alloca, set_name, Env, Scope, - FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64, + cast_basic_basic, cast_block_of_memory_to_tag, set_name, Env, FAST_CALL_CONV, + LLVM_SADD_WITH_OVERFLOW_I64, }; use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list}; use crate::llvm::convert::{ @@ -591,25 +591,12 @@ fn modify_refcount_list_help<'a, 'ctx, 'env>( ); builder.set_current_debug_location(&ctx, loc); - let mut scope = Scope::default(); - // Add args to scope let arg_symbol = Symbol::ARG_1; let arg_val = fn_val.get_param_iter().next().unwrap(); set_name(arg_val, arg_symbol.ident_string(&env.interns)); - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - let parent = fn_val; let original_wrapper = arg_val.into_struct_value(); @@ -711,25 +698,12 @@ fn modify_refcount_str_help<'a, 'ctx, 'env>( ); builder.set_current_debug_location(&ctx, loc); - let mut scope = Scope::default(); - // Add args to scope let arg_symbol = Symbol::ARG_1; let arg_val = fn_val.get_param_iter().next().unwrap(); set_name(arg_val, arg_symbol.ident_string(&env.interns)); - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - let parent = fn_val; let str_wrapper = arg_val.into_struct_value(); @@ -744,7 +718,7 @@ fn modify_refcount_str_help<'a, 'ctx, 'env>( IntPredicate::SGT, len, ptr_int(ctx, env.ptr_bytes).const_zero(), - "len > 0", + "is_big_str", ); // the block we'll always jump to when we're done @@ -939,6 +913,10 @@ fn build_rec_union_help<'a, 'ctx, 'env>( false })(); + // to increment/decrement the cons-cell itself + let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); + let call_mode = mode_to_call_mode(fn_val, mode); + let ctx = env.context; let cont_block = ctx.append_basic_block(parent, "cont"); if is_nullable { @@ -963,10 +941,6 @@ fn build_rec_union_help<'a, 'ctx, 'env>( // next, make a jump table for all possible values of the tag_id let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - let merge_block = env - .context - .append_basic_block(parent, pick("increment_merge", "decrement_merge")); - builder.set_current_debug_location(&context, loc); for (tag_id, field_layouts) in tags.iter().enumerate() { @@ -1001,6 +975,11 @@ fn build_rec_union_help<'a, 'ctx, 'env>( ) .into_pointer_value(); + // defer actually performing the refcount modifications until after the current cell has + // been decremented, see below + let mut deferred_rec = Vec::new_in(env.arena); + let mut deferred_nonrec = Vec::new_in(env.arena); + for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { // this field has type `*i64`, but is really a pointer to the data we want @@ -1023,13 +1002,8 @@ fn build_rec_union_help<'a, 'ctx, 'env>( union_type.ptr_type(AddressSpace::Generic).into(), ); - // recursively decrement the field - let call_name = pick("recursive_tag_increment", "recursive_tag_decrement"); - call_help(env, fn_val, mode, recursive_field_ptr, call_name); + deferred_rec.push(recursive_field_ptr); } else if field_layout.contains_refcounted() { - // TODO this loads the whole field onto the stack; - // that's wasteful if e.g. the field is a big record, where only - // some fields are actually refcounted. let elem_pointer = env .builder .build_struct_gep(struct_ptr, i as u32, "gep_recursive_pointer") @@ -1040,11 +1014,31 @@ fn build_rec_union_help<'a, 'ctx, 'env>( pick("increment_struct_field", "decrement_struct_field"), ); - modify_refcount_layout(env, parent, layout_ids, mode, field, field_layout); + deferred_nonrec.push((field, field_layout)); } } - env.builder.build_unconditional_branch(merge_block); + // OPTIMIZATION + // + // We really would like `inc/dec` to be tail-recursive; it gives roughly a 2X speedup on linked + // lists. To achieve it, we must first load all fields that we want to inc/dec (done above) + // and store them on the stack, then modify (and potentially free) the current cell, then + // actually inc/dec the fields. + refcount_ptr.modify(call_mode, &layout, env); + + for (field, field_layout) in deferred_nonrec { + modify_refcount_layout(env, parent, layout_ids, mode, field, field_layout); + } + + let call_name = pick("recursive_tag_increment", "recursive_tag_decrement"); + for ptr in deferred_rec { + // recursively decrement the field + let call = call_help(env, fn_val, mode, ptr, call_name); + call.set_tail_call(true); + } + + // this function returns void + builder.build_return(None); cases.push(( env.context.i64_type().const_int(tag_id as u64, false), @@ -1067,20 +1061,22 @@ fn build_rec_union_help<'a, 'ctx, 'env>( // read the tag_id let current_tag_id = rec_union_read_tag(env, value_ptr); + let merge_block = env + .context + .append_basic_block(parent, pick("increment_merge", "decrement_merge")); + // switch on it env.builder .build_switch(current_tag_id, merge_block, &cases); + + env.builder.position_at_end(merge_block); + + // increment/decrement the cons-cell itself + refcount_ptr.modify(call_mode, &layout, env); + + // this function returns void + builder.build_return(None); } - - env.builder.position_at_end(merge_block); - - // increment/decrement the cons-cell itself - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); - let call_mode = mode_to_call_mode(fn_val, mode); - refcount_ptr.modify(call_mode, &layout, env); - - // this function returns void - builder.build_return(None); } fn rec_union_read_tag<'a, 'ctx, 'env>( @@ -1106,7 +1102,7 @@ fn call_help<'a, 'ctx, 'env>( mode: Mode, value: BasicValueEnum<'ctx>, call_name: &str, -) { +) -> inkwell::values::CallSiteValue<'ctx> { let call = match mode { Mode::Inc(inc_amount) => { let rc_increment = ptr_int(env.context, env.ptr_bytes).const_int(inc_amount, false); @@ -1118,6 +1114,8 @@ fn call_help<'a, 'ctx, 'env>( }; call.set_call_convention(FAST_CALL_CONV); + + call } fn modify_refcount_union<'a, 'ctx, 'env>(