Don't special-case bindgen for single-tag unions

This commit is contained in:
Richard Feldman 2022-05-29 00:03:30 -04:00
parent 1f9040460a
commit 729c849b51
No known key found for this signature in database
GPG key ID: 7E4127D1E4241798
2 changed files with 115 additions and 225 deletions

View file

@ -313,33 +313,6 @@ fn add_tag_union(
}) })
.collect(); .collect();
if tags.len() == 1 {
// This is a single-tag union.
let (tag_name, payload_vars) = tags.pop().unwrap();
// If there was a type alias name, use that. Otherwise use the tag name.
let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(),
None => tag_name,
};
let fields = payload_vars
.iter()
.enumerate()
.map(|(index, payload_var)| (index, *payload_var));
add_struct(env, name, fields, types, |name, fields| {
RocType::TagUnionPayload { name, fields }
})
} else {
// This is a multi-tag union.
// This is a placeholder so that we can get a TypeId for future recursion IDs.
// At the end, we will replace this with the real tag union type.
let type_id = types.add(RocType::Struct {
name: String::new(),
fields: Vec::new(),
});
let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap(); let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap();
let name = match opt_name { let name = match opt_name {
Some(sym) => sym.as_str(env.interns).to_string(), Some(sym) => sym.as_str(env.interns).to_string(),
@ -376,8 +349,7 @@ fn add_tag_union(
// create a RocType for the payload and save it // create a RocType for the payload and save it
let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant" let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant"
let fields = payload_vars.iter().copied().enumerate(); let fields = payload_vars.iter().copied().enumerate();
let struct_id = let struct_id = add_struct(env, struct_name, fields, types, |name, fields| {
add_struct(env, struct_name, fields, types, |name, fields| {
RocType::TagUnionPayload { name, fields } RocType::TagUnionPayload { name, fields }
}); });
@ -448,35 +420,31 @@ fn add_tag_union(
} }
} }
} }
Layout::Builtin(builtin) => match builtin { Layout::Builtin(Builtin::Int(_)) => RocType::TagUnion(RocTagUnion::Enumeration {
Builtin::Int(_) => RocType::TagUnion(RocTagUnion::Enumeration {
name, name,
tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(), tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(),
}), }),
Builtin::Bool => RocType::Bool, Layout::Builtin(_)
Builtin::Float(_) | Layout::Struct { .. }
| Builtin::Decimal
| Builtin::Str
| Builtin::Dict(_, _)
| Builtin::Set(_)
| Builtin::List(_) => unreachable!(),
},
Layout::Struct { .. }
| Layout::Boxed(_) | Layout::Boxed(_)
| Layout::LambdaSet(_) | Layout::LambdaSet(_)
| Layout::RecursivePointer => { | Layout::RecursivePointer => {
unreachable!() // These must be single-tag unions. Bindgen ordinary nonrecursive
// tag unions for them, and let Rust do the unwrapping.
//
// This should be a very rare use case, and it's not worth overcomplicating
// the rest of bindgen to make it do something different.
RocType::TagUnion(RocTagUnion::NonRecursive { name, tags })
} }
}; };
types.replace(type_id, typ); let type_id = types.add(typ);
if is_recursive { if is_recursive {
env.known_recursive_types.insert(var, type_id); env.known_recursive_types.insert(var, type_id);
} }
type_id type_id
}
} }
fn is_recursive_tag_union(layout: &Layout) -> bool { fn is_recursive_tag_union(layout: &Layout) -> bool {

View file

@ -233,82 +233,4 @@ mod test_gen_rs {
) )
); );
} }
#[test]
fn single_tag_union_with_payloads() {
let module = indoc!(
r#"
UserId : [Id U32 Str]
main : UserId
main = Id 42 "blah"
"#
);
assert_eq!(
generate_bindings(module)
.strip_prefix('\n')
.unwrap_or_default(),
indoc!(
r#"
#[cfg(any(
target_arch = "x86_64",
target_arch = "aarch64"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
struct UserId {
pub f1: roc_std::RocStr,
pub f0: u32,
}
#[cfg(any(
target_arch = "x86",
target_arch = "arm",
target_arch = "wasm32"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
struct UserId {
pub f0: u32,
pub f1: roc_std::RocStr,
}
"#
)
);
}
#[test]
fn single_tag_union_with_one_payload_field() {
let module = indoc!(
r#"
UserId : [Id Str]
main : UserId
main = Id "blah"
"#
);
assert_eq!(
generate_bindings(module)
.strip_prefix('\n')
.unwrap_or_default(),
indoc!(
r#"
#[cfg(any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "arm",
target_arch = "wasm32"
))]
#[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)]
#[repr(C)]
struct UserId {
pub f0: roc_std::RocStr,
}
"#
)
);
}
} }