diff --git a/bindgen/src/bindgen.rs b/bindgen/src/bindgen.rs index 23438779e1..52a4273480 100644 --- a/bindgen/src/bindgen.rs +++ b/bindgen/src/bindgen.rs @@ -313,170 +313,138 @@ fn add_tag_union( }) .collect(); - if tags.len() == 1 { - // This is a single-tag union. - let (tag_name, payload_vars) = tags.pop().unwrap(); + let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap(); + let name = match opt_name { + Some(sym) => sym.as_str(env.interns).to_string(), + None => env.enum_names.get_name(var), + }; - // If there was a type alias name, use that. Otherwise use the tag name. - let name = match opt_name { - Some(sym) => sym.as_str(env.interns).to_string(), - None => tag_name, - }; + // Sort tags alphabetically by tag name + tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2)); - let fields = payload_vars - .iter() - .enumerate() - .map(|(index, payload_var)| (index, *payload_var)); + let is_recursive = is_recursive_tag_union(&layout); - add_struct(env, name, fields, types, |name, fields| { - RocType::TagUnionPayload { name, fields } + let mut tags: Vec<_> = tags + .into_iter() + .map(|(tag_name, payload_vars)| { + match struct_fields_needed(env, payload_vars.iter().copied()) { + 0 => { + // no payload + (tag_name, None) + } + 1 if !is_recursive => { + // this isn't recursive and there's 1 payload item, so it doesn't + // need its own struct - e.g. for `[Foo Str, Bar Str]` both of them + // can have payloads of plain old Str, no struct wrapper needed. + let payload_var = payload_vars.get(0).unwrap(); + let layout = env + .layout_cache + .from_var(env.arena, *payload_var, env.subs) + .expect("Something weird ended up in the content"); + let payload_id = add_type_help(env, layout, *payload_var, None, types); + + (tag_name, Some(payload_id)) + } + _ => { + // create a RocType for the payload and save it + let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant" + let fields = payload_vars.iter().copied().enumerate(); + let struct_id = add_struct(env, struct_name, fields, types, |name, fields| { + RocType::TagUnionPayload { name, fields } + }); + + (tag_name, Some(struct_id)) + } + } }) - } else { - // This is a multi-tag union. + .collect(); - // This is a placeholder so that we can get a TypeId for future recursion IDs. - // At the end, we will replace this with the real tag union type. - let type_id = types.add(RocType::Struct { - name: String::new(), - fields: Vec::new(), - }); - let layout = env.layout_cache.from_var(env.arena, var, subs).unwrap(); - let name = match opt_name { - Some(sym) => sym.as_str(env.interns).to_string(), - None => env.enum_names.get_name(var), - }; + let typ = match layout { + Layout::Union(union_layout) => { + use roc_mono::layout::UnionLayout::*; - // Sort tags alphabetically by tag name - tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2)); - - let is_recursive = is_recursive_tag_union(&layout); - - let mut tags: Vec<_> = tags - .into_iter() - .map(|(tag_name, payload_vars)| { - match struct_fields_needed(env, payload_vars.iter().copied()) { - 0 => { - // no payload - (tag_name, None) - } - 1 if !is_recursive => { - // this isn't recursive and there's 1 payload item, so it doesn't - // need its own struct - e.g. for `[Foo Str, Bar Str]` both of them - // can have payloads of plain old Str, no struct wrapper needed. - let payload_var = payload_vars.get(0).unwrap(); - let layout = env - .layout_cache - .from_var(env.arena, *payload_var, env.subs) - .expect("Something weird ended up in the content"); - let payload_id = add_type_help(env, layout, *payload_var, None, types); - - (tag_name, Some(payload_id)) - } - _ => { - // create a RocType for the payload and save it - let struct_name = format!("{}_{}", name, tag_name); // e.g. "MyUnion_MyVariant" - let fields = payload_vars.iter().copied().enumerate(); - let struct_id = - add_struct(env, struct_name, fields, types, |name, fields| { - RocType::TagUnionPayload { name, fields } - }); - - (tag_name, Some(struct_id)) - } + match union_layout { + // A non-recursive tag union + // e.g. `Result ok err : [Ok ok, Err err]` + NonRecursive(_) => RocType::TagUnion(RocTagUnion::NonRecursive { name, tags }), + // A recursive tag union (general case) + // e.g. `Expr : [Sym Str, Add Expr Expr]` + Recursive(_) => RocType::TagUnion(RocTagUnion::Recursive { name, tags }), + // A recursive tag union with just one constructor + // Optimization: No need to store a tag ID (the payload is "unwrapped") + // e.g. `RoseTree a : [Tree a (List (RoseTree a))]` + NonNullableUnwrapped(_) => { + todo!() } - }) - .collect(); + // A recursive tag union that has an empty variant + // Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison + // It has more than one other variant, so they need tag IDs (payloads are "wrapped") + // e.g. `FingerTree a : [Empty, Single a, More (Some a) (FingerTree (Tuple a)) (Some a)]` + // see also: https://youtu.be/ip92VMpf_-A?t=164 + NullableWrapped { .. } => { + todo!() + } + // A recursive tag union with only two variants, where one is empty. + // Optimizations: Use null for the empty variant AND don't store a tag ID for the other variant. + // e.g. `ConsList a : [Nil, Cons a (ConsList a)]` + NullableUnwrapped { + nullable_id: null_represents_first_tag, + other_fields: _, // TODO use this! + } => { + // NullableUnwrapped tag unions should always have exactly 2 tags. + debug_assert_eq!(tags.len(), 2); - let typ = match layout { - Layout::Union(union_layout) => { - use roc_mono::layout::UnionLayout::*; + let null_tag; + let non_null; - match union_layout { - // A non-recursive tag union - // e.g. `Result ok err : [Ok ok, Err err]` - NonRecursive(_) => RocType::TagUnion(RocTagUnion::NonRecursive { name, tags }), - // A recursive tag union (general case) - // e.g. `Expr : [Sym Str, Add Expr Expr]` - Recursive(_) => RocType::TagUnion(RocTagUnion::Recursive { name, tags }), - // A recursive tag union with just one constructor - // Optimization: No need to store a tag ID (the payload is "unwrapped") - // e.g. `RoseTree a : [Tree a (List (RoseTree a))]` - NonNullableUnwrapped(_) => { - todo!() + if null_represents_first_tag { + // If nullable_id is true, then the null tag is second, which means + // pop() will return it because it's at the end of the vec. + null_tag = tags.pop().unwrap().0; + non_null = tags.pop().unwrap(); + } else { + // The null tag is first, which means the tag with the payload is second. + non_null = tags.pop().unwrap(); + null_tag = tags.pop().unwrap().0; } - // A recursive tag union that has an empty variant - // Optimization: Represent the empty variant as null pointer => no memory usage & fast comparison - // It has more than one other variant, so they need tag IDs (payloads are "wrapped") - // e.g. `FingerTree a : [Empty, Single a, More (Some a) (FingerTree (Tuple a)) (Some a)]` - // see also: https://youtu.be/ip92VMpf_-A?t=164 - NullableWrapped { .. } => { - todo!() - } - // A recursive tag union with only two variants, where one is empty. - // Optimizations: Use null for the empty variant AND don't store a tag ID for the other variant. - // e.g. `ConsList a : [Nil, Cons a (ConsList a)]` - NullableUnwrapped { - nullable_id: null_represents_first_tag, - other_fields: _, // TODO use this! - } => { - // NullableUnwrapped tag unions should always have exactly 2 tags. - debug_assert_eq!(tags.len(), 2); - let null_tag; - let non_null; + let (non_null_tag, non_null_payload) = non_null; - if null_represents_first_tag { - // If nullable_id is true, then the null tag is second, which means - // pop() will return it because it's at the end of the vec. - null_tag = tags.pop().unwrap().0; - non_null = tags.pop().unwrap(); - } else { - // The null tag is first, which means the tag with the payload is second. - non_null = tags.pop().unwrap(); - null_tag = tags.pop().unwrap().0; - } - - let (non_null_tag, non_null_payload) = non_null; - - RocType::TagUnion(RocTagUnion::NullableUnwrapped { - name, - null_tag, - non_null_tag, - non_null_payload: non_null_payload.unwrap(), - null_represents_first_tag, - }) - } + RocType::TagUnion(RocTagUnion::NullableUnwrapped { + name, + null_tag, + non_null_tag, + non_null_payload: non_null_payload.unwrap(), + null_represents_first_tag, + }) } } - Layout::Builtin(builtin) => match builtin { - Builtin::Int(_) => RocType::TagUnion(RocTagUnion::Enumeration { - name, - tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(), - }), - Builtin::Bool => RocType::Bool, - Builtin::Float(_) - | Builtin::Decimal - | Builtin::Str - | Builtin::Dict(_, _) - | Builtin::Set(_) - | Builtin::List(_) => unreachable!(), - }, - Layout::Struct { .. } - | Layout::Boxed(_) - | Layout::LambdaSet(_) - | Layout::RecursivePointer => { - unreachable!() - } - }; - - types.replace(type_id, typ); - - if is_recursive { - env.known_recursive_types.insert(var, type_id); } + Layout::Builtin(Builtin::Int(_)) => RocType::TagUnion(RocTagUnion::Enumeration { + name, + tags: tags.into_iter().map(|(tag_name, _)| tag_name).collect(), + }), + Layout::Builtin(_) + | Layout::Struct { .. } + | Layout::Boxed(_) + | Layout::LambdaSet(_) + | Layout::RecursivePointer => { + // These must be single-tag unions. Bindgen ordinary nonrecursive + // tag unions for them, and let Rust do the unwrapping. + // + // This should be a very rare use case, and it's not worth overcomplicating + // the rest of bindgen to make it do something different. + RocType::TagUnion(RocTagUnion::NonRecursive { name, tags }) + } + }; - type_id + let type_id = types.add(typ); + + if is_recursive { + env.known_recursive_types.insert(var, type_id); } + + type_id } fn is_recursive_tag_union(layout: &Layout) -> bool { diff --git a/bindgen/tests/gen_rs.rs b/bindgen/tests/gen_rs.rs index 10c510d980..d558a8e926 100644 --- a/bindgen/tests/gen_rs.rs +++ b/bindgen/tests/gen_rs.rs @@ -233,82 +233,4 @@ mod test_gen_rs { ) ); } - - #[test] - fn single_tag_union_with_payloads() { - let module = indoc!( - r#" - UserId : [Id U32 Str] - - main : UserId - main = Id 42 "blah" - "# - ); - - assert_eq!( - generate_bindings(module) - .strip_prefix('\n') - .unwrap_or_default(), - indoc!( - r#" - #[cfg(any( - target_arch = "x86_64", - target_arch = "aarch64" - ))] - #[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)] - #[repr(C)] - struct UserId { - pub f1: roc_std::RocStr, - pub f0: u32, - } - - #[cfg(any( - target_arch = "x86", - target_arch = "arm", - target_arch = "wasm32" - ))] - #[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)] - #[repr(C)] - struct UserId { - pub f0: u32, - pub f1: roc_std::RocStr, - } - "# - ) - ); - } - - #[test] - fn single_tag_union_with_one_payload_field() { - let module = indoc!( - r#" - UserId : [Id Str] - - main : UserId - main = Id "blah" - "# - ); - - assert_eq!( - generate_bindings(module) - .strip_prefix('\n') - .unwrap_or_default(), - indoc!( - r#" - #[cfg(any( - target_arch = "x86_64", - target_arch = "x86", - target_arch = "aarch64", - target_arch = "arm", - target_arch = "wasm32" - ))] - #[derive(Clone, Debug, Default, Eq, Ord, Hash, PartialEq, PartialOrd)] - #[repr(C)] - struct UserId { - pub f0: roc_std::RocStr, - } - "# - ) - ); - } }