diff --git a/crates/glue/src/rust_glue.rs b/crates/glue/src/rust_glue.rs index d392ee0c99..693532cbf9 100644 --- a/crates/glue/src/rust_glue.rs +++ b/crates/glue/src/rust_glue.rs @@ -248,9 +248,12 @@ fn add_type(target_info: TargetInfo, id: TypeId, types: &Types, impls: &mut Impl RocTagUnion::RecursiveSingleTag { name, tag_name, - payload: payload_id, + payload_fields, + } => { + // TODO + todo!(); } - | RocTagUnion::NonRecursiveSingleTag { + RocTagUnion::NonRecursiveSingleTag { name, tag_name, payload: Some(payload_id), @@ -269,13 +272,21 @@ fn add_type(target_info: TargetInfo, id: TypeId, types: &Types, impls: &mut Impl types, false, ); - let mut buf = format!("#[repr(C)]\n{derive}\npub struct {name} {{"); - write_tag_union_field(tag_name, &Some(*payload_id), types, &mut buf); - - buf.push_str("}\n"); - - add_decl(impls, None, target_info, buf); + add_decl( + impls, + None, + target_info, + format!( + r#"#[repr(C)] +{derive} +pub struct {name} {{ + {tag_name}: *mut {} +}} +"#, + type_name(*payload_id, types) + ), + ); } } RocTagUnion::NonRecursiveSingleTag { @@ -486,7 +497,27 @@ pub struct {name} {{ let mut buf = format!("#[repr(C)]\n{pub_str}union {decl_union_name} {{\n"); for (tag_name, opt_payload_id) in tags { - write_tag_union_field(tag_name, opt_payload_id, types, &mut buf); + // If there's no payload, we don't need a discriminant for it. + if let Some(payload_id) = opt_payload_id { + let payload_type = types.get_type(*payload_id); + + write!(buf, "{INDENT}{tag_name}: ").unwrap(); + + if cannot_derive_copy(payload_type, types) { + // types with pointers need ManuallyDrop + // because rust unions don't (and can't) + // know how to drop them automatically! + writeln!( + buf, + "core::mem::ManuallyDrop<{}>,", + type_name(*payload_id, types) + ) + .unwrap(); + } else { + buf.push_str(&type_name(*payload_id, types)); + buf.push_str(",\n"); + } + } } if tags.len() > 1 { @@ -1278,36 +1309,6 @@ pub struct {name} {{ } } -#[inline(always)] -fn write_tag_union_field( - tag_name: &str, - opt_payload_id: &Option, - types: &Types, - buf: &mut String, -) { - // If there's no payload, we don't need a discriminant for it. - if let Some(payload_id) = opt_payload_id { - let payload_type = types.get_type(*payload_id); - - write!(buf, "{INDENT}{tag_name}: ").unwrap(); - - if cannot_derive_copy(payload_type, types) { - // types with pointers need ManuallyDrop - // because rust unions don't (and can't) - // know how to drop them automatically! - writeln!( - buf, - "core::mem::ManuallyDrop<{}>,", - type_name(*payload_id, types) - ) - .unwrap(); - } else { - buf.push_str(&type_name(*payload_id, types)); - buf.push_str(",\n"); - } - } -} - fn write_impl_tags< 'a, I: IntoIterator)>, @@ -2145,6 +2146,9 @@ fn has_float_help(roc_type: &RocType, types: &Types, do_not_recurse: &[TypeId]) RocType::TagUnionPayload { fields, .. } => fields .iter() .any(|(_, type_id)| has_float_help(types.get_type(*type_id), types, do_not_recurse)), + RocType::TagUnion(RocTagUnion::RecursiveSingleTag { payload_fields, .. }) => payload_fields + .iter() + .any(|type_id| has_float_help(types.get_type(*type_id), types, do_not_recurse)), RocType::TagUnion(RocTagUnion::Recursive { tags, .. }) | RocType::TagUnion(RocTagUnion::NonRecursive { tags, .. }) => { tags.iter().any(|(_, payloads)| { @@ -2164,9 +2168,6 @@ fn has_float_help(roc_type: &RocType, types: &Types, do_not_recurse: &[TypeId]) non_null_payload: content, .. }) - | RocType::TagUnion(RocTagUnion::RecursiveSingleTag { - payload: content, .. - }) | RocType::RecursivePointer(content) => { if do_not_recurse.contains(content) { false diff --git a/crates/glue/src/types.rs b/crates/glue/src/types.rs index 7cae87b50e..224d593062 100644 --- a/crates/glue/src/types.rs +++ b/crates/glue/src/types.rs @@ -147,12 +147,17 @@ impl Types { } ( RecursiveSingleTag { - payload: content_a, .. + name: _, + tag_name: tag_name_a, + payload_fields: payload_fields_a, }, RecursiveSingleTag { - payload: content_b, .. + name: _, + tag_name: tag_name_b, + payload_fields: payload_fields_b, }, - ) => content_a == content_b, + ) => tag_name_a == tag_name_b && payload_fields_a == payload_fields_b, + ( NullableWrapped { tags: tags_a, .. }, NullableWrapped { tags: tags_b, .. }, @@ -543,7 +548,7 @@ pub enum RocTagUnion { RecursiveSingleTag { name: String, tag_name: String, - payload: TypeId, + payload_fields: Vec, }, /// A recursive tag union that has an empty variant @@ -972,6 +977,48 @@ fn add_tag_union<'a>( layout: Layout<'a>, ) -> TypeId { let subs = env.subs; + let name = match opt_name { + Some(sym) => sym.as_str(env.interns).to_string(), + None => env.enum_names.get_name(var), + }; + + // This one needs an early return. That's because the the outermost representation + // of an unwrapped, non-nullable tag union is a struct, not a pointer. We must not + // create structs for its payloads like we do with other tag union types, + // because this one *is* the struct! + if let Layout::Union(UnionLayout::NonNullableUnwrapped(payload_field_layouts)) = layout { + let mut iter = union_tags.iter_from_subs(subs); + let (tag_name, payload_vars) = iter.next().unwrap(); + + // NonNullableUnwrapped should always have exactly 1 payload. + debug_assert_eq!(iter.next(), None); + + let payload_fields: Vec = payload_vars + .into_iter() + .zip(payload_field_layouts.into_iter()) + .map(|(field_var, field_layout)| { + add_type_help(env, *field_layout, *field_var, None, types) + }) + .collect(); + + // A recursive tag union with just one constructor + // Optimization: No need to store a tag ID (the payload is "unwrapped") + // e.g. `RoseTree a : [Tree a (List (RoseTree a))]` + let tag_union_type = RocTagUnion::RecursiveSingleTag { + name: name.clone(), + tag_name: tag_name.0.as_str().to_string(), + payload_fields, + }; + + let typ = RocType::TagUnion(tag_union_type); + let type_id = types.add_named(name, typ, layout); + + env.known_recursive_types.insert(layout, type_id); + + // Do an early return because we've already done everything we needed to do. + return type_id; + } + let mut tags: Vec<(String, Vec)> = union_tags .iter_from_subs(subs) .map(|(tag_name, payload_vars)| { @@ -981,11 +1028,6 @@ fn add_tag_union<'a>( }) .collect(); - let name = match opt_name { - Some(sym) => sym.as_str(env.interns).to_string(), - None => env.enum_names.get_name(var), - }; - // Sort tags alphabetically by tag name tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2)); @@ -1062,19 +1104,9 @@ fn add_tag_union<'a>( discriminant_offset, } } - // A recursive tag union with just one constructor - // Optimization: No need to store a tag ID (the payload is "unwrapped") - // e.g. `RoseTree a : [Tree a (List (RoseTree a))]` NonNullableUnwrapped(_) => { - debug_assert_eq!(1, tags.len()); - - let (tag_name, payload) = tags.pop().unwrap(); - - RocTagUnion::RecursiveSingleTag { - name: name.clone(), - tag_name, - payload: payload.unwrap(), - } + // This was already special-cased with an early return, above. + unreachable!() } // A recursive tag union that has an empty variant // Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison