diff --git a/bindgen/src/bindgen.rs b/bindgen/src/bindgen.rs index 22bb175fbd..cd12aaffdb 100644 --- a/bindgen/src/bindgen.rs +++ b/bindgen/src/bindgen.rs @@ -5,12 +5,12 @@ use crate::types::{TypeId, Types}; use crate::{enums::Enums, types::RocType}; use bumpalo::Bump; use roc_builtins::bitcode::{FloatWidth::*, IntWidth::*}; -use roc_module::ident::TagName; +use roc_module::ident::{Lowercase, TagName}; use roc_module::symbol::{Interns, Symbol}; use roc_mono::layout::{cmp_fields, ext_var_is_empty_tag_union, Builtin, Layout, LayoutCache}; use roc_types::subs::UnionTags; use roc_types::{ - subs::{Content, FlatType, RecordFields, Subs, Variable}, + subs::{Content, FlatType, Subs, Variable}, types::RecordField, }; @@ -50,7 +50,27 @@ pub fn add_type_help<'a>( todo!("TODO give a nice error message for a non-concrete type being passed to the host") } Content::Structure(FlatType::Record(fields, ext)) => { - add_struct(env, opt_name, fields, var, *ext, types) + let it = fields + .unsorted_iterator(subs, *ext) + .expect("something weird in content") + .flat_map(|(label, field)| { + match field { + RecordField::Required(field_var) | RecordField::Demanded(field_var) => { + Some((label.clone(), field_var)) + } + RecordField::Optional(_) => { + // drop optional fields + None + } + } + }); + + let name = match opt_name { + Some(sym) => sym.as_str(env.interns).to_string(), + None => env.struct_names.get_name(var), + }; + + add_struct(env, name, it, types) } Content::Structure(FlatType::TagUnion(tags, ext_var)) => { debug_assert!(ext_var_is_empty_tag_union(subs, *ext_var)); @@ -165,35 +185,27 @@ pub fn add_builtin_type<'a>( } } -fn add_struct( +fn add_struct>( env: &mut Env<'_>, - opt_name: Option, - record_fields: &RecordFields, - var: Variable, - ext: Variable, + name: String, + fields: I, types: &mut Types, ) -> TypeId { let subs = env.subs; - let mut sortables = bumpalo::collections::Vec::with_capacity_in(record_fields.len(), env.arena); - let it = record_fields - .unsorted_iterator(subs, ext) - .expect("something weird in content"); + let fields_iter = fields.into_iter(); + let mut sortables = bumpalo::collections::Vec::with_capacity_in( + fields_iter.size_hint().1.unwrap_or_default(), + env.arena, + ); - for (label, field) in it { - match field { - RecordField::Required(field_var) | RecordField::Demanded(field_var) => { - sortables.push(( - label, - field_var, - env.layout_cache - .from_var(env.arena, field_var, subs) - .unwrap(), - )); - } - RecordField::Optional(_) => { - // drop optional fields - } - }; + for (label, field_var) in fields_iter { + sortables.push(( + label, + field_var, + env.layout_cache + .from_var(env.arena, field_var, subs) + .unwrap(), + )); } sortables.sort_by(|(label1, _, layout1), (label2, _, layout2)| { @@ -211,61 +223,106 @@ fn add_struct( .map(|(label, field_var, field_layout)| { ( label.to_string(), - add_type_help(env, field_layout, field_var, opt_name, types), + add_type_help(env, field_layout, field_var, None, types), ) }) .collect(); - let name = match opt_name { - Some(sym) => sym.as_str(env.interns).to_string(), - None => env.struct_names.get_name(var), - }; - types.add(RocType::Struct { name, fields }) } fn add_tag_union( env: &mut Env<'_>, opt_name: Option, - tags: &UnionTags, + union_tags: &UnionTags, var: Variable, types: &mut Types, ) -> TypeId { let subs = env.subs; + let mut tags: Vec<(String, Vec)> = union_tags + .iter_from_subs(subs) + .map(|(tag_name, payload_vars)| { + let name_str = match tag_name { + TagName::Tag(uppercase) => uppercase.as_str().to_string(), + TagName::Closure(_) => unreachable!(), + }; + + (name_str, payload_vars.to_vec()) + }) + .collect(); + + if tags.len() == 1 { + let (tag_name, payload_vars) = tags.pop().unwrap(); + + // If there was a type alias name, use that. Otherwise use the tag name. + let name = match opt_name { + Some(sym) => sym.as_str(env.interns).to_string(), + None => tag_name, + }; + + return match payload_vars.len() { + 0 => { + // This is a single-tag union with no payload, e.g. `[ Foo ]` + // so just generate an empty record + types.add(RocType::Struct { + name, + fields: Vec::new(), + }) + } + 1 => { + // This is a single-tag union with 1 payload field, e.g.`[ Foo Str ]`. + // We'll just wrap that. + let var = *payload_vars.get(0).unwrap(); + let content = add_type(env, var, types); + + types.add(RocType::TransparentWrapper { name, content }) + } + _ => { + // This is a single-tag union with multiple payload field, e.g.`[ Foo Str U32 ]`. + // Generate a record. + let fields = payload_vars.iter().enumerate().map(|(index, payload_var)| { + let field_name = format!("f{}", index).into(); + + (field_name, *payload_var) + }); + + add_struct(env, name, fields, types) + } + }; + } + + let name = match opt_name { + Some(sym) => sym.as_str(env.interns).to_string(), + None => env.enum_names.get_name(var), + }; + + // Sort tags alphabetically by tag name + tags.sort_by(|(name1, _), (name2, _)| name1.cmp(name2)); + + let tags = tags + .into_iter() + .map(|(tag_name, payload_vars)| { + let payloads = payload_vars + .iter() + .map(|payload_var| add_type(env, *payload_var, types)) + .collect::>(); + + (tag_name, payloads) + }) + .collect(); let typ = match env.layout_cache.from_var(env.arena, var, subs).unwrap() { Layout::Struct { .. } => { - // a single-tag union with a payload - todo!(); + // a single-tag union with multiple payload values, e.g. [ Foo Str Str ] + unreachable!() } Layout::Union(_) => todo!(), Layout::Builtin(builtin) => match builtin { - Builtin::Int(int_width) => { - let tag_pairs = subs.tag_names[tags.tag_names().indices()] - .iter() - .map(|tag_name| { - let name_str = match tag_name { - TagName::Tag(uppercase) => uppercase.as_str().to_string(), - TagName::Closure(_) => unreachable!(), - }; - - // This is an enum, so there's no payload. - (name_str, Vec::new()) - }) - .collect(); - - let tag_bytes = int_width.stack_size().try_into().unwrap(); - let name = match opt_name { - Some(sym) => sym.as_str(env.interns).to_string(), - None => env.enum_names.get_name(var), - }; - - RocType::TagUnion { - tag_bytes, - name, - tags: tag_pairs, - } - } + Builtin::Int(int_width) => RocType::TagUnion { + tag_bytes: int_width.stack_size().try_into().unwrap(), + name, + tags, + }, Builtin::Bool => RocType::Bool, Builtin::Float(_) | Builtin::Decimal diff --git a/bindgen/src/bindgen_rs.rs b/bindgen/src/bindgen_rs.rs index 7c879e03e2..a4eb1006f7 100644 --- a/bindgen/src/bindgen_rs.rs +++ b/bindgen/src/bindgen_rs.rs @@ -70,6 +70,12 @@ pub fn write_types(types: &Types, buf: &mut String) -> fmt::Result { | RocType::RocSet(_) | RocType::RocList(_) | RocType::RocBox(_) => {} + RocType::TransparentWrapper { name, content } => { + write_deriving(id, types, buf)?; + write!(buf, "#[repr(transparent)]\npub struct {}(", name)?; + write_type_name(*content, types, buf)?; + buf.write_str(");\n")?; + } } } @@ -101,14 +107,10 @@ fn write_struct( ) -> fmt::Result { write_deriving(struct_id, types, buf)?; - buf.write_str("#[repr(C)]\npub struct ")?; - buf.write_str(name)?; - buf.write_str(" {\n")?; + writeln!(buf, "#[repr(C)]\npub struct {} {{", name)?; for (label, field_id) in fields { - buf.write_str(INDENT)?; - buf.write_str(label.as_str())?; - buf.write_str(": ")?; + write!(buf, "{}{}: ", INDENT, label.as_str())?; write_type_name(*field_id, types, buf)?; buf.write_str(",\n")?; } @@ -158,6 +160,7 @@ fn write_type_name(id: TypeId, types: &Types, buf: &mut String) -> fmt::Result { } RocType::Struct { name, .. } | RocType::TagUnion { name, .. } + | RocType::TransparentWrapper { name, .. } | RocType::RecursiveTagUnion { name, .. } => buf.write_str(name), } } diff --git a/bindgen/src/types.rs b/bindgen/src/types.rs index 4f263b2f14..8339e371e1 100644 --- a/bindgen/src/types.rs +++ b/bindgen/src/types.rs @@ -129,6 +129,11 @@ pub enum RocType { name: String, fields: Vec<(String, TypeId)>, }, + /// Either a single-tag union or a single-field record + TransparentWrapper { + name: String, + content: TypeId, + }, } impl RocType { @@ -162,6 +167,7 @@ impl RocType { RocType::Struct { fields, .. } => fields .iter() .any(|(_, id)| types.get(*id).has_pointer(types)), + RocType::TransparentWrapper { content, .. } => types.get(*content).has_pointer(types), } } @@ -194,6 +200,7 @@ impl RocType { RocType::Struct { fields, .. } => { fields.iter().any(|(_, id)| types.get(*id).has_float(types)) } + RocType::TransparentWrapper { content, .. } => types.get(*content).has_float(types), } } @@ -226,6 +233,7 @@ impl RocType { RocType::Struct { fields, .. } => fields .iter() .any(|(_, id)| types.get(*id).has_tag_union(types)), + RocType::TransparentWrapper { content, .. } => types.get(*content).has_tag_union(types), } } @@ -283,6 +291,9 @@ impl RocType { RocType::F32 => FloatWidth::F32.alignment_bytes(target_info) as usize, RocType::F64 => FloatWidth::F64.alignment_bytes(target_info) as usize, RocType::F128 => FloatWidth::F128.alignment_bytes(target_info) as usize, + RocType::TransparentWrapper { content, .. } => { + types.get(*content).alignment(types, target_info) + } } } } diff --git a/bindgen/tests/gen_rs.rs b/bindgen/tests/gen_rs.rs index f1e497f62f..5208d11246 100644 --- a/bindgen/tests/gen_rs.rs +++ b/bindgen/tests/gen_rs.rs @@ -152,15 +152,15 @@ fn nested_record_anonymous() { r#" #[derive(Clone, PartialEq, PartialOrd, Default, Debug)] #[repr(C)] - pub struct R2 { + pub struct R1 { y: roc_std::RocStr, z: roc_std::RocList, - x: R1, + x: R2, } #[derive(Clone, PartialEq, PartialOrd, Copy, Default, Debug)] #[repr(C)] - pub struct R1 { + pub struct R2 { b: f32, a: u16, } @@ -347,3 +347,56 @@ fn tag_union_enumeration() { ) ); } + +#[test] +fn single_tag_union_with_payloads() { + let module = indoc!( + r#" + UserId : [ Id U32 Str ] + + main : UserId + main = Id 42 "blah" + "# + ); + + assert_eq!( + generate_bindings(module) + .strip_prefix('\n') + .unwrap_or_default(), + indoc!( + r#" + #[derive(Clone, PartialEq, PartialOrd, Default, Eq, Ord, Hash, Debug)] + #[repr(C)] + pub struct UserId { + f1: roc_std::RocStr, + f0: u32, + } + "# + ) + ); +} + +#[test] +fn single_tag_union_with_one_payload_field() { + let module = indoc!( + r#" + UserId : [ Id Str ] + + main : UserId + main = Id "blah" + "# + ); + + assert_eq!( + generate_bindings(module) + .strip_prefix('\n') + .unwrap_or_default(), + indoc!( + r#" + #[derive(Clone, PartialEq, PartialOrd, Default, Eq, Ord, Hash, Debug)] + #[repr(transparent)] + pub struct UserId(roc_std::RocStr); + "# + ) + ); +} diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index 97a17c5dc0..3af78043f6 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -2252,7 +2252,8 @@ fn register_tag_arguments<'a>( VariableSubsSlice::default() } else { let new_variables = VariableSubsSlice::reserve_into_subs(subs, arguments.len()); - let it = (new_variables.indices()).zip(arguments); + let it = new_variables.indices().zip(arguments); + for (target_index, argument) in it { let var = RegisterVariable::with_stack(subs, rank, pools, arena, argument, stack); subs.variables[target_index] = var; diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index 93b6ba97e5..9efc792704 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -2396,6 +2396,17 @@ impl UnionTags { .zip(self.variables().into_iter()) } + /// Iterator over (TagName, &[Variable]) pairs obtained by + /// looking up slices in the given Subs + pub fn iter_from_subs<'a>( + &'a self, + subs: &'a Subs, + ) -> impl Iterator + ExactSizeIterator { + self.iter_all().map(move |(name_index, payload_index)| { + (&subs[name_index], subs.get_subs_slice(subs[payload_index])) + }) + } + #[inline(always)] pub fn unsorted_iterator<'a>( &'a self,