Bindgen PartialEq and Eq instances for tag unions

This commit is contained in:
Richard Feldman 2022-05-11 08:56:48 -04:00
parent bf59119bf3
commit 2fb52af432
No known key found for this signature in database
GPG key ID: 7E4127D1E4241798
2 changed files with 76 additions and 2 deletions

View file

@ -30,7 +30,7 @@ pub fn write_types(types: &Types, buf: &mut String) -> fmt::Result {
// Empty tag unions can never come up at runtime,
// and so don't need declared types.
if !tags.is_empty() {
write_tag_union(name, tags, types, buf)?;
write_tag_union(name, id, tags, types, buf)?;
}
}
RocType::RecursiveTagUnion { .. } => {
@ -74,6 +74,7 @@ pub fn write_types(types: &Types, buf: &mut String) -> fmt::Result {
fn write_tag_union(
name: &str,
type_id: TypeId,
tags: &[(String, Option<TypeId>)],
types: &Types,
buf: &mut String,
@ -91,6 +92,7 @@ fn write_tag_union(
name: discriminant_name.clone(),
tags: tag_names.clone().cloned().collect(),
};
let typ = types.get(type_id);
write_enumeration(
&discriminant_name,
@ -271,7 +273,6 @@ fn write_tag_union(
writeln!(
buf,
// Don't use indoc because this must be indented once!
indoc!(
r#"
}}
@ -282,6 +283,60 @@ fn write_tag_union(
)?;
}
// The PartialEq impl for the tag union
{
write!(
buf,
indoc!(
r#"
impl PartialEq for {} {{
fn eq(&self, other: &Self) -> bool {{
if self.tag != other.tag {{
return false;
}}
unsafe {{
match self.tag {{
"#
),
name
)?;
write_impl_tags(
4,
tags.iter(),
&discriminant_name,
buf,
|tag_name, opt_payload_id| {
if opt_payload_id.is_some() {
format!("self.variant.{} == other.variant.{},", tag_name, tag_name)
} else {
// if the tags themselves had been unequal, we already would have
// early-returned with false, so this means the tags were equal
// and there's no payload; return true!
"true,".to_string()
}
},
)?;
writeln!(
buf,
// Don't use indoc because this must be indented once!
indoc!(
r#"
}}
}}
}}
}}
"#
),
)?;
}
if !typ.has_float(types) {
writeln!(buf, "impl Eq for {} {{}}\n", name)?;
}
Ok(())
}

View file

@ -263,6 +263,25 @@ fn tag_union_aliased() {
}
}
impl PartialEq for MyTagUnion {
fn eq(&self, other: &Self) -> bool {
if self.tag != other.tag {
return false;
}
unsafe {
match self.tag {
tag_MyTagUnion::Bar => self.variant.Bar == other.variant.Bar,
tag_MyTagUnion::Baz => true,
tag_MyTagUnion::Blah => self.variant.Blah == other.variant.Blah,
tag_MyTagUnion::Foo => self.variant.Foo == other.variant.Foo,
}
}
}
}
impl Eq for MyTagUnion {}
"#
)
);