diff --git a/compiler/mono/src/alias_analysis.rs b/compiler/mono/src/alias_analysis.rs index 831bdd8afd..2684f2884e 100644 --- a/compiler/mono/src/alias_analysis.rs +++ b/compiler/mono/src/alias_analysis.rs @@ -845,6 +845,19 @@ fn lowlevel_spec( } } +fn recursive_tag_variant( + builder: &mut impl TypeContext, + union_layout: &UnionLayout, + fields: &[Layout], +) -> Result { + let when_recursive = WhenRecursive::Loop(*union_layout); + + let data_id = build_recursive_tuple_type(builder, fields, &when_recursive)?; + let cell_id = builder.add_heap_cell_type(); + + builder.add_tuple_type(&[cell_id, data_id]) +} + fn build_variant_types( builder: &mut impl TypeContext, union_layout: &UnionLayout, @@ -864,17 +877,12 @@ fn build_variant_types( Recursive(tags) => { result = Vec::with_capacity(tags.len()); - let when_recursive = WhenRecursive::Loop(*union_layout); - for tag in tags.iter() { - let data_id = build_recursive_tuple_type(builder, tag, &when_recursive)?; - let cell_id = builder.add_heap_cell_type(); - let value_id = builder.add_tuple_type(&[cell_id, data_id])?; - result.push(value_id); + result.push(recursive_tag_variant(builder, union_layout, tag)?); } } NonNullableUnwrapped(fields) => { - result = vec![build_tuple_type(builder, fields)?]; + result = vec![recursive_tag_variant(builder, union_layout, fields)?]; } NullableWrapped { nullable_id, @@ -885,14 +893,14 @@ fn build_variant_types( let cutoff = *nullable_id as usize; for tag in tags[..cutoff].iter() { - result.push(build_tuple_type(builder, tag)?); + result.push(recursive_tag_variant(builder, union_layout, tag)?); } let unit = builder.add_tuple_type(&[])?; result.push(unit); for tag in tags[cutoff..].iter() { - result.push(build_tuple_type(builder, tag)?); + result.push(recursive_tag_variant(builder, union_layout, tag)?); } } NullableUnwrapped { @@ -900,7 +908,7 @@ fn build_variant_types( other_fields: fields, } => { let unit = builder.add_tuple_type(&[])?; - let other_type = build_tuple_type(builder, fields)?; + let other_type = recursive_tag_variant(builder, union_layout, fields)?; if *nullable_id { // nullable_id == 1 @@ -959,17 +967,33 @@ fn expr_spec<'a>( env.type_names.insert(*tag_layout); - let named_id = builder.add_make_named(block, MOD_APP, type_name, union_id)?; - - Ok(named_id) + builder.add_make_named(block, MOD_APP, type_name, union_id) } - UnionLayout::NonNullableUnwrapped(_) - | UnionLayout::NullableWrapped { .. } - | UnionLayout::NullableUnwrapped { .. } => { + UnionLayout::NonNullableUnwrapped(_) | UnionLayout::NullableWrapped { .. } => { let result_type = worst_case_type(builder)?; let value_id = build_tuple_value(builder, env, block, arguments)?; builder.add_unknown_with(block, &[value_id], result_type) } + UnionLayout::NullableUnwrapped { nullable_id, .. } => { + let union_id = if *tag_id == *nullable_id as u8 { + let value_id = builder.add_make_tuple(block, &[])?; + builder.add_make_union(block, &variant_types, *tag_id as u32, value_id)? + } else { + let data_id = build_tuple_value(builder, env, block, arguments)?; + let cell_id = builder.add_new_heap_cell(block)?; + + let value_id = builder.add_make_tuple(block, &[cell_id, data_id])?; + + builder.add_make_union(block, &variant_types, *tag_id as u32, value_id)? + }; + + let type_name_bytes = recursive_tag_union_name_bytes(tag_layout).as_bytes(); + let type_name = TypeName(&type_name_bytes); + + env.type_names.insert(*tag_layout); + + builder.add_make_named(block, MOD_APP, type_name, union_id) + } } } Struct(fields) => build_tuple_value(builder, env, block, fields), @@ -1005,6 +1029,24 @@ fn expr_spec<'a>( builder.add_get_tuple_field(block, tuple_value_id, index) } + UnionLayout::NullableUnwrapped { .. } => { + let index = (*index) as u32; + let tag_value_id = env.symbols[structure]; + + let type_name_bytes = recursive_tag_union_name_bytes(&union_layout).as_bytes(); + let type_name = TypeName(&type_name_bytes); + + let union_id = builder.add_unwrap_named(block, MOD_APP, type_name, tag_value_id)?; + let variant_id = builder.add_unwrap_union(block, union_id, *tag_id as u32)?; + + // we're reading from this value, so touch the heap cell + let heap_cell = builder.add_get_tuple_field(block, variant_id, 0)?; + builder.add_touch(block, heap_cell)?; + + let tuple_value_id = builder.add_get_tuple_field(block, variant_id, 1)?; + + builder.add_get_tuple_field(block, tuple_value_id, index) + } _ => { // for the moment recursive tag unions don't quite work let value_id = env.symbols[structure]; @@ -1097,7 +1139,6 @@ fn layout_spec_help( match union_layout { UnionLayout::NonRecursive(_) => builder.add_union_type(&variant_types), UnionLayout::Recursive(_) => { - // worst_case_type(builder), let type_name_bytes = recursive_tag_union_name_bytes(&union_layout).as_bytes(); let type_name = TypeName(&type_name_bytes); @@ -1111,7 +1152,12 @@ fn layout_spec_help( UnionLayout::NullableUnwrapped { nullable_id: _, other_fields: _, - } => worst_case_type(builder), + } => { + let type_name_bytes = recursive_tag_union_name_bytes(&union_layout).as_bytes(); + let type_name = TypeName(&type_name_bytes); + + Ok(builder.add_named_type(MOD_APP, type_name)) + } } } RecursivePointer => match when_recursive { @@ -1120,15 +1166,18 @@ fn layout_spec_help( // unreachable!(), worst_case_type(builder) } - WhenRecursive::Loop(union_layout) - if matches!(union_layout, UnionLayout::Recursive(_)) => - { - let type_name_bytes = recursive_tag_union_name_bytes(union_layout).as_bytes(); - let type_name = TypeName(&type_name_bytes); + WhenRecursive::Loop(union_layout) => match union_layout { + UnionLayout::NonRecursive(_) => unreachable!(), + UnionLayout::Recursive(_) | UnionLayout::NullableUnwrapped { .. } => { + let type_name_bytes = recursive_tag_union_name_bytes(union_layout).as_bytes(); + let type_name = TypeName(&type_name_bytes); - Ok(builder.add_named_type(MOD_APP, type_name)) - } - WhenRecursive::Loop(_union_layout) => worst_case_type(builder), + Ok(builder.add_named_type(MOD_APP, type_name)) + } + UnionLayout::NonNullableUnwrapped(_) | UnionLayout::NullableWrapped { .. } => { + worst_case_type(builder) + } + }, }, Closure(_, lambda_set, _) => layout_spec_help( builder,