diff --git a/bindgen/src/bindgen_rs.rs b/bindgen/src/bindgen_rs.rs index d60b3fe782..441c1d2ff0 100644 --- a/bindgen/src/bindgen_rs.rs +++ b/bindgen/src/bindgen_rs.rs @@ -30,7 +30,7 @@ pub fn write_types(types: &Types, buf: &mut String) -> fmt::Result { // Empty tag unions can never come up at runtime, // and so don't need declared types. if !tags.is_empty() { - write_tag_union(name, tags, types, buf)?; + write_tag_union(name, id, tags, types, buf)?; } } RocType::RecursiveTagUnion { .. } => { @@ -74,6 +74,7 @@ pub fn write_types(types: &Types, buf: &mut String) -> fmt::Result { fn write_tag_union( name: &str, + type_id: TypeId, tags: &[(String, Option)], types: &Types, buf: &mut String, @@ -91,6 +92,7 @@ fn write_tag_union( name: discriminant_name.clone(), tags: tag_names.clone().cloned().collect(), }; + let typ = types.get(type_id); write_enumeration( &discriminant_name, @@ -271,7 +273,6 @@ fn write_tag_union( writeln!( buf, - // Don't use indoc because this must be indented once! indoc!( r#" }} @@ -282,6 +283,60 @@ fn write_tag_union( )?; } + // The PartialEq impl for the tag union + { + write!( + buf, + indoc!( + r#" + impl PartialEq for {} {{ + fn eq(&self, other: &Self) -> bool {{ + if self.tag != other.tag {{ + return false; + }} + + unsafe {{ + match self.tag {{ + "# + ), + name + )?; + + write_impl_tags( + 4, + tags.iter(), + &discriminant_name, + buf, + |tag_name, opt_payload_id| { + if opt_payload_id.is_some() { + format!("self.variant.{} == other.variant.{},", tag_name, tag_name) + } else { + // if the tags themselves had been unequal, we already would have + // early-returned with false, so this means the tags were equal + // and there's no payload; return true! + "true,".to_string() + } + }, + )?; + + writeln!( + buf, + // Don't use indoc because this must be indented once! + indoc!( + r#" + }} + }} + }} + }} + "# + ), + )?; + } + + if !typ.has_float(types) { + writeln!(buf, "impl Eq for {} {{}}\n", name)?; + } + Ok(()) } diff --git a/bindgen/tests/gen_rs.rs b/bindgen/tests/gen_rs.rs index 7f086d586b..d369caa555 100644 --- a/bindgen/tests/gen_rs.rs +++ b/bindgen/tests/gen_rs.rs @@ -263,6 +263,25 @@ fn tag_union_aliased() { } } + impl PartialEq for MyTagUnion { + fn eq(&self, other: &Self) -> bool { + if self.tag != other.tag { + return false; + } + + unsafe { + match self.tag { + tag_MyTagUnion::Bar => self.variant.Bar == other.variant.Bar, + tag_MyTagUnion::Baz => true, + tag_MyTagUnion::Blah => self.variant.Blah == other.variant.Blah, + tag_MyTagUnion::Foo => self.variant.Foo == other.variant.Foo, + } + } + } + } + + impl Eq for MyTagUnion {} + "# ) );