re-enable non-recursive tag union tests

This commit is contained in:
Folkert 2023-04-04 19:22:38 +02:00
parent 995e14747b
commit bf4ac1cbf6
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
6 changed files with 216 additions and 48 deletions

View file

@ -319,10 +319,175 @@ generateEnumTagsDebug = \name ->
\accum, tagName ->
Str.concat accum "\(indent)\(indent)\(indent)Self::\(tagName) => f.write_str(\"\(name)::\(tagName)\"),\n"
deriveCloneTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
deriveCloneTagUnion = \buf, tagUnionType, tags ->
clones =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => union_\(tagUnionType) {
\(tagName): self.payload.\(tagName).clone(),
},
"""
"""
\(buf)
impl Clone for \(tagUnionType) {
fn clone(&self) -> Self {
use discriminant_\(tagUnionType)::*;
let payload = unsafe {
match self.discriminant {\(clones)
}
};
Self {
discriminant: self.discriminant,
payload,
}
}
}
"""
deriveDebugTagUnion : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
deriveDebugTagUnion = \buf, types, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName, payload } ->
type = when payload is
Some id -> typeName types id
None -> "()"
"""
\(accum)
\(tagName) => {
let field: &\(type) = &self.payload.\(tagName);
f.debug_tuple("\(tagUnionType)::\(tagName)").field(field).finish()
},
"""
"""
\(buf)
impl core::fmt::Debug for \(tagUnionType) {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use discriminant_\(tagUnionType)::*;
unsafe {
match self.discriminant {\(checks)
}
}
}
}
"""
deriveEqTagUnion : Str, Str -> Str
deriveEqTagUnion = \buf, tagUnionType ->
"""
\(buf)
impl Eq for \(tagUnionType) {}
"""
derivePartialEqTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialEqTagUnion = \buf, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName) == other.payload.\(tagName),
"""
"""
\(buf)
impl PartialEq for \(tagUnionType) {
fn eq(&self, other: &Self) -> bool {
use discriminant_\(tagUnionType)::*;
if self.discriminant != other.discriminant {
return false;
}
unsafe {
match self.discriminant {\(checks)
}
}
}
}
"""
deriveOrdTagUnion : Str, Str -> Str
deriveOrdTagUnion = \buf, tagUnionType ->
"""
\(buf)
impl Ord for \(tagUnionType) {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap()
}
}
"""
derivePartialOrdTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialOrdTagUnion = \buf, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)),
"""
"""
\(buf)
impl PartialOrd for \(tagUnionType) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
use discriminant_\(tagUnionType)::*;
use std::cmp::Ordering::*;
match self.discriminant.cmp(&other.discriminant) {
Less => Option::Some(Less),
Greater => Option::Some(Greater),
Equal => unsafe {
match self.discriminant {\(checks)
}
},
}
}
}
"""
deriveHashTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
deriveHashTagUnion = \buf, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName } ->
"""
\(accum)
\(tagName) => self.payload.\(tagName).hash(state),
"""
"""
\(buf)
impl core::hash::Hash for \(tagUnionType) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
use discriminant_\(tagUnionType)::*;
unsafe {
match self.discriminant {\(checks)
}
}
}
}
"""
generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
generateConstructorFunctions = \buf, types, tagUnionType, tags ->
buf
|> Str.concat "\n\nimpl \(tagUnionType) {\n"
|> Str.concat "\n\nimpl \(tagUnionType) {"
|> \b -> List.walk tags b \accum, r -> generateConstructorFunction accum types tagUnionType r.name r.payload
|> Str.concat "\n}\n\n"
@ -369,7 +534,7 @@ generateConstructorFunction = \buf, types, tagUnionType, name, optPayload ->
generateDestructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
generateDestructorFunctions = \buf, types, tagUnionType, tags ->
buf
|> Str.concat "\n\nimpl \(tagUnionType) {\n"
|> Str.concat "\n\nimpl \(tagUnionType) {"
|> \b -> List.walk tags b \accum, r -> generateDestructorFunction accum types tagUnionType r.name r.payload
|> Str.concat "\n}\n\n"
@ -418,9 +583,19 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
tagNames = List.map tags \{ name: n } -> n
selfMut = "self"
max = \a, b -> if a >= b then a else b
# TODO: this value can be different than the alignment of `id`
align =
List.walk tags 1 \accum, { payload } ->
when payload is
Some payloadId -> max accum (Types.alignment types payloadId)
None -> accum
|> Num.toStr
buf
|> generateDiscriminant types discriminantName tagNames discriminantSize
|> Str.concat "#[repr(C)]\npub union \(unionName) {\n"
|> Str.concat "#[repr(C, align(\(align)))]\npub union \(unionName) {\n"
|> \b -> List.walk tags b (generateUnionField types)
|> generateTagUnionSizer types id tags
|> Str.concat
@ -449,7 +624,6 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
"""
|> Str.concat "// TODO: NonRecursive TagUnion constructor impls\n\n"
|> Str.concat
"""
#[repr(C)]
@ -458,6 +632,13 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
discriminant: discriminant_\(escapedName),
}
"""
|> deriveCloneTagUnion escapedName tags
|> deriveDebugTagUnion types escapedName tags
|> deriveEqTagUnion escapedName
|> derivePartialEqTagUnion escapedName tags
|> deriveOrdTagUnion escapedName
|> derivePartialOrdTagUnion escapedName tags
|> deriveHashTagUnion escapedName tags
|> generateDestructorFunctions types escapedName tags
|> generateConstructorFunctions types escapedName tags
|> \b ->
@ -526,7 +707,7 @@ generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, disc
|> writeTagImpls tags discriminantName indents \name, payload ->
when payload is
Some id if cannotDeriveCopy types (Types.shape types id) ->
"unsafe {{ core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) }},"
"unsafe { core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) },"
_ ->
# If it had no payload, or if the payload had no pointers,

View file

@ -13,5 +13,5 @@ platform "test-platform"
# that all variants have.
NonRecursive : [Foo Str, Bar U128, Blah I32, Baz]
mainForHost : NonRecursive
mainForHost = main
mainForHost : {} -> NonRecursive
mainForHost = \{} -> main

View file

@ -12,13 +12,7 @@ pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<NonRecursive> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Verify that it has all the expected traits.
@ -29,14 +23,14 @@ pub extern "C" fn rust_main() -> i32 {
assert!(tag_union.cmp(&tag_union) == Ordering::Equal); // Ord
println!(
"tag_union was: {:?}\n`Foo \"small str\"` is: {:?}\n`Foo \"A long enough string to not be small\"` is: {:?}\n`Bar 123` is: {:?}\n`Baz` is: {:?}\n`Blah 456` is: {:?}",
tag_union,
NonRecursive::Foo("small str".into()),
NonRecursive::Foo("A long enough string to not be small".into()),
NonRecursive::Bar(123.into()),
NonRecursive::Baz,
NonRecursive::Blah(456),
); // Debug
"tag_union was: {:?}\n`Foo \"small str\"` is: {:?}\n`Foo \"A long enough string to not be small\"` is: {:?}\n`Bar 123` is: {:?}\n`Baz` is: {:?}\n`Blah 456` is: {:?}",
tag_union,
NonRecursive::Foo("small str".into()),
NonRecursive::Foo("A long enough string to not be small".into()),
NonRecursive::Bar(123),
NonRecursive::Baz(),
NonRecursive::Blah(456),
); // Debug
let mut set = HashSet::new();

View file

@ -11,5 +11,5 @@ platform "test-platform"
# to store the discriminant. We have to generate glue code accordingly!
NonRecursive : [Foo Str, Bar I64, Blah I32, Baz]
mainForHost : NonRecursive
mainForHost = main
mainForHost : {} -> NonRecursive
mainForHost = \{} -> main

View file

@ -10,14 +10,7 @@ pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<test_glue::NonRecursive> =
core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Verify that it has all the expected traits.
@ -32,7 +25,7 @@ pub extern "C" fn rust_main() -> i32 {
tag_union,
test_glue::NonRecursive::Foo("small str".into()),
test_glue::NonRecursive::Bar(123),
test_glue::NonRecursive::Baz,
test_glue::NonRecursive::Baz(),
test_glue::NonRecursive::Blah(456),
); // Debug

View file

@ -77,21 +77,21 @@ mod glue_cli_run {
single_tag_union:"single-tag-union" => indoc!(r#"
tag_union was: SingleTagUnion::OneTag
"#),
// union_with_padding:"union-with-padding" => indoc!(r#"
// tag_union was: NonRecursive::Foo("This is a test")
// `Foo "small str"` is: NonRecursive::Foo("small str")
// `Foo "A long enough string to not be small"` is: NonRecursive::Foo("A long enough string to not be small")
// `Bar 123` is: NonRecursive::Bar(123)
// `Baz` is: NonRecursive::Baz
// `Blah 456` is: NonRecursive::Blah(456)
// "#),
// union_without_padding:"union-without-padding" => indoc!(r#"
// tag_union was: NonRecursive::Foo("This is a test")
// `Foo "small str"` is: NonRecursive::Foo("small str")
// `Bar 123` is: NonRecursive::Bar(123)
// `Baz` is: NonRecursive::Baz
// `Blah 456` is: NonRecursive::Blah(456)
// "#),
union_with_padding:"union-with-padding" => indoc!(r#"
tag_union was: NonRecursive::Foo("This is a test")
`Foo "small str"` is: NonRecursive::Foo("small str")
`Foo "A long enough string to not be small"` is: NonRecursive::Foo("A long enough string to not be small")
`Bar 123` is: NonRecursive::Bar(123)
`Baz` is: NonRecursive::Baz(())
`Blah 456` is: NonRecursive::Blah(456)
"#),
union_without_padding:"union-without-padding" => indoc!(r#"
tag_union was: NonRecursive::Foo("This is a test")
`Foo "small str"` is: NonRecursive::Foo("small str")
`Bar 123` is: NonRecursive::Bar(123)
`Baz` is: NonRecursive::Baz(())
`Blah 456` is: NonRecursive::Blah(456)
"#),
// nullable_wrapped:"nullable-wrapped" => indoc!(r#"
// tag_union was: StrFingerTree::More("foo", StrFingerTree::More("bar", StrFingerTree::Empty))
// `More "small str" (Single "other str")` is: StrFingerTree::More("small str", StrFingerTree::Single("other str"))