Don't special-case bindgen for single-tag unions

This commit is contained in:
Richard Feldman 2022-05-29 00:03:30 -04:00
parent 1f9040460a
commit 729c849b51
No known key found for this signature in database
GPG key ID: 7E4127D1E4241798
2 changed files with 115 additions and 225 deletions

View file

@ -313,170 +313,138 @@ fn add_tag_union(
}) })
.collect(); .collect();
if tags.len() == 1 { let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap();
// This is a single-tag union. let name = match opt_name {
let (tag_name, payload_vars) = tags.pop().unwrap(); Some(sym) => sym.as_str(env.interns).to_string(),
None => env.enum_names.get_name(var),
};
// If there was a type alias name, use that. Otherwise use the tag name. // Sort tags alphabetically by tag name
let name = match opt_name { tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2));
Some(sym) => sym.as_str(env.interns).to_string(),
None => tag_name,
};
let fields = payload_vars let is_recursive = is_recursive_tag_union(&layout);
.iter()
.enumerate()
.map(|(index, payload_var)| (index, *payload_var));
add_struct(env, name, fields, types, |name, fields| { let mut tags: Vec<_> = tags
RocType::TagUnionPayload { name, fields } .into_iter()
.map(|(tag_name, payload_vars)| {
match struct_fields_needed(env, payload_vars.iter().copied()) {
0 => {
// no payload
(tag_name, None)
}
1 if !is_recursive => {
// this isn't recursive and there's 1 payload item, so it doesn't
// need its own struct - e.g. for `[Foo Str, Bar Str]` both of them
// can have payloads of plain old Str, no struct wrapper needed.
let payload_var = payload_vars.get(0).unwrap();
let layout = env
.layout_cache
.from_var(env.arena, *payload_var, env.subs)
.expect("Something weird ended up in the content");
let payload_id = add_type_help(env, layout, *payload_var, None, types);
(tag_name, Some(payload_id))
}
_ => {
// create a RocType for the payload and save it
let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant"
let fields = payload_vars.iter().copied().enumerate();
let struct_id = add_struct(env, struct_name, fields, types, |name, fields| {
RocType::TagUnionPayload { name, fields }
});
(tag_name, Some(struct_id))
}
}
}) })
} else { .collect();
// This is a multi-tag union.
// This is a placeholder so that we can get a TypeId for future recursion IDs. let typ = match layout {
// At the end, we will replace this with the real tag union type. Layout::Union(union_layout) => {
let type_id = types.add(RocType::Struct { use roc_mono::layout::UnionLayout::*;
name: String::new(),
fields: Vec::new(),
});
let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap();
let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(),
None => env.enum_names.get_name(var),
};
// Sort tags alphabetically by tag name match union_layout {
tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2)); // A non-recursive tag union
// e.g. `Result ok err : [Ok ok, Err err]`
let is_recursive = is_recursive_tag_union(&layout); NonRecursive(_) => RocType::TagUnion(RocTagUnion::NonRecursive { name, tags }),
// A recursive tag union (general case)
let mut tags: Vec<_> = tags // e.g. `Expr : [Sym Str, Add Expr Expr]`
.into_iter() Recursive(_) => RocType::TagUnion(RocTagUnion::Recursive { name, tags }),
.map(|(tag_name, payload_vars)| { // A recursive tag union with just one constructor
match struct_fields_needed(env, payload_vars.iter().copied()) { // Optimization: No need to store a tag ID (the payload is "unwrapped")
0 => { // e.g. `RoseTree a : [Tree a (List (RoseTree a))]`
// no payload NonNullableUnwrapped(_) => {
(tag_name, None) todo!()
}
1 if !is_recursive => {
// this isn't recursive and there's 1 payload item, so it doesn't
// need its own struct - e.g. for `[Foo Str, Bar Str]` both of them
// can have payloads of plain old Str, no struct wrapper needed.
let payload_var = payload_vars.get(0).unwrap();
let layout = env
.layout_cache
.from_var(env.arena, *payload_var, env.subs)
.expect("Something weird ended up in the content");
let payload_id = add_type_help(env, layout, *payload_var, None, types);
(tag_name, Some(payload_id))
}
_ => {
// create a RocType for the payload and save it
let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant"
let fields = payload_vars.iter().copied().enumerate();
let struct_id =
add_struct(env, struct_name, fields, types, |name, fields| {
RocType::TagUnionPayload { name, fields }
});
(tag_name, Some(struct_id))
}
} }
}) // A recursive tag union that has an empty variant
.collect(); // Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison
// It has more than one other variant, so they need tag IDs (payloads are "wrapped")
// e.g. `FingerTree a : [Empty, Single a, More (Some a) (FingerTree (Tuple a)) (Some a)]`
// see also: https://youtu.be/ip92VMpf_-A?t=164
NullableWrapped { .. } => {
todo!()
}
// A recursive tag union with only two variants, where one is empty.
// Optimizations: Use null for the empty variant AND don't store a tag ID for the other variant.
// e.g. `ConsList a : [Nil, Cons a (ConsList a)]`
NullableUnwrapped {
nullable_id: null_represents_first_tag,
other_fields: _, // TODO use this!
} => {
// NullableUnwrapped tag unions should always have exactly 2 tags.
debug_assert_eq!(tags.len(), 2);
let typ = match layout { let null_tag;
Layout::Union(union_layout) => { let non_null;
use roc_mono::layout::UnionLayout::*;
match union_layout { if null_represents_first_tag {
// A non-recursive tag union // If nullable_id is true, then the null tag is second, which means
// e.g. `Result ok err : [Ok ok, Err err]` // pop() will return it because it's at the end of the vec.
NonRecursive(_) => RocType::TagUnion(RocTagUnion::NonRecursive { name, tags }), null_tag = tags.pop().unwrap().0;
// A recursive tag union (general case) non_null = tags.pop().unwrap();
// e.g. `Expr : [Sym Str, Add Expr Expr]` } else {
Recursive(_) => RocType::TagUnion(RocTagUnion::Recursive { name, tags }), // The null tag is first, which means the tag with the payload is second.
// A recursive tag union with just one constructor non_null = tags.pop().unwrap();
// Optimization: No need to store a tag ID (the payload is "unwrapped") null_tag = tags.pop().unwrap().0;
// e.g. `RoseTree a : [Tree a (List (RoseTree a))]`
NonNullableUnwrapped(_) => {
todo!()
} }
// A recursive tag union that has an empty variant
// Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison
// It has more than one other variant, so they need tag IDs (payloads are "wrapped")
// e.g. `FingerTree a : [Empty, Single a, More (Some a) (FingerTree (Tuple a)) (Some a)]`
// see also: https://youtu.be/ip92VMpf_-A?t=164
NullableWrapped { .. } => {
todo!()
}
// A recursive tag union with only two variants, where one is empty.
// Optimizations: Use null for the empty variant AND don't store a tag ID for the other variant.
// e.g. `ConsList a : [Nil, Cons a (ConsList a)]`
NullableUnwrapped {
nullable_id: null_represents_first_tag,
other_fields: _, // TODO use this!
} => {
// NullableUnwrapped tag unions should always have exactly 2 tags.
debug_assert_eq!(tags.len(), 2);
let null_tag; let (non_null_tag, non_null_payload) = non_null;
let non_null;
if null_represents_first_tag { RocType::TagUnion(RocTagUnion::NullableUnwrapped {
// If nullable_id is true, then the null tag is second, which means name,
// pop() will return it because it's at the end of the vec. null_tag,
null_tag = tags.pop().unwrap().0; non_null_tag,
non_null = tags.pop().unwrap(); non_null_payload: non_null_payload.unwrap(),
} else { null_represents_first_tag,
// The null tag is first, which means the tag with the payload is second. })
non_null = tags.pop().unwrap();
null_tag = tags.pop().unwrap().0;
}
let (non_null_tag, non_null_payload) = non_null;
RocType::TagUnion(RocTagUnion::NullableUnwrapped {
name,
null_tag,
non_null_tag,
non_null_payload: non_null_payload.unwrap(),
null_represents_first_tag,
})
}
} }
} }
Layout::Builtin(builtin) => match builtin {
Builtin::Int(_) => RocType::TagUnion(RocTagUnion::Enumeration {
name,
tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(),
}),
Builtin::Bool => RocType::Bool,
Builtin::Float(_)
| Builtin::Decimal
| Builtin::Str
| Builtin::Dict(_, _)
| Builtin::Set(_)
| Builtin::List(_) => unreachable!(),
},
Layout::Struct { .. }
| Layout::Boxed(_)
| Layout::LambdaSet(_)
| Layout::RecursivePointer => {
unreachable!()
}
};
types.replace(type_id, typ);
if is_recursive {
env.known_recursive_types.insert(var, type_id);
} }
Layout::Builtin(Builtin::Int(_)) => RocType::TagUnion(RocTagUnion::Enumeration {
name,
tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(),
}),
Layout::Builtin(_)
| Layout::Struct { .. }
| Layout::Boxed(_)
| Layout::LambdaSet(_)
| Layout::RecursivePointer => {
// These must be single-tag unions. Bindgen ordinary nonrecursive
// tag unions for them, and let Rust do the unwrapping.
//
// This should be a very rare use case, and it's not worth overcomplicating
// the rest of bindgen to make it do something different.
RocType::TagUnion(RocTagUnion::NonRecursive { name, tags })
}
};
type_id let type_id = types.add(typ);
if is_recursive {
env.known_recursive_types.insert(var, type_id);
} }
type_id
} }
fn is_recursive_tag_union(layout: &Layout) -> bool { fn is_recursive_tag_union(layout: &Layout) -> bool {

View file

@ -233,82 +233,4 @@ mod test_gen_rs {
) )
); );
} }
#[test]
fn single_tag_union_with_payloads() {
let module = indoc!(
r#"
UserId : [Id U32 Str]
main : UserId
main = Id 42 "blah"
"#
);
assert_eq!(
generate_bindings(module)
.strip_prefix('\n')
.unwrap_or_default(),
indoc!(
r#"
#[cfg(any(
target_arch = "x86_64",
target_arch = "aarch64"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
struct UserId {
pub f1: roc_std::RocStr,
pub f0: u32,
}
#[cfg(any(
target_arch = "x86",
target_arch = "arm",
target_arch = "wasm32"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
struct UserId {
pub f0: u32,
pub f1: roc_std::RocStr,
}
"#
)
);
}
#[test]
fn single_tag_union_with_one_payload_field() {
let module = indoc!(
r#"
UserId : [Id Str]
main : UserId
main = Id "blah"
"#
);
assert_eq!(
generate_bindings(module)
.strip_prefix('\n')
.unwrap_or_default(),
indoc!(
r#"
#[cfg(any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "arm",
target_arch = "wasm32"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
struct UserId {
pub f0: roc_std::RocStr,
}
"#
)
);
}
} }