mirror of
https://github.com/roc-lang/roc.git
synced 2025-08-03 19:58:18 +00:00
re-enable non-recursive tag union tests
This commit is contained in:
parent
995e14747b
commit
bf4ac1cbf6
6 changed files with 216 additions and 48 deletions
|
@ -319,10 +319,175 @@ generateEnumTagsDebug = \name ->
|
|||
\accum, tagName ->
|
||||
Str.concat accum "\(indent)\(indent)\(indent)Self::\(tagName) => f.write_str(\"\(name)::\(tagName)\"),\n"
|
||||
|
||||
deriveCloneTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
deriveCloneTagUnion = \buf, tagUnionType, tags ->
|
||||
clones =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => union_\(tagUnionType) {
|
||||
\(tagName): self.payload.\(tagName).clone(),
|
||||
},
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl Clone for \(tagUnionType) {
|
||||
fn clone(&self) -> Self {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
let payload = unsafe {
|
||||
match self.discriminant {\(clones)
|
||||
}
|
||||
};
|
||||
|
||||
Self {
|
||||
discriminant: self.discriminant,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
deriveDebugTagUnion : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
deriveDebugTagUnion = \buf, types, tagUnionType, tags ->
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName, payload } ->
|
||||
type = when payload is
|
||||
Some id -> typeName types id
|
||||
None -> "()"
|
||||
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => {
|
||||
let field: &\(type) = &self.payload.\(tagName);
|
||||
f.debug_tuple("\(tagUnionType)::\(tagName)").field(field).finish()
|
||||
},
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl core::fmt::Debug for \(tagUnionType) {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
deriveEqTagUnion : Str, Str -> Str
|
||||
deriveEqTagUnion = \buf, tagUnionType ->
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl Eq for \(tagUnionType) {}
|
||||
"""
|
||||
|
||||
|
||||
derivePartialEqTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
derivePartialEqTagUnion = \buf, tagUnionType, tags ->
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName) == other.payload.\(tagName),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl PartialEq for \(tagUnionType) {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
if self.discriminant != other.discriminant {
|
||||
return false;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
deriveOrdTagUnion : Str, Str -> Str
|
||||
deriveOrdTagUnion = \buf, tagUnionType ->
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl Ord for \(tagUnionType) {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(other).unwrap()
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
derivePartialOrdTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
derivePartialOrdTagUnion = \buf, tagUnionType, tags ->
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName).partial_cmp(&other.payload.\(tagName)),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl PartialOrd for \(tagUnionType) {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
use std::cmp::Ordering::*;
|
||||
|
||||
match self.discriminant.cmp(&other.discriminant) {
|
||||
Less => Option::Some(Less),
|
||||
Greater => Option::Some(Greater),
|
||||
Equal => unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
deriveHashTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
deriveHashTagUnion = \buf, tagUnionType, tags ->
|
||||
checks =
|
||||
List.walk tags "" \accum, { name: tagName } ->
|
||||
"""
|
||||
\(accum)
|
||||
\(tagName) => self.payload.\(tagName).hash(state),
|
||||
"""
|
||||
|
||||
"""
|
||||
\(buf)
|
||||
|
||||
impl core::hash::Hash for \(tagUnionType) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
use discriminant_\(tagUnionType)::*;
|
||||
|
||||
unsafe {
|
||||
match self.discriminant {\(checks)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
generateConstructorFunctions = \buf, types, tagUnionType, tags ->
|
||||
buf
|
||||
|> Str.concat "\n\nimpl \(tagUnionType) {\n"
|
||||
|> Str.concat "\n\nimpl \(tagUnionType) {"
|
||||
|> \b -> List.walk tags b \accum, r -> generateConstructorFunction accum types tagUnionType r.name r.payload
|
||||
|> Str.concat "\n}\n\n"
|
||||
|
||||
|
@ -369,7 +534,7 @@ generateConstructorFunction = \buf, types, tagUnionType, name, optPayload ->
|
|||
generateDestructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
|
||||
generateDestructorFunctions = \buf, types, tagUnionType, tags ->
|
||||
buf
|
||||
|> Str.concat "\n\nimpl \(tagUnionType) {\n"
|
||||
|> Str.concat "\n\nimpl \(tagUnionType) {"
|
||||
|> \b -> List.walk tags b \accum, r -> generateDestructorFunction accum types tagUnionType r.name r.payload
|
||||
|> Str.concat "\n}\n\n"
|
||||
|
||||
|
@ -418,9 +583,19 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
|
|||
tagNames = List.map tags \{ name: n } -> n
|
||||
selfMut = "self"
|
||||
|
||||
max = \a, b -> if a >= b then a else b
|
||||
|
||||
# TODO: this value can be different than the alignment of `id`
|
||||
align =
|
||||
List.walk tags 1 \accum, { payload } ->
|
||||
when payload is
|
||||
Some payloadId -> max accum (Types.alignment types payloadId)
|
||||
None -> accum
|
||||
|> Num.toStr
|
||||
|
||||
buf
|
||||
|> generateDiscriminant types discriminantName tagNames discriminantSize
|
||||
|> Str.concat "#[repr(C)]\npub union \(unionName) {\n"
|
||||
|> Str.concat "#[repr(C, align(\(align)))]\npub union \(unionName) {\n"
|
||||
|> \b -> List.walk tags b (generateUnionField types)
|
||||
|> generateTagUnionSizer types id tags
|
||||
|> Str.concat
|
||||
|
@ -449,7 +624,6 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
|
|||
|
||||
|
||||
"""
|
||||
|> Str.concat "// TODO: NonRecursive TagUnion constructor impls\n\n"
|
||||
|> Str.concat
|
||||
"""
|
||||
#[repr(C)]
|
||||
|
@ -458,6 +632,13 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
|
|||
discriminant: discriminant_\(escapedName),
|
||||
}
|
||||
"""
|
||||
|> deriveCloneTagUnion escapedName tags
|
||||
|> deriveDebugTagUnion types escapedName tags
|
||||
|> deriveEqTagUnion escapedName
|
||||
|> derivePartialEqTagUnion escapedName tags
|
||||
|> deriveOrdTagUnion escapedName
|
||||
|> derivePartialOrdTagUnion escapedName tags
|
||||
|> deriveHashTagUnion escapedName tags
|
||||
|> generateDestructorFunctions types escapedName tags
|
||||
|> generateConstructorFunctions types escapedName tags
|
||||
|> \b ->
|
||||
|
@ -526,7 +707,7 @@ generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, disc
|
|||
|> writeTagImpls tags discriminantName indents \name, payload ->
|
||||
when payload is
|
||||
Some id if cannotDeriveCopy types (Types.shape types id) ->
|
||||
"unsafe {{ core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) }},"
|
||||
"unsafe { core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) },"
|
||||
|
||||
_ ->
|
||||
# If it had no payload, or if the payload had no pointers,
|
||||
|
|
|
@ -13,5 +13,5 @@ platform "test-platform"
|
|||
# that all variants have.
|
||||
NonRecursive : [Foo Str, Bar U128, Blah I32, Baz]
|
||||
|
||||
mainForHost : NonRecursive
|
||||
mainForHost = main
|
||||
mainForHost : {} -> NonRecursive
|
||||
mainForHost = \{} -> main
|
||||
|
|
|
@ -12,13 +12,7 @@ pub extern "C" fn rust_main() -> i32 {
|
|||
use std::cmp::Ordering;
|
||||
use std::collections::hash_set::HashSet;
|
||||
|
||||
let tag_union = unsafe {
|
||||
let mut ret: core::mem::MaybeUninit<NonRecursive> = core::mem::MaybeUninit::uninit();
|
||||
|
||||
roc_main(ret.as_mut_ptr());
|
||||
|
||||
ret.assume_init()
|
||||
};
|
||||
let tag_union = test_glue::mainForHost(());
|
||||
|
||||
// Verify that it has all the expected traits.
|
||||
|
||||
|
@ -29,14 +23,14 @@ pub extern "C" fn rust_main() -> i32 {
|
|||
assert!(tag_union.cmp(&tag_union) == Ordering::Equal); // Ord
|
||||
|
||||
println!(
|
||||
"tag_union was: {:?}\n`Foo \"small str\"` is: {:?}\n`Foo \"A long enough string to not be small\"` is: {:?}\n`Bar 123` is: {:?}\n`Baz` is: {:?}\n`Blah 456` is: {:?}",
|
||||
tag_union,
|
||||
NonRecursive::Foo("small str".into()),
|
||||
NonRecursive::Foo("A long enough string to not be small".into()),
|
||||
NonRecursive::Bar(123.into()),
|
||||
NonRecursive::Baz,
|
||||
NonRecursive::Blah(456),
|
||||
); // Debug
|
||||
"tag_union was: {:?}\n`Foo \"small str\"` is: {:?}\n`Foo \"A long enough string to not be small\"` is: {:?}\n`Bar 123` is: {:?}\n`Baz` is: {:?}\n`Blah 456` is: {:?}",
|
||||
tag_union,
|
||||
NonRecursive::Foo("small str".into()),
|
||||
NonRecursive::Foo("A long enough string to not be small".into()),
|
||||
NonRecursive::Bar(123),
|
||||
NonRecursive::Baz(),
|
||||
NonRecursive::Blah(456),
|
||||
); // Debug
|
||||
|
||||
let mut set = HashSet::new();
|
||||
|
||||
|
|
|
@ -11,5 +11,5 @@ platform "test-platform"
|
|||
# to store the discriminant. We have to generate glue code accordingly!
|
||||
NonRecursive : [Foo Str, Bar I64, Blah I32, Baz]
|
||||
|
||||
mainForHost : NonRecursive
|
||||
mainForHost = main
|
||||
mainForHost : {} -> NonRecursive
|
||||
mainForHost = \{} -> main
|
||||
|
|
|
@ -10,14 +10,7 @@ pub extern "C" fn rust_main() -> i32 {
|
|||
use std::cmp::Ordering;
|
||||
use std::collections::hash_set::HashSet;
|
||||
|
||||
let tag_union = unsafe {
|
||||
let mut ret: core::mem::MaybeUninit<test_glue::NonRecursive> =
|
||||
core::mem::MaybeUninit::uninit();
|
||||
|
||||
roc_main(ret.as_mut_ptr());
|
||||
|
||||
ret.assume_init()
|
||||
};
|
||||
let tag_union = test_glue::mainForHost(());
|
||||
|
||||
// Verify that it has all the expected traits.
|
||||
|
||||
|
@ -32,7 +25,7 @@ pub extern "C" fn rust_main() -> i32 {
|
|||
tag_union,
|
||||
test_glue::NonRecursive::Foo("small str".into()),
|
||||
test_glue::NonRecursive::Bar(123),
|
||||
test_glue::NonRecursive::Baz,
|
||||
test_glue::NonRecursive::Baz(),
|
||||
test_glue::NonRecursive::Blah(456),
|
||||
); // Debug
|
||||
|
||||
|
|
|
@ -77,21 +77,21 @@ mod glue_cli_run {
|
|||
single_tag_union:"single-tag-union" => indoc!(r#"
|
||||
tag_union was: SingleTagUnion::OneTag
|
||||
"#),
|
||||
// union_with_padding:"union-with-padding" => indoc!(r#"
|
||||
// tag_union was: NonRecursive::Foo("This is a test")
|
||||
// `Foo "small str"` is: NonRecursive::Foo("small str")
|
||||
// `Foo "A long enough string to not be small"` is: NonRecursive::Foo("A long enough string to not be small")
|
||||
// `Bar 123` is: NonRecursive::Bar(123)
|
||||
// `Baz` is: NonRecursive::Baz
|
||||
// `Blah 456` is: NonRecursive::Blah(456)
|
||||
// "#),
|
||||
// union_without_padding:"union-without-padding" => indoc!(r#"
|
||||
// tag_union was: NonRecursive::Foo("This is a test")
|
||||
// `Foo "small str"` is: NonRecursive::Foo("small str")
|
||||
// `Bar 123` is: NonRecursive::Bar(123)
|
||||
// `Baz` is: NonRecursive::Baz
|
||||
// `Blah 456` is: NonRecursive::Blah(456)
|
||||
// "#),
|
||||
union_with_padding:"union-with-padding" => indoc!(r#"
|
||||
tag_union was: NonRecursive::Foo("This is a test")
|
||||
`Foo "small str"` is: NonRecursive::Foo("small str")
|
||||
`Foo "A long enough string to not be small"` is: NonRecursive::Foo("A long enough string to not be small")
|
||||
`Bar 123` is: NonRecursive::Bar(123)
|
||||
`Baz` is: NonRecursive::Baz(())
|
||||
`Blah 456` is: NonRecursive::Blah(456)
|
||||
"#),
|
||||
union_without_padding:"union-without-padding" => indoc!(r#"
|
||||
tag_union was: NonRecursive::Foo("This is a test")
|
||||
`Foo "small str"` is: NonRecursive::Foo("small str")
|
||||
`Bar 123` is: NonRecursive::Bar(123)
|
||||
`Baz` is: NonRecursive::Baz(())
|
||||
`Blah 456` is: NonRecursive::Blah(456)
|
||||
"#),
|
||||
// nullable_wrapped:"nullable-wrapped" => indoc!(r#"
|
||||
// tag_union was: StrFingerTree::More("foo", StrFingerTree::More("bar", StrFingerTree::Empty))
|
||||
// `More "small str" (Single "other str")` is: StrFingerTree::More("small str", StrFingerTree::Single("other str"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue