make all glue tests run!

This commit is contained in:
Folkert 2023-04-05 23:19:17 +02:00
parent 84d61a0a64
commit 92c2931678
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
9 changed files with 431 additions and 73 deletions

View file

@ -4586,4 +4586,10 @@ mod test {
let target_info = TargetInfo::default_x86_64();
assert_eq!(Layout::VOID_NAKED.stack_size(&interner, target_info), 0);
}
#[test]
fn align_u128_in_tag_union() {
let interner = STLayoutInterner::with_capacity(4, TargetInfo::default_x86_64());
assert_eq!(interner.alignment_bytes(Layout::U128), 16);
}
}

View file

@ -584,6 +584,25 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
max = \a, b -> if a >= b then a else b
alignOfUnion =
List.walk tags 1 \accum, { payload } ->
when payload is
Some payloadId -> max accum (Types.alignment types payloadId)
None -> accum
alignOfUnionStr = Num.toStr alignOfUnion
sizeOfUnionStr =
List.walk tags 1 \accum, { payload } ->
when payload is
Some payloadId -> max accum (Types.size types payloadId)
None -> accum
|> nextMultipleOf alignOfUnion
|> Num.toStr
sizeOfSelf = Num.toStr (Types.size types id)
alignOfSelf = Num.toStr (Types.alignment types id)
# TODO: this value can be different than the alignment of `id`
align =
List.walk tags 1 \accum, { payload } ->
@ -596,11 +615,16 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
|> generateDiscriminant types discriminantName tagNames discriminantSize
|> Str.concat "#[repr(C, align(\(align)))]\npub union \(unionName) {\n"
|> \b -> List.walk tags b (generateUnionField types)
|> generateTagUnionSizer types id tags
|> Str.concat
"""
}
const _SIZE_CHECK_\(unionName): () = assert!(core::mem::size_of::<\(unionName)>() == \(sizeOfUnionStr));
const _ALIGN_CHECK_\(unionName): () = assert!(core::mem::align_of::<\(unionName)>() == \(alignOfUnionStr));
const _SIZE_CHECK_\(escapedName): () = assert!(core::mem::size_of::<\(escapedName)>() == \(sizeOfSelf));
const _ALIGN_CHECK_\(escapedName): () = assert!(core::mem::align_of::<\(escapedName)>() == \(alignOfSelf));
impl \(escapedName) {
\(discriminantDocComment)
pub fn discriminant(&self) -> \(discriminantName) {
@ -720,8 +744,8 @@ generateNonNullableUnwrapped = \buf, types, name, tagName, payload, discriminant
}
"""
generateRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, _discriminantOffset, _nullTagIndex ->
escapedName = escapeKW name
generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSize, _discriminantOffset, nullTagIndex ->
escapedName = escapeKW tagUnionName
discriminantName = "discriminant_\(escapedName)"
tagNames = List.map tags \{ name: n } -> n
# self = "(&*self.union_pointer())"
@ -729,13 +753,348 @@ generateRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, _disc
# other = "(&*other.union_pointer())"
unionName = "union_\(escapedName)"
discriminants =
tagNames
|> Str.joinWith ", "
|> \b -> "[ \(b) ]"
nullTagId =
when nullTagIndex is
Some index ->
n = Num.toStr index
"discriminants[\(n)]"
None ->
"""
unreachable!("this pointer cannot be NULL")
"""
isFunction = \{ name: tagName, payload: optPayload }, index ->
payloadFields =
when optPayload is
Some payload ->
when Types.shape types payload is
TagUnionPayload { fields } ->
when fields is
HasNoClosure xs -> List.map xs .id
HasClosure xs -> List.map xs .id
_ ->
[]
None ->
[]
payloadFieldNames =
commaSeparated "" payloadFields \_, i ->
n = Num.toStr i
"f\(n)"
constructorArguments =
commaSeparated "" payloadFields \payloadId, i ->
n = Num.toStr i
type = typeName types payloadId
"f\(n): \(type)"
fixManuallyDrop =
when optPayload is
Some payload ->
shape = Types.shape types payload
if canDeriveCopy types shape then
"payload"
else
"core::mem::ManuallyDrop::new(payload)"
None ->
"payload"
if Some (Num.intCast index) == nullTagIndex then
"""
pub fn is_\(tagName)(&self) -> bool {
matches!(self.discriminant(), discriminant_\(escapedName)::\(tagName))
}
pub fn \(tagName)(\(constructorArguments)) -> Self {
Self(std::ptr::null_mut())
}
"""
else
"""
pub fn is_\(tagName)(&self) -> bool {
matches!(self.discriminant(), discriminant_\(escapedName)::\(tagName))
}
pub fn \(tagName)(\(constructorArguments)) -> Self {
let tag_id = discriminant_\(escapedName)::\(tagName);
let payload = \(escapedName)_\(tagName) { \(payloadFieldNames) } ;
let union_payload = union_\(escapedName) { \(tagName): \(fixManuallyDrop) };
let ptr = unsafe { roc_std::RocBox::leak(roc_std::RocBox::new(union_payload)) };
Self((ptr as usize | tag_id as usize) as *mut _)
}
"""
constructors =
tags
|> List.mapWithIndex isFunction
|> Str.joinWith "\n\n"
cloneCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => Self::\(tagName)(),
"""
else
"""
\(tagName) => {
let tag_id = discriminant_\(escapedName)::\(tagName);
let payload_union = unsafe { self.ptr_read_union() };
let payload = union_\(escapedName) {
\(tagName): unsafe { payload_union.\(tagName).clone() },
};
let ptr = unsafe { roc_std::RocBox::leak(roc_std::RocBox::new(payload)) };
Self((ptr as usize | tag_id as usize) as *mut _)
},
"""
cloneCases =
tags
|> List.mapWithIndex cloneCase
|> Str.joinWith "\n"
partialEqCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => true,
"""
else
"""
\(tagName) => {
let payload_union1 = unsafe { self.ptr_read_union() };
let payload_union2 = unsafe { other.ptr_read_union() };
unsafe {
payload_union1.\(tagName) == payload_union2.\(tagName)
}
},
"""
partialEqCases =
tags
|> List.mapWithIndex partialEqCase
|> Str.joinWith "\n"
debugCase = \{ name: tagName, payload: optPayload }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => f.debug_tuple("\(escapedName)::\(tagName)").finish(),
"""
else
payloadFields =
when optPayload is
Some payload ->
when Types.shape types payload is
TagUnionPayload { fields } ->
when fields is
HasNoClosure xs -> List.map xs .id
HasClosure xs -> List.map xs .id
_ ->
[]
None ->
[]
debugFields =
payloadFields
|> List.mapWithIndex \_, i ->
n = Num.toStr i
".field(&payload_union.\(tagName).f\(n))"
|> Str.joinWith ""
"""
\(tagName) => {
let payload_union = unsafe { self.ptr_read_union() };
unsafe {
f.debug_tuple("\(escapedName)::\(tagName)")\(debugFields).finish()
}
},
"""
debugCases =
tags
|> List.mapWithIndex debugCase
|> Str.joinWith "\n"
hashCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => {}
"""
else
"""
\(tagName) => {
let payload_union = unsafe { self.ptr_read_union() };
unsafe { payload_union.\(tagName).hash(state) };
},
"""
hashCases =
tags
|> List.mapWithIndex hashCase
|> Str.joinWith "\n"
partialOrdCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => std::cmp::Ordering::Equal,
"""
else
"""
\(tagName) => {
let payload_union1 = unsafe { self.ptr_read_union() };
let payload_union2 = unsafe { other.ptr_read_union() };
unsafe {
payload_union1.\(tagName).cmp(&payload_union2.\(tagName))
}
},
"""
partialOrdCases =
tags
|> List.mapWithIndex partialOrdCase
|> Str.joinWith "\n"
sizeOfSelf = Num.toStr (Types.size types id)
alignOfSelf = Num.toStr (Types.alignment types id)
buf
|> generateDiscriminant types discriminantName tagNames discriminantSize
|> Str.concat
"""
#[repr(transparent)]
pub struct \(escapedName) {
pointer: roc_std::RocBox<\(unionName)>,
pub struct \(escapedName)(*mut \(unionName));
const _SIZE_CHECK_\(escapedName): () = assert!(core::mem::size_of::<\(escapedName)>() == \(sizeOfSelf));
const _ALIGN_CHECK_\(escapedName): () = assert!(core::mem::align_of::<\(escapedName)>() == \(alignOfSelf));
impl \(escapedName) {
fn discriminant(&self) -> discriminant_\(escapedName) {
let discriminants = {
use \(discriminantName)::*;
\(discriminants)
};
if self.0.is_null() {
\(nullTagId)
} else {
match std::mem::size_of::<usize>() {
4 => discriminants[self.0 as usize & 0b011],
8 => discriminants[self.0 as usize & 0b111],
_ => unreachable!(),
}
}
}
unsafe fn ptr_read_union(&self) -> core::mem::ManuallyDrop<union_\(escapedName)> {
debug_assert!(!self.0.is_null());
let mask = match std::mem::size_of::<usize>() {
4 => !0b011usize,
8 => !0b111usize,
_ => unreachable!(),
};
let ptr = ((self.0 as usize) & mask) as *mut union_\(escapedName);
core::mem::ManuallyDrop::new(unsafe { std::ptr::read(ptr) })
}
\(constructors)
}
impl Clone for \(escapedName) {
fn clone(&self) -> Self {
use discriminant_\(escapedName)::*;
let discriminant = self.discriminant();
match discriminant {
\(cloneCases)
}
}
}
impl PartialEq for \(escapedName) {
fn eq(&self, other: &Self) -> bool {
use discriminant_\(escapedName)::*;
if self.discriminant() != other.discriminant() {
return false;
}
match self.discriminant() {
\(partialEqCases)
}
}
}
impl Eq for \(escapedName) {}
impl core::fmt::Debug for \(escapedName) {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use discriminant_\(escapedName)::*;
match self.discriminant() {
\(debugCases)
}
}
}
impl core::hash::Hash for \(escapedName) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
use discriminant_\(escapedName)::*;
self.discriminant().hash(state);
match self.discriminant() {
\(hashCases)
}
}
}
impl PartialOrd for \(escapedName) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(<Self as Ord>::cmp(self, other))
}
}
impl Ord for \(escapedName) {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use discriminant_\(escapedName)::*;
use std::cmp::Ordering::*;
match self.discriminant().cmp(&other.discriminant()) {
Less => Less,
Greater => Greater,
Equal => unsafe {
match self.discriminant() {
\(partialOrdCases)
}
},
}
}
}
#[repr(C)]
@ -744,7 +1103,6 @@ generateRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, _disc
|> \b -> List.walk tags b (generateUnionField types)
|> generateTagUnionSizer types id tags
|> Str.concat "}\n\n"
|> Str.concat "// TODO: Recursive TagUnion impls\n\n"
generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, discriminantSize, indents ->
if discriminantSize == 0 then
@ -861,7 +1219,7 @@ commaSeparated = \buf, items, step ->
|> .buf
generateNullableUnwrapped : Str, Types, TypeId, Str, Str, Str, TypeId, [FirstTagIsNull, SecondTagIsNull] -> Str
generateNullableUnwrapped = \buf, types, _id, name, nullTag, nonNullTag, nonNullPayload, whichTagIsNull ->
generateNullableUnwrapped = \buf, types, tagUnionid, name, nullTag, nonNullTag, nonNullPayload, whichTagIsNull ->
payloadFields =
when Types.shape types nonNullPayload is
TagUnionPayload { fields } ->
@ -910,6 +1268,9 @@ generateNullableUnwrapped = \buf, types, _id, name, nullTag, nonNullTag, nonNull
}
"""
sizeOfSelf = Num.toStr (Types.size types tagUnionid)
alignOfSelf = Num.toStr (Types.alignment types tagUnionid)
"""
\(buf)
@ -919,6 +1280,9 @@ generateNullableUnwrapped = \buf, types, _id, name, nullTag, nonNullTag, nonNull
\(discriminant)
const _SIZE_CHECK_\(name): () = assert!(core::mem::size_of::<\(name)>() == \(sizeOfSelf));
const _ALIGN_CHECK_\(name): () = assert!(core::mem::align_of::<\(name)>() == \(alignOfSelf));
impl \(name) {
pub fn \(nullTag)() -> Self {
Self(core::ptr::null_mut())
@ -1620,3 +1984,8 @@ escapeKW = \input ->
"r#\(input)"
else
input
nextMultipleOf = \lhs, rhs ->
when lhs % rhs is
0 -> lhs
r -> lhs + (rhs - r)

View file

@ -72,16 +72,20 @@ impl Types {
pub fn with_capacity(cap: usize, target_info: TargetInfo) -> Self {
let mut types = Vec::with_capacity(cap);
let mut sizes = Vec::with_capacity(cap);
let mut aligns = Vec::with_capacity(cap);
types.push(RocType::Unit);
sizes.push(1);
aligns.push(1);
Self {
target: target_info,
types,
sizes,
aligns,
types_by_name: FnvHashMap::with_capacity_and_hasher(10, Default::default()),
entry_points: Vec::new(),
sizes: Vec::new(),
aligns: Vec::new(),
deps: VecMap::with_capacity(cap),
}
}
@ -542,14 +546,19 @@ impl Types {
}
}
debug_assert_eq!(self.types.len(), self.sizes.len());
debug_assert_eq!(self.types.len(), self.aligns.len());
let id = TypeId(self.types.len());
assert!(id.0 <= TypeId::MAX.0);
let size = interner.stack_size(layout);
let align = interner.alignment_bytes(layout);
self.types.push(typ);
self.sizes
.push(interner.stack_size_without_alignment(layout));
self.aligns.push(interner.alignment_bytes(layout));
self.sizes.push(size);
self.aligns.push(align);
id
}
@ -660,11 +669,7 @@ impl From<&Types> for roc_type::Types {
deps,
entrypoints,
sizes: types.sizes.as_slice().into(),
types: types
.types
.iter()
.map(|t| roc_type::RocType::from(t))
.collect(),
types: types.types.iter().map(roc_type::RocType::from).collect(),
typesByName: types_by_name,
target: types.target.into(),
}

View file

@ -22,5 +22,5 @@ Job : [
Rbt : { default : Job }
mainForHost : Rbt
mainForHost = main
mainForHost : {} -> Rbt
mainForHost = \{} -> main

View file

@ -1,29 +1,19 @@
mod test_glue;
use indoc::indoc;
use test_glue::Rbt;
extern "C" {
#[link_name = "roc__mainForHost_1_exposed_generic"]
fn roc_main(_: *mut Rbt);
}
// use test_glue::Rbt;
#[no_mangle]
pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<Rbt> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Verify that it has all the expected traits.
assert!(tag_union == tag_union); // PartialEq
assert!(tag_union.clone() == tag_union.clone()); // Clone
assert!(tag_union.partial_cmp(&tag_union) == Some(Ordering::Equal)); // PartialOrd
@ -57,7 +47,7 @@ use std::os::raw::c_char;
#[no_mangle]
pub unsafe extern "C" fn roc_alloc(size: usize, _alignment: u32) -> *mut c_void {
return libc::malloc(size);
libc::malloc(size)
}
#[no_mangle]

View file

@ -7,5 +7,5 @@ platform "test-platform"
Expr : [String Str, Concat Expr Expr]
mainForHost : Expr
mainForHost = main
mainForHost : {} -> Expr
mainForHost = \{} -> main

View file

@ -13,13 +13,7 @@ pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<Expr> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Verify that it has all the expected traits.

View file

@ -14,26 +14,20 @@ pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<StrFingerTree> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Eq
assert!(StrFingerTree::Empty == StrFingerTree::Empty);
assert!(StrFingerTree::Empty != tag_union);
assert!(StrFingerTree::Empty() == StrFingerTree::Empty());
assert!(StrFingerTree::Empty() != tag_union);
assert!(
StrFingerTree::Single(RocStr::from("foo")) == StrFingerTree::Single(RocStr::from("foo"))
);
assert!(StrFingerTree::Single(RocStr::from("foo")) != StrFingerTree::Empty);
assert!(StrFingerTree::Single(RocStr::from("foo")) != StrFingerTree::Empty());
// Verify that it has all the expected traits.
assert!(tag_union == tag_union); // PartialEq
assert!(tag_union.clone() == tag_union.clone()); // Clone
assert!(StrFingerTree::Empty.clone() == StrFingerTree::Empty); // Clone
assert!(StrFingerTree::Empty().clone() == StrFingerTree::Empty()); // Clone
assert!(tag_union.partial_cmp(&tag_union) == Some(Ordering::Equal)); // PartialOrd
assert!(tag_union.cmp(&tag_union) == Ordering::Equal); // Ord
@ -53,9 +47,9 @@ pub extern "C" fn rust_main() -> i32 {
"small str".into(),
StrFingerTree::Single("other str".into()),
),
StrFingerTree::More("small str".into(), StrFingerTree::Empty),
StrFingerTree::More("small str".into(), StrFingerTree::Empty()),
StrFingerTree::Single("small str".into()),
StrFingerTree::Empty,
StrFingerTree::Empty(),
); // Debug
let mut set = HashSet::new();

View file

@ -92,13 +92,13 @@ mod glue_cli_run {
`Baz` is: NonRecursive::Baz(())
`Blah 456` is: NonRecursive::Blah(456)
"#),
// nullable_wrapped:"nullable-wrapped" => indoc!(r#"
// tag_union was: StrFingerTree::More("foo", StrFingerTree::More("bar", StrFingerTree::Empty))
// `More "small str" (Single "other str")` is: StrFingerTree::More("small str", StrFingerTree::Single("other str"))
// `More "small str" Empty` is: StrFingerTree::More("small str", StrFingerTree::Empty)
// `Single "small str"` is: StrFingerTree::Single("small str")
// `Empty` is: StrFingerTree::Empty
// "#),
nullable_wrapped:"nullable-wrapped" => indoc!(r#"
tag_union was: StrFingerTree::More("foo", StrFingerTree::More("bar", StrFingerTree::Empty))
`More "small str" (Single "other str")` is: StrFingerTree::More("small str", StrFingerTree::Single("other str"))
`More "small str" Empty` is: StrFingerTree::More("small str", StrFingerTree::Empty)
`Single "small str"` is: StrFingerTree::Single("small str")
`Empty` is: StrFingerTree::Empty
"#),
nullable_unwrapped:"nullable-unwrapped" => indoc!(r#"
tag_union was: StrConsList::Cons("World!", StrConsList::Cons("Hello ", StrConsList::Nil))
`Cons "small str" Nil` is: StrConsList::Cons("small str", StrConsList::Nil)
@ -108,17 +108,17 @@ mod glue_cli_run {
tag_union was: StrRoseTree::Tree("root", [StrRoseTree::Tree("leaf1", []), StrRoseTree::Tree("leaf2", [])])
Tree "foo" [] is: StrRoseTree::Tree("foo", [])
"#),
// basic_recursive_union:"basic-recursive-union" => indoc!(r#"
// tag_union was: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
// `Concat (String "Hello, ") (String "World!")` is: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
// `String "this is a test"` is: Expr::String("this is a test")
// "#),
// advanced_recursive_union:"advanced-recursive-union" => indoc!(r#"
// rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { tool: Tool::SystemTool(R4 { name: "test", num: 42 }) }), inputFiles: ["foo"] }) }
// "#),
// list_recursive_union:"list-recursive-union" => indoc!(r#"
// rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { args: [], tool: Tool::SystemTool(R3 { name: "test" }) }), inputFiles: ["foo"], job: [] }) }
// "#),
basic_recursive_union:"basic-recursive-union" => indoc!(r#"
tag_union was: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
`Concat (String "Hello, ") (String "World!")` is: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
`String "this is a test"` is: Expr::String("this is a test")
"#),
advanced_recursive_union:"advanced-recursive-union" => indoc!(r#"
rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { tool: Tool::SystemTool(R4 { name: "test", num: 42 }) }), inputFiles: ["foo"] }) }
"#),
list_recursive_union:"list-recursive-union" => indoc!(r#"
rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { args: [], tool: Tool::SystemTool(R3 { name: "test" }) }), inputFiles: ["foo"], job: [] }) }
"#),
multiple_modules:"multiple-modules" => indoc!(r#"
combined was: Combined { s1: DepStr1::S("hello"), s2: DepStr2::R("world") }
"#),