diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index c474c4c0f8..d9e264ebcf 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -2026,12 +2026,12 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( // This doesn't currently do anything context.i64_type().const_zero().into() } - Inc(symbol, _inc, cont) => { + Inc(symbol, inc_amount, cont) => { let (value, layout) = load_symbol_and_layout(env, scope, symbol); let layout = layout.clone(); if layout.contains_refcounted() { - increment_refcount_layout(env, parent, layout_ids, value, &layout); + increment_refcount_layout(env, parent, layout_ids, *inc_amount, value, &layout); } build_exp_stmt(env, layout_ids, scope, parent, cont) diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 0600170390..480dc34705 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -576,7 +576,7 @@ pub fn list_get_unsafe<'a, 'ctx, 'env>( let result = builder.build_load(elem_ptr, "List.get"); - increment_refcount_layout(env, parent, layout_ids, result, elem_layout); + increment_refcount_layout(env, parent, layout_ids, 1, result, elem_layout); result } @@ -1369,11 +1369,11 @@ macro_rules! list_map_help { let list_ptr = load_list_ptr(builder, list_wrapper, ptr_type); let list_loop = |index, before_elem| { - increment_refcount_layout($env, parent, layout_ids, before_elem, elem_layout); + increment_refcount_layout($env, parent, layout_ids, 1, before_elem, elem_layout); let arguments = match closure_info { Some((closure_data_layout, closure_data)) => { - increment_refcount_layout( $env, parent, layout_ids, closure_data, closure_data_layout); + increment_refcount_layout( $env, parent, layout_ids, 1, closure_data, closure_data_layout); bumpalo::vec![in $env.arena; before_elem, closure_data] } diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index 57fe578bfd..f130fc1327 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -101,6 +101,13 @@ impl<'ctx> PointerToRefcount<'ctx> { env.builder.build_store(self.value, refcount); } + fn modify<'a, 'env>(&self, mode: Mode, layout: &Layout<'a>, env: &Env<'a, 'ctx, 'env>) { + match mode { + Mode::Inc(inc_amount) => self.increment(inc_amount, env), + Mode::Dec => self.decrement(env, layout), + } + } + fn increment<'a, 'env>(&self, amount: u64, env: &Env<'a, 'ctx, 'env>) { debug_assert!(amount > 0); @@ -303,18 +310,29 @@ fn modify_refcount_struct<'a, 'ctx, 'env>( .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") .unwrap(); - match mode { - Mode::Inc(_) => { - increment_refcount_layout(env, parent, layout_ids, field_ptr, field_layout) - } - Mode::Dec => { - decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout) - } - } + modify_refcount_layout(env, parent, layout_ids, mode, field_ptr, field_layout); } } } +pub fn increment_refcount_layout<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + layout_ids: &mut LayoutIds<'a>, + inc_amount: u64, + value: BasicValueEnum<'ctx>, + layout: &Layout<'a>, +) { + modify_refcount_layout( + env, + parent, + layout_ids, + Mode::Inc(inc_amount), + value, + layout, + ); +} + pub fn decrement_refcount_layout<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, @@ -449,10 +467,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>( ); } - NonRecursive(tags) => match mode { - Mode::Inc(_) => build_inc_union(env, layout_ids, tags, value), - Mode::Dec => build_dec_union(env, layout_ids, tags, value), - }, + NonRecursive(tags) => modify_refcount_union(env, layout_ids, mode, tags, value), } } Closure(_, closure_layout, _) => { @@ -487,16 +502,6 @@ fn modify_refcount_layout<'a, 'ctx, 'env>( } } -pub fn increment_refcount_layout<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - parent: FunctionValue<'ctx>, - layout_ids: &mut LayoutIds<'a>, - value: BasicValueEnum<'ctx>, - layout: &Layout<'a>, -) { - modify_refcount_layout(env, parent, layout_ids, Mode::Inc(1), value, layout); -} - fn modify_refcount_list<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -605,11 +610,7 @@ fn modify_refcount_list_help<'a, 'ctx, 'env>( builder.position_at_end(modification_block); let refcount_ptr = PointerToRefcount::from_list_wrapper(env, original_wrapper); - - match mode { - Mode::Inc(inc_amount) => refcount_ptr.increment(inc_amount, env), - Mode::Dec => refcount_ptr.decrement(env, layout), - } + refcount_ptr.modify(mode, layout, env); builder.build_unconditional_branch(cont_block); @@ -734,11 +735,7 @@ fn modify_refcount_str_help<'a, 'ctx, 'env>( builder.position_at_end(modification_block); let refcount_ptr = PointerToRefcount::from_list_wrapper(env, str_wrapper); - - match mode { - Mode::Inc(inc_amount) => refcount_ptr.increment(inc_amount, env), - Mode::Dec => refcount_ptr.decrement(env, layout), - } + refcount_ptr.modify(mode, layout, env); builder.build_unconditional_branch(cont_block); @@ -749,7 +746,7 @@ fn modify_refcount_str_help<'a, 'ctx, 'env>( } /// Build an increment or decrement function for a specific layout -pub fn build_header<'a, 'ctx, 'env>( +fn build_header<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, arg_type: BasicTypeEnum<'ctx>, fn_name: &str, @@ -1002,14 +999,7 @@ fn build_rec_union_help<'a, 'ctx, 'env>( pick("increment_struct_field", "decrement_struct_field"), ); - match mode { - Mode::Inc(_) => { - increment_refcount_layout(env, parent, layout_ids, field, field_layout) - } - Mode::Dec => { - decrement_refcount_layout(env, parent, layout_ids, field, field_layout) - } - } + modify_refcount_layout(env, parent, layout_ids, mode, field, field_layout); } } @@ -1044,181 +1034,8 @@ fn build_rec_union_help<'a, 'ctx, 'env>( env.builder.position_at_end(merge_block); // increment/decrement the cons-cell itself - match mode { - Mode::Inc(inc_amount) => { - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); - refcount_ptr.increment(inc_amount, env); - } - Mode::Dec => { - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); - refcount_ptr.decrement(env, &layout); - } - } - - // this function returns void - builder.build_return(None); -} - -pub fn build_dec_union<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - fields: &'a [&'a [Layout<'a>]], - value: BasicValueEnum<'ctx>, -) { - let layout = Layout::Union(UnionLayout::NonRecursive(fields)); - - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - - let symbol = Symbol::DEC; - let fn_name = layout_ids - .get(symbol, &layout) - .to_symbol_string(symbol, &env.interns); - - let function = match env.module.get_function(fn_name.as_str()) { - Some(function_value) => function_value, - None => { - let basic_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); - - build_dec_union_help(env, layout_ids, fields, function_value); - - function_value - } - }; - - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - - let call = env - .builder - .build_call(function, &[value], "decrement_union"); - - call.set_call_convention(FAST_CALL_CONV); -} - -pub fn build_dec_union_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - tags: &[&[Layout<'a>]], - fn_val: FunctionValue<'ctx>, -) { - debug_assert!(!tags.is_empty()); - - let context = &env.context; - let builder = env.builder; - - // Add a basic block for the entry point - let entry = context.append_basic_block(fn_val, "entry"); - - builder.position_at_end(entry); - - let func_scope = fn_val.get_subprogram().unwrap(); - let lexical_block = env.dibuilder.create_lexical_block( - /* scope */ func_scope.as_debug_info_scope(), - /* file */ env.compile_unit.get_file(), - /* line_no */ 0, - /* column_no */ 0, - ); - - let loc = env.dibuilder.create_debug_location( - context, - /* line */ 0, - /* column */ 0, - /* current_scope */ lexical_block.as_debug_info_scope(), - /* inlined_at */ None, - ); - builder.set_current_debug_location(&context, loc); - - // Add args to scope - let arg_symbol = Symbol::ARG_1; - - let arg_val = fn_val.get_param_iter().next().unwrap(); - - set_name(arg_val, arg_symbol.ident_string(&env.interns)); - - let parent = fn_val; - - let before_block = env.builder.get_insert_block().expect("to be in a function"); - - debug_assert!(arg_val.is_struct_value()); - let wrapper_struct = arg_val.into_struct_value(); - - // next, make a jump table for all possible values of the tag_id - let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - - let merge_block = env.context.append_basic_block(parent, "decrement_merge"); - - builder.set_current_debug_location(&context, loc); - - for (tag_id, field_layouts) in tags.iter().enumerate() { - // if none of the fields are or contain anything refcounted, just move on - if !field_layouts - .iter() - .any(|x| x.is_refcounted() || x.contains_refcounted()) - { - continue; - } - - let block = env.context.append_basic_block(parent, "tag_id_decrement"); - env.builder.position_at_end(block); - - let wrapper_type = basic_type_from_layout( - env.arena, - env.context, - &Layout::Struct(field_layouts), - env.ptr_bytes, - ); - - debug_assert!(wrapper_type.is_struct_type()); - let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type); - - for (i, field_layout) in field_layouts.iter().enumerate() { - if let Layout::RecursivePointer = field_layout { - panic!("a non-recursive tag union cannot contain RecursivePointer"); - } else if field_layout.contains_refcounted() { - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") - .unwrap(); - - decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout); - } - } - - env.builder.build_unconditional_branch(merge_block); - - cases.push(( - env.context.i64_type().const_int(tag_id as u64, false), - block, - )); - } - - cases.reverse(); - - env.builder.position_at_end(before_block); - - // read the tag_id - let current_tag_id = { - // the first element of the wrapping struct is an array of i64 - let first_array = env - .builder - .build_extract_value(wrapper_struct, 0, "read_tag_id") - .unwrap() - .into_array_value(); - - env.builder - .build_extract_value(first_array, 0, "read_tag_id_2") - .unwrap() - .into_int_value() - }; - - // switch on it - env.builder - .build_switch(current_tag_id, merge_block, &cases); - - env.builder.position_at_end(merge_block); + let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); + refcount_ptr.modify(mode, &layout, env); // this function returns void builder.build_return(None); @@ -1241,9 +1058,10 @@ fn rec_union_read_tag<'a, 'ctx, 'env>( .into_int_value() } -pub fn build_inc_union<'a, 'ctx, 'env>( +fn modify_refcount_union<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, fields: &'a [&'a [Layout<'a>]], value: BasicValueEnum<'ctx>, ) { @@ -1263,7 +1081,7 @@ pub fn build_inc_union<'a, 'ctx, 'env>( let basic_type = block_of_memory(env.context, &layout, env.ptr_bytes); let function_value = build_header(env, basic_type, &fn_name); - build_inc_union_help(env, layout_ids, fields, function_value); + modify_refcount_union_help(env, layout_ids, mode, fields, function_value); function_value } @@ -1279,9 +1097,10 @@ pub fn build_inc_union<'a, 'ctx, 'env>( call.set_call_convention(FAST_CALL_CONV); } -pub fn build_inc_union_help<'a, 'ctx, 'env>( +fn modify_refcount_union_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, tags: &[&[Layout<'a>]], fn_val: FunctionValue<'ctx>, ) { @@ -1320,7 +1139,6 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( let parent = fn_val; - let layout = Layout::Union(UnionLayout::Recursive(tags)); let before_block = env.builder.get_insert_block().expect("to be in a function"); let wrapper_struct = arg_val.into_struct_value(); @@ -1347,7 +1165,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( // next, make a jump table for all possible values of the tag_id let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - let merge_block = env.context.append_basic_block(parent, "increment_merge"); + let merge_block = env + .context + .append_basic_block(parent, "modify_rc_union_merge"); for (tag_id, field_layouts) in tags.iter().enumerate() { // if none of the fields are or contain anything refcounted, just move on @@ -1358,7 +1178,7 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( continue; } - let block = env.context.append_basic_block(parent, "tag_id_increment"); + let block = env.context.append_basic_block(parent, "tag_id_modify"); env.builder.position_at_end(block); let wrapper_type = basic_type_from_layout( @@ -1373,48 +1193,14 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { - // this field has type `*i64`, but is really a pointer to the data we want - let ptr_as_i64_ptr = env - .builder - .build_extract_value(wrapper_struct, i as u32, "increment_struct_field") - .unwrap(); - - debug_assert!(ptr_as_i64_ptr.is_pointer_value()); - - // therefore we must cast it to our desired type - let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let recursive_field_ptr = env - .builder - .build_bitcast( - ptr_as_i64_ptr, - union_type.ptr_type(AddressSpace::Generic), - "recursive_to_desired", - ) - .into_pointer_value(); - - let recursive_field = env - .builder - .build_load(recursive_field_ptr, "load_recursive_field"); - - // recursively increment the field - let call = - env.builder - .build_call(fn_val, &[recursive_field], "recursive_tag_increment"); - - // Because it's an internal-only function, use the fast calling convention. - call.set_call_convention(FAST_CALL_CONV); - - // TODO do this decrement before the recursive call? - // Then the recursive call is potentially TCE'd - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, recursive_field_ptr); - refcount_ptr.increment(1, env); + panic!("non-recursive tag unions cannot contain naked recursion pointers!"); } else if field_layout.contains_refcounted() { let field_ptr = env .builder - .build_extract_value(wrapper_struct, i as u32, "increment_struct_field") + .build_extract_value(wrapper_struct, i as u32, "modify_tag_field") .unwrap(); - increment_refcount_layout(env, parent, layout_ids, field_ptr, field_layout); + modify_refcount_layout(env, parent, layout_ids, mode, field_ptr, field_layout); } } diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index 6a4dfb297a..6b67823c86 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -262,7 +262,9 @@ impl<'a> Context<'a> { } } - fn add_inc(&self, symbol: Symbol, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + fn add_inc(&self, symbol: Symbol, inc_amount: u64, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + debug_assert!(inc_amount > 0); + let info = self.get_var_info(symbol); if info.persistent { @@ -275,7 +277,7 @@ impl<'a> Context<'a> { return stmt; } - self.arena.alloc(Stmt::Inc(symbol, 1, stmt)) + self.arena.alloc(Stmt::Inc(symbol, inc_amount, stmt)) } fn add_dec(&self, symbol: Symbol, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { @@ -335,11 +337,8 @@ impl<'a> Context<'a> { num_consumptions - 1 }; - // verify that this is indeed always 1 - debug_assert!(num_incs <= 1); - - if num_incs == 1 { - b = self.add_inc(*x, b) + if num_incs >= 1 { + b = self.add_inc(*x, num_incs as u64, b) } } } @@ -524,7 +523,7 @@ impl<'a> Context<'a> { let b = self.add_dec_if_needed(x, b, b_live_vars); let info_x = self.get_var_info(x); let b = if info_x.consume { - self.add_inc(z, b) + self.add_inc(z, 1, b) } else { b }; @@ -786,7 +785,7 @@ impl<'a> Context<'a> { live_vars.insert(*x); if info.reference && !info.consume { - (self.add_inc(*x, stmt), live_vars) + (self.add_inc(*x, 1, stmt), live_vars) } else { (stmt, live_vars) }