Revise glue for single-tag recursive unions

This commit is contained in:
Richard Feldman 2022-08-07 21:38:22 -04:00
parent 22b219c3c9
commit 4ac79772df
No known key found for this signature in database
GPG key ID: 7E4127D1E4241798
2 changed files with 96 additions and 63 deletions

View file

@ -248,9 +248,12 @@ fn add_type(target_info: TargetInfo, id: TypeId, types: &Types, impls: &mut Impl
RocTagUnion::RecursiveSingleTag { RocTagUnion::RecursiveSingleTag {
name, name,
tag_name, tag_name,
payload: payload_id, payload_fields,
} => {
// TODO
todo!();
} }
| RocTagUnion::NonRecursiveSingleTag { RocTagUnion::NonRecursiveSingleTag {
name, name,
tag_name, tag_name,
payload: Some(payload_id), payload: Some(payload_id),
@ -269,13 +272,21 @@ fn add_type(target_info: TargetInfo, id: TypeId, types: &Types, impls: &mut Impl
types, types,
false, false,
); );
let mut buf = format!("#[repr(C)]\n{derive}\npub struct {name} {{");
write_tag_union_field(tag_name, &Some(*payload_id), types, &mut buf); add_decl(
impls,
buf.push_str("}\n"); None,
target_info,
add_decl(impls, None, target_info, buf); format!(
r#"#[repr(C)]
{derive}
pub struct {name} {{
{tag_name}: *mut {}
}}
"#,
type_name(*payload_id, types)
),
);
} }
} }
RocTagUnion::NonRecursiveSingleTag { RocTagUnion::NonRecursiveSingleTag {
@ -486,7 +497,27 @@ pub struct {name} {{
let mut buf = format!("#[repr(C)]\n{pub_str}union {decl_union_name} {{\n"); let mut buf = format!("#[repr(C)]\n{pub_str}union {decl_union_name} {{\n");
for (tag_name, opt_payload_id) in tags { for (tag_name, opt_payload_id) in tags {
write_tag_union_field(tag_name, opt_payload_id, types, &mut buf); // If there's no payload, we don't need a discriminant for it.
if let Some(payload_id) = opt_payload_id {
let payload_type = types.get_type(*payload_id);
write!(buf, "{INDENT}{tag_name}: ").unwrap();
if cannot_derive_copy(payload_type, types) {
// types with pointers need ManuallyDrop
// because rust unions don't (and can't)
// know how to drop them automatically!
writeln!(
buf,
"core::mem::ManuallyDrop<{}>,",
type_name(*payload_id, types)
)
.unwrap();
} else {
buf.push_str(&type_name(*payload_id, types));
buf.push_str(",\n");
}
}
} }
if tags.len() > 1 { if tags.len() > 1 {
@ -1278,36 +1309,6 @@ pub struct {name} {{
} }
} }
#[inline(always)]
fn write_tag_union_field(
tag_name: &str,
opt_payload_id: &Option<TypeId>,
types: &Types,
buf: &mut String,
) {
// If there's no payload, we don't need a discriminant for it.
if let Some(payload_id) = opt_payload_id {
let payload_type = types.get_type(*payload_id);
write!(buf, "{INDENT}{tag_name}: ").unwrap();
if cannot_derive_copy(payload_type, types) {
// types with pointers need ManuallyDrop
// because rust unions don't (and can't)
// know how to drop them automatically!
writeln!(
buf,
"core::mem::ManuallyDrop<{}>,",
type_name(*payload_id, types)
)
.unwrap();
} else {
buf.push_str(&type_name(*payload_id, types));
buf.push_str(",\n");
}
}
}
fn write_impl_tags< fn write_impl_tags<
'a, 'a,
I: IntoIterator<Item = &'a (String, Option<TypeId>)>, I: IntoIterator<Item = &'a (String, Option<TypeId>)>,
@ -2145,6 +2146,9 @@ fn has_float_help(roc_type: &RocType, types: &Types, do_not_recurse: &[TypeId])
RocType::TagUnionPayload { fields, .. } => fields RocType::TagUnionPayload { fields, .. } => fields
.iter() .iter()
.any(|(_, type_id)| has_float_help(types.get_type(*type_id), types, do_not_recurse)), .any(|(_, type_id)| has_float_help(types.get_type(*type_id), types, do_not_recurse)),
RocType::TagUnion(RocTagUnion::RecursiveSingleTag { payload_fields, .. }) => payload_fields
.iter()
.any(|type_id| has_float_help(types.get_type(*type_id), types, do_not_recurse)),
RocType::TagUnion(RocTagUnion::Recursive { tags, .. }) RocType::TagUnion(RocTagUnion::Recursive { tags, .. })
| RocType::TagUnion(RocTagUnion::NonRecursive { tags, .. }) => { | RocType::TagUnion(RocTagUnion::NonRecursive { tags, .. }) => {
tags.iter().any(|(_, payloads)| { tags.iter().any(|(_, payloads)| {
@ -2164,9 +2168,6 @@ fn has_float_help(roc_type: &RocType, types: &Types, do_not_recurse: &[TypeId])
non_null_payload: content, non_null_payload: content,
.. ..
}) })
| RocType::TagUnion(RocTagUnion::RecursiveSingleTag {
payload: content, ..
})
| RocType::RecursivePointer(content) => { | RocType::RecursivePointer(content) => {
if do_not_recurse.contains(content) { if do_not_recurse.contains(content) {
false false

View file

@ -147,12 +147,17 @@ impl Types {
} }
( (
RecursiveSingleTag { RecursiveSingleTag {
payload: content_a, .. name: _,
tag_name: tag_name_a,
payload_fields: payload_fields_a,
}, },
RecursiveSingleTag { RecursiveSingleTag {
payload: content_b, .. name: _,
tag_name: tag_name_b,
payload_fields: payload_fields_b,
}, },
) => content_a == content_b, ) => tag_name_a == tag_name_b && payload_fields_a == payload_fields_b,
( (
NullableWrapped { tags: tags_a, .. }, NullableWrapped { tags: tags_a, .. },
NullableWrapped { tags: tags_b, .. }, NullableWrapped { tags: tags_b, .. },
@ -543,7 +548,7 @@ pub enum RocTagUnion {
RecursiveSingleTag { RecursiveSingleTag {
name: String, name: String,
tag_name: String, tag_name: String,
payload: TypeId, payload_fields: Vec<TypeId>,
}, },
/// A recursive tag union that has an empty variant /// A recursive tag union that has an empty variant
@ -972,6 +977,48 @@ fn add_tag_union<'a>(
layout: Layout<'a>, layout: Layout<'a>,
) -> TypeId { ) -> TypeId {
let subs = env.subs; let subs = env.subs;
let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(),
None => env.enum_names.get_name(var),
};
// This one needs an early return. That's because the the outermost representation
// of an unwrapped, non-nullable tag union is a struct, not a pointer. We must not
// create structs for its payloads like we do with other tag union types,
// because this one *is* the struct!
if let Layout::Union(UnionLayout::NonNullableUnwrapped(payload_field_layouts)) = layout {
let mut iter = union_tags.iter_from_subs(subs);
let (tag_name, payload_vars) = iter.next().unwrap();
// NonNullableUnwrapped should always have exactly 1 payload.
debug_assert_eq!(iter.next(), None);
let payload_fields: Vec<TypeId> = payload_vars
.into_iter()
.zip(payload_field_layouts.into_iter())
.map(|(field_var, field_layout)| {
add_type_help(env, *field_layout, *field_var, None, types)
})
.collect();
// A recursive tag union with just one constructor
// Optimization: No need to store a tag ID (the payload is "unwrapped")
// e.g. `RoseTree a : [Tree a (List (RoseTree a))]`
let tag_union_type = RocTagUnion::RecursiveSingleTag {
name: name.clone(),
tag_name: tag_name.0.as_str().to_string(),
payload_fields,
};
let typ = RocType::TagUnion(tag_union_type);
let type_id = types.add_named(name, typ, layout);
env.known_recursive_types.insert(layout, type_id);
// Do an early return because we've already done everything we needed to do.
return type_id;
}
let mut tags: Vec<(String, Vec<Variable>)> = union_tags let mut tags: Vec<(String, Vec<Variable>)> = union_tags
.iter_from_subs(subs) .iter_from_subs(subs)
.map(|(tag_name, payload_vars)| { .map(|(tag_name, payload_vars)| {
@ -981,11 +1028,6 @@ fn add_tag_union<'a>(
}) })
.collect(); .collect();
let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(),
None => env.enum_names.get_name(var),
};
// Sort tags alphabetically by tag name // Sort tags alphabetically by tag name
tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2)); tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2));
@ -1062,19 +1104,9 @@ fn add_tag_union<'a>(
discriminant_offset, discriminant_offset,
} }
} }
// A recursive tag union with just one constructor
// Optimization: No need to store a tag ID (the payload is "unwrapped")
// e.g. `RoseTree a : [Tree a (List (RoseTree a))]`
NonNullableUnwrapped(_) => { NonNullableUnwrapped(_) => {
debug_assert_eq!(1, tags.len()); // This was already special-cased with an early return, above.
unreachable!()
let (tag_name, payload) = tags.pop().unwrap();
RocTagUnion::RecursiveSingleTag {
name: name.clone(),
tag_name,
payload: payload.unwrap(),
}
} }
// A recursive tag union that has an empty variant // A recursive tag union that has an empty variant
// Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison // Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison