diff --git a/cli/src/repl/eval.rs b/cli/src/repl/eval.rs index 93c08cac80..31853aed1f 100644 --- a/cli/src/repl/eval.rs +++ b/cli/src/repl/eval.rs @@ -164,7 +164,10 @@ fn jit_to_ast_help<'a>( let size = layout.stack_size(env.ptr_bytes); match union_variant { - UnionVariant::Wrapped(tags_and_layouts) => { + UnionVariant::Wrapped { + sorted_tag_layouts: tags_and_layouts, + .. + } => { Ok(run_jit_function_dynamic_type!( lib, main_fn_name, diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 884909a0c8..84cd50b019 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -808,6 +808,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( layout: &Layout<'a>, expr: &roc_mono::ir::Expr<'a>, ) -> BasicValueEnum<'ctx> { + use inkwell::types::BasicType; use roc_mono::ir::Expr::*; match expr { @@ -960,15 +961,9 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( field_types.push(field_type); if let Layout::RecursivePointer = tag_field_layout { - let ptr = allocate_with_refcount(env, &tag_layout, val); - - let ptr = cast_basic_basic( - builder, - ptr.into(), - ctx.i64_type().ptr_type(AddressSpace::Generic).into(), + panic!( + r"non-recursive tag unions cannot directly contain a recursive pointer" ); - - field_vals.push(ptr); } else { // this check fails for recursive tag unions, but can be helpful while debugging debug_assert_eq!(tag_field_layout, val_layout); @@ -1027,7 +1022,88 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( internal_type, ) } - Tag { .. } => unreachable!("tags should have a union layout"), + Tag { + arguments, + tag_layout: Layout::RecursiveUnion(fields), + union_size, + tag_id, + tag_name, + .. + } => { + let tag_layout = Layout::Union(fields); + + debug_assert!(*union_size > 1); + let ptr_size = env.ptr_bytes; + + let ctx = env.context; + let builder = env.builder; + + // Determine types + let num_fields = arguments.len() + 1; + let mut field_types = Vec::with_capacity_in(num_fields, env.arena); + let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); + + let tag_field_layouts = if let TagName::Closure(_) = tag_name { + // closures ignore (and do not store) the discriminant + &fields[*tag_id as usize][1..] + } else { + &fields[*tag_id as usize] + }; + + for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) { + let (val, val_layout) = load_symbol_and_layout(env, scope, field_symbol); + + // Zero-sized fields have no runtime representation. + // The layout of the struct expects them to be dropped! + if !tag_field_layout.is_dropped_because_empty() { + let field_type = + basic_type_from_layout(env.arena, env.context, tag_field_layout, ptr_size); + + field_types.push(field_type); + + if let Layout::RecursivePointer = tag_field_layout { + debug_assert!(val.is_pointer_value()); + + // we store recursive pointers as `i64*` + let ptr = cast_basic_basic( + builder, + val, + ctx.i64_type().ptr_type(AddressSpace::Generic).into(), + ); + + field_vals.push(ptr); + } else { + // this check fails for recursive tag unions, but can be helpful while debugging + debug_assert_eq!(tag_field_layout, val_layout); + + field_vals.push(val); + } + } + } + + // Create the struct_type + let data_ptr = reserve_with_refcount(env, &tag_layout); + let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); + let struct_ptr = cast_basic_basic( + builder, + data_ptr.into(), + struct_type.ptr_type(AddressSpace::Generic).into(), + ) + .into_pointer_value(); + + // Insert field exprs into struct_val + for (index, field_val) in field_vals.into_iter().enumerate() { + let field_ptr = builder + .build_struct_gep(struct_ptr, index as u32, "struct_gep") + .unwrap(); + + builder.build_store(field_ptr, field_val); + } + + data_ptr.into() + } + + Tag { .. } => unreachable!("tags should have a Union or RecursiveUnion layout"), Reset(_) => todo!(), Reuse { .. } => todo!(), @@ -1092,6 +1168,8 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( field_layouts, .. } => { + use BasicValueEnum::*; + let builder = env.builder; // Determine types, assumes the descriminant is in the field layouts @@ -1111,28 +1189,61 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( .struct_type(field_types.into_bump_slice(), false); // cast the argument bytes into the desired shape for this tag - let argument = load_symbol(env, scope, structure).into_struct_value(); + let argument = load_symbol(env, scope, structure); - let struct_value = cast_struct_struct(builder, argument, struct_type); + let struct_layout = Layout::Struct(field_layouts); + match argument { + StructValue(value) => { + let struct_value = cast_struct_struct(builder, value, struct_type); - let result = builder - .build_extract_value(struct_value, *index as u32, "") - .expect("desired field did not decode"); + let result = builder + .build_extract_value(struct_value, *index as u32, "") + .expect("desired field did not decode"); - if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) { - let struct_layout = Layout::Struct(field_layouts); - let desired_type = block_of_memory(env.context, &struct_layout, env.ptr_bytes); + if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) { + let desired_type = + block_of_memory(env.context, &struct_layout, env.ptr_bytes); - // the value is a pointer to the actual value; load that value! - use inkwell::types::BasicType; - let ptr = cast_basic_basic( - builder, - result, - desired_type.ptr_type(AddressSpace::Generic).into(), - ); - builder.build_load(ptr.into_pointer_value(), "load_recursive_field") - } else { - result + // the value is a pointer to the actual value; load that value! + let ptr = cast_basic_basic( + builder, + result, + desired_type.ptr_type(AddressSpace::Generic).into(), + ); + builder.build_load(ptr.into_pointer_value(), "load_recursive_field") + } else { + result + } + } + PointerValue(value) => { + let ptr = cast_basic_basic( + builder, + value.into(), + struct_type.ptr_type(AddressSpace::Generic).into(), + ) + .into_pointer_value(); + + let elem_ptr = builder + .build_struct_gep(ptr, *index as u32, "at_index_struct_gep") + .unwrap(); + + let result = builder.build_load(elem_ptr, "load_at_index_ptr"); + + if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) { + // a recursive field is stored as a `i64*`, to use it we must cast it to + // a pointer to the block of memory representation + cast_basic_basic( + builder, + result, + block_of_memory(env.context, &struct_layout, env.ptr_bytes) + .ptr_type(AddressSpace::Generic) + .into(), + ) + } else { + result + } + } + _ => panic!("cannot look up index in {:?}", argument), } } EmptyArray => empty_polymorphic_list(env), @@ -1178,12 +1289,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( } } -pub fn allocate_with_refcount<'a, 'ctx, 'env>( +pub fn reserve_with_refcount<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout: &Layout<'a>, - value: BasicValueEnum<'ctx>, ) -> PointerValue<'ctx> { - let builder = env.builder; let ctx = env.context; let len_type = env.ptr_int(); @@ -1193,10 +1302,18 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>( let rc1 = crate::llvm::refcounting::refcount_1(ctx, env.ptr_bytes); - let data_ptr = allocate_with_refcount_help(env, layout, value_bytes_intvalue, rc1); + allocate_with_refcount_help(env, layout, value_bytes_intvalue, rc1) +} + +pub fn allocate_with_refcount<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout: &Layout<'a>, + value: BasicValueEnum<'ctx>, +) -> PointerValue<'ctx> { + let data_ptr = reserve_with_refcount(env, layout); // store the value in the pointer - builder.build_store(data_ptr, value); + env.builder.build_store(data_ptr, value); data_ptr } @@ -1454,6 +1571,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( let mut stack = Vec::with_capacity_in(queue.len(), env.arena); for (symbol, expr, layout) in queue { + debug_assert!(layout != &Layout::RecursivePointer); let context = &env.context; let val = build_exp_expr(env, layout_ids, &scope, parent, layout, &expr); diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index 958a986c42..9f05862b64 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -137,7 +137,10 @@ pub fn basic_type_from_layout<'ctx>( basic_type_from_record(arena, context, sorted_fields, ptr_bytes) } - RecursiveUnion(_) | Union(_) => block_of_memory(context, layout, ptr_bytes), + RecursiveUnion(_) => block_of_memory(context, layout, ptr_bytes) + .ptr_type(AddressSpace::Generic) + .into(), + Union(_) => block_of_memory(context, layout, ptr_bytes), RecursivePointer => { // TODO make this dynamic context diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index 78c86b28b1..8d4abaa791 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -343,7 +343,8 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>( } RecursiveUnion(tags) => { - build_dec_union(env, layout_ids, tags, value); + debug_assert!(value.is_pointer_value()); + build_dec_rec_union(env, layout_ids, tags, value.into_pointer_value()); } FunctionPointer(_, _) | Pointer(_) => {} @@ -427,7 +428,8 @@ pub fn increment_refcount_layout<'a, 'ctx, 'env>( } RecursiveUnion(tags) => { - build_inc_union(env, layout_ids, tags, value); + debug_assert!(value.is_pointer_value()); + build_inc_rec_union(env, layout_ids, tags, value.into_pointer_value()); } Closure(_, closure_layout, _) => { if closure_layout.contains_refcounted() { @@ -1032,6 +1034,202 @@ pub fn build_header_help<'a, 'ctx, 'env>( fn_val } +pub fn build_dec_rec_union<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + fields: &'a [&'a [Layout<'a>]], + value: PointerValue<'ctx>, +) { + let layout = Layout::RecursiveUnion(fields); + + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let symbol = Symbol::DEC; + let fn_name = layout_ids + .get(symbol, &layout) + .to_symbol_string(symbol, &env.interns); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let function_value = build_header(env, &layout, &fn_name); + + build_dec_rec_union_help(env, layout_ids, fields, function_value); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + let call = env + .builder + .build_call(function, &[value.into()], "decrement_union"); + + call.set_call_convention(FAST_CALL_CONV); +} + +pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + tags: &[&[Layout<'a>]], + fn_val: FunctionValue<'ctx>, +) { + debug_assert!(!tags.is_empty()); + + use inkwell::types::BasicType; + + let context = &env.context; + let builder = env.builder; + + // Add a basic block for the entry point + let entry = context.append_basic_block(fn_val, "entry"); + + builder.position_at_end(entry); + + let func_scope = fn_val.get_subprogram().unwrap(); + let lexical_block = env.dibuilder.create_lexical_block( + /* scope */ func_scope.as_debug_info_scope(), + /* file */ env.compile_unit.get_file(), + /* line_no */ 0, + /* column_no */ 0, + ); + + let loc = env.dibuilder.create_debug_location( + context, + /* line */ 0, + /* column */ 0, + /* current_scope */ lexical_block.as_debug_info_scope(), + /* inlined_at */ None, + ); + builder.set_current_debug_location(&context, loc); + + // Add args to scope + let arg_symbol = Symbol::ARG_1; + + let arg_val = fn_val.get_param_iter().next().unwrap(); + + set_name(arg_val, arg_symbol.ident_string(&env.interns)); + + let parent = fn_val; + + let layout = Layout::RecursiveUnion(tags); + let before_block = env.builder.get_insert_block().expect("to be in a function"); + + debug_assert!(arg_val.is_pointer_value()); + let value_ptr = arg_val.into_pointer_value(); + + // next, make a jump table for all possible values of the tag_id + let mut cases = Vec::with_capacity_in(tags.len(), env.arena); + + let merge_block = env.context.append_basic_block(parent, "decrement_merge"); + + builder.set_current_debug_location(&context, loc); + + for (tag_id, field_layouts) in tags.iter().enumerate() { + // if none of the fields are or contain anything refcounted, just move on + if !field_layouts + .iter() + .any(|x| x.is_refcounted() || x.contains_refcounted()) + { + continue; + } + + let block = env.context.append_basic_block(parent, "tag_id_decrement"); + env.builder.position_at_end(block); + + let wrapper_type = basic_type_from_layout( + env.arena, + env.context, + &Layout::Struct(field_layouts), + env.ptr_bytes, + ); + + // cast the opaque pointer to a pointer of the correct shape + let struct_ptr = cast_basic_basic( + env.builder, + value_ptr.into(), + wrapper_type.ptr_type(AddressSpace::Generic).into(), + ) + .into_pointer_value(); + + for (i, field_layout) in field_layouts.iter().enumerate() { + if let Layout::RecursivePointer = field_layout { + // this field has type `*i64`, but is really a pointer to the data we want + let elem_pointer = env + .builder + .build_struct_gep(struct_ptr, i as u32, "gep_recursive_pointer") + .unwrap(); + + let ptr_as_i64_ptr = env + .builder + .build_load(elem_pointer, "load_recursive_pointer"); + + debug_assert!(ptr_as_i64_ptr.is_pointer_value()); + + // therefore we must cast it to our desired type + let union_type = + basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); + let recursive_field_ptr = cast_basic_basic(env.builder, ptr_as_i64_ptr, union_type); + + // recursively decrement the field + let call = env.builder.build_call( + fn_val, + &[recursive_field_ptr], + "recursive_tag_decrement", + ); + + // Because it's an internal-only function, use the fast calling convention. + call.set_call_convention(FAST_CALL_CONV); + } else if field_layout.contains_refcounted() { + // TODO this loads the whole field onto the stack; + // that's wasteful if e.g. the field is a big record, where only + // some fields are actually refcounted. + let elem_pointer = env + .builder + .build_struct_gep(struct_ptr, i as u32, "gep_recursive_pointer") + .unwrap(); + + let field = env + .builder + .build_load(elem_pointer, "decrement_struct_field"); + + decrement_refcount_layout(env, parent, layout_ids, field, field_layout); + } + } + + env.builder.build_unconditional_branch(merge_block); + + cases.push(( + env.context.i64_type().const_int(tag_id as u64, false), + block, + )); + } + + cases.reverse(); + + env.builder.position_at_end(before_block); + + // read the tag_id + let current_tag_id = rec_union_read_tag(env, value_ptr); + + // switch on it + env.builder + .build_switch(current_tag_id, merge_block, &cases); + + env.builder.position_at_end(merge_block); + + // decrement this cons-cell itself + let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); + refcount_ptr.decrement(env, &layout); + + // this function returns void + builder.build_return(None); +} + pub fn build_dec_union<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -1078,8 +1276,6 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( ) { debug_assert!(!tags.is_empty()); - use inkwell::types::BasicType; - let context = &env.context; let builder = env.builder; @@ -1105,32 +1301,18 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( ); builder.set_current_debug_location(&context, loc); - let mut scope = Scope::default(); - // Add args to scope let arg_symbol = Symbol::ARG_1; - let layout = Layout::Union(tags); let arg_val = fn_val.get_param_iter().next().unwrap(); set_name(arg_val, arg_symbol.ident_string(&env.interns)); - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - let parent = fn_val; - let layout = Layout::RecursiveUnion(tags); let before_block = env.builder.get_insert_block().expect("to be in a function"); + debug_assert!(arg_val.is_struct_value()); let wrapper_struct = arg_val.into_struct_value(); // next, make a jump table for all possible values of the tag_id @@ -1164,39 +1346,7 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { - // this field has type `*i64`, but is really a pointer to the data we want - let ptr_as_i64_ptr = env - .builder - .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") - .unwrap(); - - debug_assert!(ptr_as_i64_ptr.is_pointer_value()); - - // therefore we must cast it to our desired type - let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let recursive_field_ptr = cast_basic_basic( - env.builder, - ptr_as_i64_ptr, - union_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); - - let recursive_field = env - .builder - .build_load(recursive_field_ptr, "load_recursive_field"); - - // recursively decrement the field - let call = - env.builder - .build_call(fn_val, &[recursive_field], "recursive_tag_decrement"); - - // Because it's an internal-only function, use the fast calling convention. - call.set_call_convention(FAST_CALL_CONV); - - // TODO do this decrement before the recursive call? - // Then the recursive call is potentially TCE'd - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, recursive_field_ptr); - refcount_ptr.decrement(env, &layout); + panic!("a non-recursive tag union cannot contain RecursivePointer"); } else if field_layout.contains_refcounted() { let field_ptr = env .builder @@ -1244,6 +1394,213 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( builder.build_return(None); } +pub fn build_inc_rec_union<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + fields: &'a [&'a [Layout<'a>]], + value: PointerValue<'ctx>, +) { + let layout = Layout::RecursiveUnion(fields); + + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let symbol = Symbol::INC; + let fn_name = layout_ids + .get(symbol, &layout) + .to_symbol_string(symbol, &env.interns); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let function_value = build_header(env, &layout, &fn_name); + + build_inc_rec_union_help(env, layout_ids, fields, function_value); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + let call = env + .builder + .build_call(function, &[value.into()], "increment_union"); + + call.set_call_convention(FAST_CALL_CONV); +} + +fn rec_union_read_tag<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + value_ptr: PointerValue<'ctx>, +) -> IntValue<'ctx> { + // Assumption: the tag is the first thing stored + // so cast the pointer to the data to a `i64*` + let tag_ptr = cast_basic_basic( + env.builder, + value_ptr.into(), + env.context + .i64_type() + .ptr_type(AddressSpace::Generic) + .into(), + ) + .into_pointer_value(); + + env.builder + .build_load(tag_ptr, "load_tag_id") + .into_int_value() +} + +pub fn build_inc_rec_union_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + tags: &[&[Layout<'a>]], + fn_val: FunctionValue<'ctx>, +) { + debug_assert!(!tags.is_empty()); + + use inkwell::types::BasicType; + + let context = &env.context; + let builder = env.builder; + + // Add a basic block for the entry point + let entry = context.append_basic_block(fn_val, "entry"); + + builder.position_at_end(entry); + + let func_scope = fn_val.get_subprogram().unwrap(); + let lexical_block = env.dibuilder.create_lexical_block( + /* scope */ func_scope.as_debug_info_scope(), + /* file */ env.compile_unit.get_file(), + /* line_no */ 0, + /* column_no */ 0, + ); + + let loc = env.dibuilder.create_debug_location( + context, + /* line */ 0, + /* column */ 0, + /* current_scope */ lexical_block.as_debug_info_scope(), + /* inlined_at */ None, + ); + builder.set_current_debug_location(&context, loc); + + // Add args to scope + let arg_symbol = Symbol::ARG_1; + let arg_val = fn_val.get_param_iter().next().unwrap(); + + set_name(arg_val, arg_symbol.ident_string(&env.interns)); + + let parent = fn_val; + + let layout = Layout::RecursiveUnion(tags); + let before_block = env.builder.get_insert_block().expect("to be in a function"); + + debug_assert!(arg_val.is_pointer_value()); + let value_ptr = arg_val.into_pointer_value(); + + // read the tag_id + let tag_id = rec_union_read_tag(env, value_ptr); + + let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); + + // next, make a jump table for all possible values of the tag_id + let mut cases = Vec::with_capacity_in(tags.len(), env.arena); + + let merge_block = env.context.append_basic_block(parent, "increment_merge"); + + for (tag_id, field_layouts) in tags.iter().enumerate() { + // if none of the fields are or contain anything refcounted, just move on + if !field_layouts + .iter() + .any(|x| x.is_refcounted() || x.contains_refcounted()) + { + continue; + } + + let block = env.context.append_basic_block(parent, "tag_id_increment"); + env.builder.position_at_end(block); + + let wrapper_type = basic_type_from_layout( + env.arena, + env.context, + &Layout::Struct(field_layouts), + env.ptr_bytes, + ); + + // cast the opaque pointer to a pointer of the correct shape + let struct_ptr = cast_basic_basic( + env.builder, + value_ptr.into(), + wrapper_type.ptr_type(AddressSpace::Generic).into(), + ) + .into_pointer_value(); + + for (i, field_layout) in field_layouts.iter().enumerate() { + if let Layout::RecursivePointer = field_layout { + // this field has type `*i64`, but is really a pointer to the data we want + let elem_pointer = env + .builder + .build_struct_gep(struct_ptr, i as u32, "gep_recursive_pointer") + .unwrap(); + + let ptr_as_i64_ptr = env + .builder + .build_load(elem_pointer, "load_recursive_pointer"); + + debug_assert!(ptr_as_i64_ptr.is_pointer_value()); + + // therefore we must cast it to our desired type + let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); + let recursive_field_ptr = cast_basic_basic( + env.builder, + ptr_as_i64_ptr, + union_type.ptr_type(AddressSpace::Generic).into(), + ); + + // recursively increment the field + let call = env.builder.build_call( + fn_val, + &[recursive_field_ptr], + "recursive_tag_increment", + ); + + // Because it's an internal-only function, use the fast calling convention. + call.set_call_convention(FAST_CALL_CONV); + } else if field_layout.contains_refcounted() { + let elem_pointer = env + .builder + .build_struct_gep(struct_ptr, i as u32, "gep_field") + .unwrap(); + + let field = env.builder.build_load(elem_pointer, "load_field"); + + increment_refcount_layout(env, parent, layout_ids, field, field_layout); + } + } + + env.builder.build_unconditional_branch(merge_block); + + cases.push((env.context.i8_type().const_int(tag_id as u64, false), block)); + } + + env.builder.position_at_end(before_block); + + env.builder + .build_switch(tag_id_u8.into_int_value(), merge_block, &cases); + + env.builder.position_at_end(merge_block); + + // increment this cons cell + let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); + refcount_ptr.increment(env); + + // this function returns void + builder.build_return(None); +} + pub fn build_inc_union<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -1316,26 +1673,12 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( ); builder.set_current_debug_location(&context, loc); - let mut scope = Scope::default(); - // Add args to scope let arg_symbol = Symbol::ARG_1; - let layout = Layout::Union(tags); let arg_val = fn_val.get_param_iter().next().unwrap(); set_name(arg_val, arg_symbol.ident_string(&env.interns)); - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - let parent = fn_val; let layout = Layout::RecursiveUnion(tags); diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index eea164e70a..b5b0f1ec8f 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -564,7 +564,7 @@ mod gen_primitives { Cons _ rest -> 1 + len rest main = - nil : LinkedList (Int *) + nil : LinkedList {} nil = Nil len nil @@ -1122,6 +1122,40 @@ mod gen_primitives { ); } + #[test] + fn linked_list_is_singleton() { + assert_non_opt_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + ConsList a : [ Cons a (ConsList a), Nil ] + + empty : ConsList a + empty = Nil + + isSingleton : ConsList a -> Bool + isSingleton = \list -> + when list is + Cons _ Nil -> + True + + _ -> + False + + main : Bool + main = + myList : ConsList I64 + myList = empty + + isSingleton myList + "# + ), + false, + bool + ); + } + #[test] fn linked_list_is_empty_1() { assert_non_opt_evals_to!( @@ -1176,7 +1210,7 @@ mod gen_primitives { main : Bool main = - myList : ConsList (Int *) + myList : ConsList I64 myList = Cons 0x1 Nil isEmpty myList @@ -1187,6 +1221,26 @@ mod gen_primitives { ); } + #[test] + fn linked_list_singleton() { + // verifies only that valid llvm is produced + assert_non_opt_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + ConsList a : [ Cons a (ConsList a), Nil ] + + main : ConsList I64 + main = Cons 0x1 Nil + "# + ), + 0, + i64, + |_| 0 + ); + } + #[test] fn recursive_functon_with_rigid() { assert_non_opt_evals_to!( @@ -1354,7 +1408,8 @@ mod gen_primitives { "# ), 1, - i64 + &i64, + |x: &i64| *x ); } @@ -1451,7 +1506,8 @@ mod gen_primitives { "# ), 1, - i64 + &i64, + |x: &i64| *x ); } diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index f011c2857e..b6b460609a 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -2731,7 +2731,10 @@ pub fn with_hole<'a>( let iter = field_symbols_temp.into_iter().map(|(_, _, data)| data); assign_to_symbols(env, procs, layout_cache, iter, stmt) } - Wrapped(sorted_tag_layouts) => { + Wrapped { + sorted_tag_layouts, + is_recursive, + } => { let union_size = sorted_tag_layouts.len() as u8; let (tag_id, (_, _)) = sorted_tag_layouts .iter() @@ -2786,7 +2789,12 @@ pub fn with_hole<'a>( } let field_symbols = field_symbols.into_bump_slice(); - let layout = Layout::Union(layouts.into_bump_slice()); + let layout = if is_recursive { + Layout::RecursiveUnion(layouts.into_bump_slice()) + } else { + Layout::Union(layouts.into_bump_slice()) + }; + let tag = Expr::Tag { tag_layout: layout.clone(), tag_name, @@ -4794,6 +4802,12 @@ fn store_pattern<'a>( for (index, (argument, arg_layout)) in arguments.iter().enumerate().rev() { let index = if write_tag { index + 1 } else { index }; + let mut arg_layout = arg_layout; + + if let Layout::RecursivePointer = arg_layout { + arg_layout = layout; + } + let load = Expr::AccessAtIndex { wrapped, index: index as u64, @@ -5857,7 +5871,10 @@ fn from_can_pattern_help<'a>( layout, } } - Wrapped(tags) => { + Wrapped { + sorted_tag_layouts: tags, + is_recursive, + } => { let mut ctors = std::vec::Vec::with_capacity(tags.len()); for (i, (tag_name, args)) in tags.iter().enumerate() { ctors.push(Ctor { @@ -5912,7 +5929,11 @@ fn from_can_pattern_help<'a>( layouts.push(arg_layouts); } - let layout = Layout::Union(layouts.into_bump_slice()); + let layout = if is_recursive { + Layout::RecursiveUnion(layouts.into_bump_slice()) + } else { + Layout::Union(layouts.into_bump_slice()) + }; Pattern::AppliedTag { tag_name: tag_name.clone(), diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 1a595b5301..1b19d92cb7 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -158,7 +158,13 @@ impl<'a> ClosureLayout<'a> { Ok(Some(closure_layout)) } - Wrapped(tags) => { + Wrapped { + sorted_tag_layouts: tags, + is_recursive, + } => { + // TODO handle recursive closures + debug_assert!(!is_recursive); + let closure_layout = ClosureLayout::from_tag_union(arena, tags.into_bump_slice()); @@ -1095,10 +1101,16 @@ pub enum UnionVariant<'a> { Never, Unit, UnitWithArguments, - BoolUnion { ttrue: TagName, ffalse: TagName }, + BoolUnion { + ttrue: TagName, + ffalse: TagName, + }, ByteUnion(Vec<'a, TagName>), Unwrapped(Vec<'a, Layout<'a>>), - Wrapped(Vec<'a, (TagName, &'a [Layout<'a>])>), + Wrapped { + sorted_tag_layouts: Vec<'a, (TagName, &'a [Layout<'a>])>, + is_recursive: bool, + }, } pub fn union_sorted_tags<'a>(arena: &'a Bump, var: Variable, subs: &Subs) -> UnionVariant<'a> { @@ -1277,7 +1289,10 @@ pub fn union_sorted_tags_help<'a>( UnionVariant::ByteUnion(tag_names) } - _ => UnionVariant::Wrapped(answer), + _ => UnionVariant::Wrapped { + sorted_tag_layouts: answer, + is_recursive: opt_rec_var.is_some(), + }, } } } @@ -1316,13 +1331,21 @@ pub fn layout_from_tag_union<'a>( Layout::Struct(field_layouts.into_bump_slice()) } } - Wrapped(tags) => { + Wrapped { + sorted_tag_layouts: tags, + is_recursive, + } => { let mut tag_layouts = Vec::with_capacity_in(tags.len(), arena); for (_, tag_layout) in tags { tag_layouts.push(tag_layout); } - Layout::Union(tag_layouts.into_bump_slice()) + + if is_recursive { + Layout::RecursiveUnion(tag_layouts.into_bump_slice()) + } else { + Layout::Union(tag_layouts.into_bump_slice()) + } } } } diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 8186a86251..0b14ca7e13 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -1893,6 +1893,7 @@ mod test_mono { let Test.2 = S Test.9 Test.8; let Test.5 = 1i64; let Test.6 = Index 0 Test.2; + dec Test.2; let Test.7 = lowlevel Eq Test.5 Test.6; if Test.7 then let Test.3 = 0i64; @@ -1944,12 +1945,16 @@ mod test_mono { let Test.10 = lowlevel Eq Test.8 Test.9; if Test.10 then let Test.4 = Index 1 Test.2; + inc Test.4; + dec Test.2; let Test.3 = 1i64; ret Test.3; else + dec Test.2; let Test.5 = 0i64; ret Test.5; else + dec Test.2; let Test.6 = 0i64; ret Test.6; "#