Merge pull request #5257 from roc-lang/glue-recursive-tag-unions

Glue recursive tag unions
This commit is contained in:
Folkert de Vries 2023-04-06 18:53:41 +02:00 committed by GitHub
commit 3b28e897c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 724 additions and 123 deletions

View file

@ -4586,4 +4586,10 @@ mod test {
let target_info = TargetInfo::default_x86_64();
assert_eq!(Layout::VOID_NAKED.stack_size(&interner, target_info), 0);
}
#[test]
fn align_u128_in_tag_union() {
let interner = STLayoutInterner::with_capacity(4, TargetInfo::default_x86_64());
assert_eq!(interner.alignment_bytes(Layout::U128), 16);
}
}

View file

@ -61,7 +61,7 @@ convertTypesToFile = \types ->
generateSingleTagStruct buf types name tagName payload
TagUnion (NonNullableUnwrapped { name, tagName, payload }) ->
generateRecursiveTagUnion buf types id name [{ name: tagName, payload: Some payload }] 0 0 None
generateNonNullableUnwrapped buf types name tagName payload 0 0 None
Function rocFn ->
if rocFn.isToplevel then
@ -354,9 +354,10 @@ deriveDebugTagUnion : Str, Types, Str, List { name : Str, payload : [Some TypeId
deriveDebugTagUnion = \buf, types, tagUnionType, tags ->
checks =
List.walk tags "" \accum, { name: tagName, payload } ->
type = when payload is
Some id -> typeName types id
None -> "()"
type =
when payload is
Some id -> typeName types id
None -> "()"
"""
\(accum)
@ -389,7 +390,6 @@ deriveEqTagUnion = \buf, tagUnionType ->
impl Eq for \(tagUnionType) {}
"""
derivePartialEqTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialEqTagUnion = \buf, tagUnionType, tags ->
checks =
@ -487,9 +487,9 @@ deriveHashTagUnion = \buf, tagUnionType, tags ->
generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
generateConstructorFunctions = \buf, types, tagUnionType, tags ->
buf
|> Str.concat "\n\nimpl \(tagUnionType) {"
|> \b -> List.walk tags b \accum, r -> generateConstructorFunction accum types tagUnionType r.name r.payload
|> Str.concat "\n}\n\n"
|> Str.concat "\n\nimpl \(tagUnionType) {"
|> \b -> List.walk tags b \accum, r -> generateConstructorFunction accum types tagUnionType r.name r.payload
|> Str.concat "\n}\n\n"
generateConstructorFunction : Str, Types, Str, Str, [Some TypeId, None] -> Str
generateConstructorFunction = \buf, types, tagUnionType, name, optPayload ->
@ -534,9 +534,9 @@ generateConstructorFunction = \buf, types, tagUnionType, name, optPayload ->
generateDestructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
generateDestructorFunctions = \buf, types, tagUnionType, tags ->
buf
|> Str.concat "\n\nimpl \(tagUnionType) {"
|> \b -> List.walk tags b \accum, r -> generateDestructorFunction accum types tagUnionType r.name r.payload
|> Str.concat "\n}\n\n"
|> Str.concat "\n\nimpl \(tagUnionType) {"
|> \b -> List.walk tags b \accum, r -> generateDestructorFunction accum types tagUnionType r.name r.payload
|> Str.concat "\n}\n\n"
generateDestructorFunction : Str, Types, Str, Str, [Some TypeId, None] -> Str
generateDestructorFunction = \buf, types, tagUnionType, name, optPayload ->
@ -557,7 +557,6 @@ generateDestructorFunction = \buf, types, tagUnionType, name, optPayload ->
take =
if canDeriveCopy types shape then
"unsafe { self.payload.\(name) }"
else
"unsafe { core::mem::ManuallyDrop::take(&mut self.payload.\(name)) }"
@ -585,6 +584,25 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
max = \a, b -> if a >= b then a else b
alignOfUnion =
List.walk tags 1 \accum, { payload } ->
when payload is
Some payloadId -> max accum (Types.alignment types payloadId)
None -> accum
alignOfUnionStr = Num.toStr alignOfUnion
sizeOfUnionStr =
List.walk tags 1 \accum, { payload } ->
when payload is
Some payloadId -> max accum (Types.size types payloadId)
None -> accum
|> nextMultipleOf alignOfUnion
|> Num.toStr
sizeOfSelf = Num.toStr (Types.size types id)
alignOfSelf = Num.toStr (Types.alignment types id)
# TODO: this value can be different than the alignment of `id`
align =
List.walk tags 1 \accum, { payload } ->
@ -597,11 +615,16 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
|> generateDiscriminant types discriminantName tagNames discriminantSize
|> Str.concat "#[repr(C, align(\(align)))]\npub union \(unionName) {\n"
|> \b -> List.walk tags b (generateUnionField types)
|> generateTagUnionSizer types id tags
|> Str.concat
"""
}
const _SIZE_CHECK_\(unionName): () = assert!(core::mem::size_of::<\(unionName)>() == \(sizeOfUnionStr));
const _ALIGN_CHECK_\(unionName): () = assert!(core::mem::align_of::<\(unionName)>() == \(alignOfUnionStr));
const _SIZE_CHECK_\(escapedName): () = assert!(core::mem::size_of::<\(escapedName)>() == \(sizeOfSelf));
const _ALIGN_CHECK_\(escapedName): () = assert!(core::mem::align_of::<\(escapedName)>() == \(alignOfSelf));
impl \(escapedName) {
\(discriminantDocComment)
pub fn discriminant(&self) -> \(discriminantName) {
@ -664,22 +687,414 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
else
b
generateRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, _discriminantOffset, _nullTagIndex ->
generateNonNullableUnwrapped = \buf, types, name, tagName, payload, discriminantSize, _discriminantOffset, _nullTagIndex ->
escapedName = escapeKW name
discriminantName = "discriminant_\(escapedName)"
payloadFields =
when Types.shape types payload is
TagUnionPayload { fields } ->
when fields is
HasNoClosure xs -> List.map xs .id
HasClosure xs -> List.map xs .id
_ ->
[]
payloadFieldNames =
commaSeparated "" payloadFields \_, i ->
n = Num.toStr i
"f\(n)"
constructorArguments =
commaSeparated "" payloadFields \id, i ->
n = Num.toStr i
type = typeName types id
"f\(n): \(type)"
debugFields =
payloadFields
|> List.mapWithIndex \_, i ->
n = Num.toStr i
".field(&node.f\(n))"
|> Str.joinWith ""
buf1 = buf |> generateDiscriminant types discriminantName [tagName] discriminantSize
"""
\(buf1)
#[repr(transparent)]
#[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
pub struct \(escapedName)(roc_std::RocBox<\(name)_\(tagName)>);
impl \(escapedName) {
pub fn \(tagName)(\(constructorArguments)) -> Self {
let payload = \(name)_\(tagName) { \(payloadFieldNames) };
Self(roc_std::RocBox::new(payload))
}
}
impl core::fmt::Debug for \(escapedName) {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let node = &self.0;
f.debug_tuple("\(escapedName)::\(tagName)")\(debugFields).finish()
}
}
"""
generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSize, _discriminantOffset, nullTagIndex ->
escapedName = escapeKW tagUnionName
discriminantName = "discriminant_\(escapedName)"
tagNames = List.map tags \{ name: n } -> n
# self = "(&*self.union_pointer())"
# selfMut = "(&mut *self.union_pointer())"
# other = "(&*other.union_pointer())"
unionName = "union_\(escapedName)"
discriminants =
tagNames
|> Str.joinWith ", "
|> \b -> "[ \(b) ]"
nullTagId =
when nullTagIndex is
Some index ->
n = Num.toStr index
"discriminants[\(n)]"
None ->
"""
unreachable!("this pointer cannot be NULL")
"""
isFunction = \{ name: tagName, payload: optPayload }, index ->
payloadFields =
when optPayload is
Some payload ->
when Types.shape types payload is
TagUnionPayload { fields } ->
when fields is
HasNoClosure xs -> List.map xs .id
HasClosure xs -> List.map xs .id
_ ->
[]
None ->
[]
payloadFieldNames =
commaSeparated "" payloadFields \_, i ->
n = Num.toStr i
"f\(n)"
constructorArguments =
commaSeparated "" payloadFields \payloadId, i ->
n = Num.toStr i
type = typeName types payloadId
"f\(n): \(type)"
fixManuallyDrop =
when optPayload is
Some payload ->
shape = Types.shape types payload
if canDeriveCopy types shape then
"payload"
else
"core::mem::ManuallyDrop::new(payload)"
None ->
"payload"
if Some (Num.intCast index) == nullTagIndex then
"""
pub fn is_\(tagName)(&self) -> bool {
matches!(self.discriminant(), discriminant_\(escapedName)::\(tagName))
}
pub fn \(tagName)(\(constructorArguments)) -> Self {
Self(std::ptr::null_mut())
}
"""
else
"""
pub fn is_\(tagName)(&self) -> bool {
matches!(self.discriminant(), discriminant_\(escapedName)::\(tagName))
}
pub fn \(tagName)(\(constructorArguments)) -> Self {
let tag_id = discriminant_\(escapedName)::\(tagName);
let payload = \(escapedName)_\(tagName) { \(payloadFieldNames) } ;
let union_payload = union_\(escapedName) { \(tagName): \(fixManuallyDrop) };
let ptr = unsafe { roc_std::RocBox::leak(roc_std::RocBox::new(union_payload)) };
Self((ptr as usize | tag_id as usize) as *mut _)
}
"""
constructors =
tags
|> List.mapWithIndex isFunction
|> Str.joinWith "\n\n"
cloneCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => Self::\(tagName)(),
"""
else
"""
\(tagName) => {
let tag_id = discriminant_\(escapedName)::\(tagName);
let payload_union = unsafe { self.ptr_read_union() };
let payload = union_\(escapedName) {
\(tagName): unsafe { payload_union.\(tagName).clone() },
};
let ptr = unsafe { roc_std::RocBox::leak(roc_std::RocBox::new(payload)) };
Self((ptr as usize | tag_id as usize) as *mut _)
},
"""
cloneCases =
tags
|> List.mapWithIndex cloneCase
|> Str.joinWith "\n"
partialEqCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => true,
"""
else
"""
\(tagName) => {
let payload_union1 = unsafe { self.ptr_read_union() };
let payload_union2 = unsafe { other.ptr_read_union() };
unsafe {
payload_union1.\(tagName) == payload_union2.\(tagName)
}
},
"""
partialEqCases =
tags
|> List.mapWithIndex partialEqCase
|> Str.joinWith "\n"
debugCase = \{ name: tagName, payload: optPayload }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => f.debug_tuple("\(escapedName)::\(tagName)").finish(),
"""
else
payloadFields =
when optPayload is
Some payload ->
when Types.shape types payload is
TagUnionPayload { fields } ->
when fields is
HasNoClosure xs -> List.map xs .id
HasClosure xs -> List.map xs .id
_ ->
[]
None ->
[]
debugFields =
payloadFields
|> List.mapWithIndex \_, i ->
n = Num.toStr i
".field(&payload_union.\(tagName).f\(n))"
|> Str.joinWith ""
"""
\(tagName) => {
let payload_union = unsafe { self.ptr_read_union() };
unsafe {
f.debug_tuple("\(escapedName)::\(tagName)")\(debugFields).finish()
}
},
"""
debugCases =
tags
|> List.mapWithIndex debugCase
|> Str.joinWith "\n"
hashCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => {}
"""
else
"""
\(tagName) => {
let payload_union = unsafe { self.ptr_read_union() };
unsafe { payload_union.\(tagName).hash(state) };
},
"""
hashCases =
tags
|> List.mapWithIndex hashCase
|> Str.joinWith "\n"
partialOrdCase = \{ name: tagName }, index ->
if Some (Num.intCast index) == nullTagIndex then
"""
\(tagName) => std::cmp::Ordering::Equal,
"""
else
"""
\(tagName) => {
let payload_union1 = unsafe { self.ptr_read_union() };
let payload_union2 = unsafe { other.ptr_read_union() };
unsafe {
payload_union1.\(tagName).cmp(&payload_union2.\(tagName))
}
},
"""
partialOrdCases =
tags
|> List.mapWithIndex partialOrdCase
|> Str.joinWith "\n"
sizeOfSelf = Num.toStr (Types.size types id)
alignOfSelf = Num.toStr (Types.alignment types id)
buf
|> generateDiscriminant types discriminantName tagNames discriminantSize
|> Str.concat
"""
#[repr(transparent)]
pub struct \(escapedName) {
pointer: *mut \(unionName),
pub struct \(escapedName)(*mut \(unionName));
const _SIZE_CHECK_\(escapedName): () = assert!(core::mem::size_of::<\(escapedName)>() == \(sizeOfSelf));
const _ALIGN_CHECK_\(escapedName): () = assert!(core::mem::align_of::<\(escapedName)>() == \(alignOfSelf));
impl \(escapedName) {
fn discriminant(&self) -> discriminant_\(escapedName) {
let discriminants = {
use \(discriminantName)::*;
\(discriminants)
};
if self.0.is_null() {
\(nullTagId)
} else {
match std::mem::size_of::<usize>() {
4 => discriminants[self.0 as usize & 0b011],
8 => discriminants[self.0 as usize & 0b111],
_ => unreachable!(),
}
}
}
unsafe fn ptr_read_union(&self) -> core::mem::ManuallyDrop<union_\(escapedName)> {
debug_assert!(!self.0.is_null());
let mask = match std::mem::size_of::<usize>() {
4 => !0b011usize,
8 => !0b111usize,
_ => unreachable!(),
};
let ptr = ((self.0 as usize) & mask) as *mut union_\(escapedName);
core::mem::ManuallyDrop::new(unsafe { std::ptr::read(ptr) })
}
\(constructors)
}
impl Clone for \(escapedName) {
fn clone(&self) -> Self {
use discriminant_\(escapedName)::*;
let discriminant = self.discriminant();
match discriminant {
\(cloneCases)
}
}
}
impl PartialEq for \(escapedName) {
fn eq(&self, other: &Self) -> bool {
use discriminant_\(escapedName)::*;
if self.discriminant() != other.discriminant() {
return false;
}
match self.discriminant() {
\(partialEqCases)
}
}
}
impl Eq for \(escapedName) {}
impl core::fmt::Debug for \(escapedName) {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use discriminant_\(escapedName)::*;
match self.discriminant() {
\(debugCases)
}
}
}
impl core::hash::Hash for \(escapedName) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
use discriminant_\(escapedName)::*;
self.discriminant().hash(state);
match self.discriminant() {
\(hashCases)
}
}
}
impl PartialOrd for \(escapedName) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(<Self as Ord>::cmp(self, other))
}
}
impl Ord for \(escapedName) {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use discriminant_\(escapedName)::*;
use std::cmp::Ordering::*;
match self.discriminant().cmp(&other.discriminant()) {
Less => Less,
Greater => Greater,
Equal => unsafe {
match self.discriminant() {
\(partialOrdCases)
}
},
}
}
}
#[repr(C)]
@ -688,7 +1103,6 @@ generateRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, _disc
|> \b -> List.walk tags b (generateUnionField types)
|> generateTagUnionSizer types id tags
|> Str.concat "}\n\n"
|> Str.concat "// TODO: Recursive TagUnion impls\n\n"
generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, discriminantSize, indents ->
if discriminantSize == 0 then
@ -794,8 +1208,169 @@ generateUnionField = \types ->
# use unit as the payload
Str.concat accum "\(indent)\(escapedFieldName): (),\n"
generateNullableUnwrapped = \buf, _types, _id, _name, _nullTag, _nonNullTag, _nonNullPayload, _whichTagIsNull ->
Str.concat buf "// TODO: TagUnion NullableUnwrapped\n\n"
commaSeparated : Str, List a, (a, Nat -> Str) -> Str
commaSeparated = \buf, items, step ->
length = List.len items
List.walk items { buf, count: 0 } \accum, item ->
if accum.count + 1 == length then
{ buf: Str.concat accum.buf (step item accum.count), count: length }
else
{ buf: Str.concat accum.buf (step item accum.count) |> Str.concat ", ", count: accum.count + 1 }
|> .buf
generateNullableUnwrapped : Str, Types, TypeId, Str, Str, Str, TypeId, [FirstTagIsNull, SecondTagIsNull] -> Str
generateNullableUnwrapped = \buf, types, tagUnionid, name, nullTag, nonNullTag, nonNullPayload, whichTagIsNull ->
payloadFields =
when Types.shape types nonNullPayload is
TagUnionPayload { fields } ->
when fields is
HasNoClosure xs -> List.map xs .id
HasClosure xs -> List.map xs .id
_ ->
[]
payloadFieldNames =
commaSeparated "" payloadFields \_, i ->
n = Num.toStr i
"f\(n)"
constructorArguments =
commaSeparated "" payloadFields \id, i ->
n = Num.toStr i
type = typeName types id
"f\(n): \(type)"
debugFields =
payloadFields
|> List.mapWithIndex \_, i ->
n = Num.toStr i
".field(&node.f\(n))"
|> Str.joinWith ""
discriminant =
when whichTagIsNull is
FirstTagIsNull ->
"""
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum discriminant_\(name) {
\(nullTag) = 0,
\(nonNullTag) = 1,
}
"""
SecondTagIsNull ->
"""
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum discriminant_\(name) {
\(nonNullTag) = 0,
\(nullTag) = 1,
}
"""
sizeOfSelf = Num.toStr (Types.size types tagUnionid)
alignOfSelf = Num.toStr (Types.alignment types tagUnionid)
"""
\(buf)
#[derive(PartialOrd, Ord)]
#[repr(C)]
pub struct \(name)(*mut \(name)_\(nonNullTag));
\(discriminant)
const _SIZE_CHECK_\(name): () = assert!(core::mem::size_of::<\(name)>() == \(sizeOfSelf));
const _ALIGN_CHECK_\(name): () = assert!(core::mem::align_of::<\(name)>() == \(alignOfSelf));
impl \(name) {
pub fn \(nullTag)() -> Self {
Self(core::ptr::null_mut())
}
pub fn \(nonNullTag)(\(constructorArguments)) -> Self {
let payload = \(name)_\(nonNullTag) { \(payloadFieldNames) };
let ptr = unsafe { roc_std::RocBox::leak(roc_std::RocBox::new(payload)) };
Self(ptr)
}
pub fn discriminant(&self) -> discriminant_\(name) {
if self.is_\(nullTag)() {
discriminant_\(name)::\(nullTag)
} else {
discriminant_\(name)::\(nonNullTag)
}
}
pub fn is_\(nullTag)(&self) -> bool {
self.0.is_null()
}
pub fn is_\(nonNullTag)(&self) -> bool {
!self.0.is_null()
}
}
impl core::fmt::Debug for \(name) {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if self.is_\(nullTag)() {
f.debug_tuple("\(name)::\(nullTag)").finish()
} else {
let node = core::mem::ManuallyDrop::new(unsafe { std::ptr::read(self.0) });
f.debug_tuple("\(name)::\(nonNullTag)")\(debugFields).finish()
}
}
}
impl Clone for \(name) {
fn clone(&self) -> Self {
if self.is_\(nullTag)() {
Self::\(nullTag)()
} else {
use std::ops::Deref;
let node_ref = core::mem::ManuallyDrop::new(unsafe { std::ptr::read(self.0) });
let payload : \(name)_\(nonNullTag) = (node_ref.deref()).clone();
let ptr = unsafe { roc_std::RocBox::leak(roc_std::RocBox::new(payload)) };
Self(ptr)
}
}
}
impl PartialEq for \(name) {
fn eq(&self, other: &Self) -> bool {
if self.discriminant() != other.discriminant() {
return false;
}
if self.is_\(nullTag)() {
return true;
}
let payload1 = core::mem::ManuallyDrop::new(unsafe { std::ptr::read(self.0) });
let payload2 = core::mem::ManuallyDrop::new(unsafe { std::ptr::read(other.0) });
payload1 == payload2
}
}
impl Eq for \(name) {}
impl core::hash::Hash for \(name) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.discriminant().hash(state);
if self.is_\(nonNullTag)() {
let payload = core::mem::ManuallyDrop::new(unsafe { std::ptr::read(self.0) });
payload.hash(state);
}
}
}
"""
generateSingleTagStruct = \buf, types, name, tagName, payload ->
# Store single-tag unions as structs rather than enums,
@ -1013,9 +1588,10 @@ generateDeriveStr = \buf, types, type, includeDebug ->
else
b
deriveDebug = when includeDebug is
IncludeDebug -> Bool.true
ExcludeDebug -> Bool.false
deriveDebug =
when includeDebug is
IncludeDebug -> Bool.true
ExcludeDebug -> Bool.false
buf
|> Str.concat "#[derive(Clone, "
@ -1034,8 +1610,7 @@ canDerivePartialEq = \types, type ->
canDerivePartialEq types runtimeRepresentation
Unsized -> Bool.false
Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true
Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true
RocStr -> Bool.true
RocList inner | RocSet inner | RocBox inner ->
innerType = Types.shape types inner
@ -1047,7 +1622,25 @@ canDerivePartialEq = \types, type ->
canDerivePartialEq types kType && canDerivePartialEq types vType
TagUnion (NullableUnwrapped _) | TagUnion (NullableWrapped _) | TagUnion (Recursive _) | TagUnion (NonNullableUnwrapped _) | RecursivePointer _ -> crash "TODO"
TagUnion (Recursive { tags }) ->
List.all tags \{ payload } ->
when payload is
None -> Bool.true
Some id -> canDerivePartialEq types (Types.shape types id)
TagUnion (NullableWrapped { tags }) ->
List.all tags \{ payload } ->
when payload is
None -> Bool.true
Some id -> canDerivePartialEq types (Types.shape types id)
TagUnion (NonNullableUnwrapped { payload }) ->
canDerivePartialEq types (Types.shape types payload)
TagUnion (NullableUnwrapped { nonNullPayload }) ->
canDerivePartialEq types (Types.shape types nonNullPayload)
RecursivePointer _ -> Bool.true
TagUnion (SingleTagStruct { payload: HasNoClosure fields }) ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
@ -1061,8 +1654,10 @@ canDerivePartialEq = \types, type ->
None -> Bool.true
RocResult okId errId ->
canDerivePartialEq types (Types.shape types okId)
&& canDerivePartialEq types (Types.shape types errId)
okShape = Types.shape types okId
errShape = Types.shape types errId
canDerivePartialEq types okShape && canDerivePartialEq types errShape
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id)
@ -1083,8 +1678,7 @@ canDeriveCopy = \types, type ->
# unsized values are heap-allocated
Unsized -> Bool.false
Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true
Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true
RocStr | RocList _ | RocDict _ _ | RocSet _ | RocBox _ | TagUnion (NullableUnwrapped _) | TagUnion (NullableWrapped _) | TagUnion (Recursive _) | TagUnion (NonNullableUnwrapped _) | RecursivePointer _ -> Bool.false
TagUnion (SingleTagStruct { payload: HasNoClosure fields }) ->
List.all fields \{ id } -> canDeriveCopy types (Types.shape types id)
@ -1111,7 +1705,7 @@ canDeriveCopy = \types, type ->
cannotDeriveDefault = \types, type ->
when type is
Unit | Unsized | EmptyTagUnion | TagUnion _ | RocResult _ _ | RecursivePointer _ | Function _ -> Bool.true
RocStr | Bool | Num _ | TagUnionPayload { fields: HasClosure _ } -> Bool.false
RocStr | Bool | Num _ | TagUnionPayload { fields: HasClosure _ } -> Bool.false
RocList id | RocSet id | RocBox id ->
cannotDeriveDefault types (Types.shape types id)
@ -1120,7 +1714,6 @@ cannotDeriveDefault = \types, type ->
|| cannotDeriveCopy types (Types.shape types valId)
Struct { fields: HasClosure _ } -> Bool.true
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
List.any fields \{ id } -> cannotDeriveDefault types (Types.shape types id)
@ -1391,3 +1984,8 @@ escapeKW = \input ->
"r#\(input)"
else
input
nextMultipleOf = \lhs, rhs ->
when lhs % rhs is
0 -> lhs
r -> lhs + (rhs - r)

View file

@ -72,16 +72,20 @@ impl Types {
pub fn with_capacity(cap: usize, target_info: TargetInfo) -> Self {
let mut types = Vec::with_capacity(cap);
let mut sizes = Vec::with_capacity(cap);
let mut aligns = Vec::with_capacity(cap);
types.push(RocType::Unit);
sizes.push(1);
aligns.push(1);
Self {
target: target_info,
types,
sizes,
aligns,
types_by_name: FnvHashMap::with_capacity_and_hasher(10, Default::default()),
entry_points: Vec::new(),
sizes: Vec::new(),
aligns: Vec::new(),
deps: VecMap::with_capacity(cap),
}
}
@ -542,14 +546,19 @@ impl Types {
}
}
debug_assert_eq!(self.types.len(), self.sizes.len());
debug_assert_eq!(self.types.len(), self.aligns.len());
let id = TypeId(self.types.len());
assert!(id.0 <= TypeId::MAX.0);
let size = interner.stack_size(layout);
let align = interner.alignment_bytes(layout);
self.types.push(typ);
self.sizes
.push(interner.stack_size_without_alignment(layout));
self.aligns.push(interner.alignment_bytes(layout));
self.sizes.push(size);
self.aligns.push(align);
id
}

View file

@ -22,5 +22,5 @@ Job : [
Rbt : { default : Job }
mainForHost : Rbt
mainForHost = main
mainForHost : {} -> Rbt
mainForHost = \{} -> main

View file

@ -1,29 +1,19 @@
mod test_glue;
use indoc::indoc;
use test_glue::Rbt;
extern "C" {
#[link_name = "roc__mainForHost_1_exposed_generic"]
fn roc_main(_: *mut Rbt);
}
// use test_glue::Rbt;
#[no_mangle]
pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<Rbt> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Verify that it has all the expected traits.
assert!(tag_union == tag_union); // PartialEq
assert!(tag_union.clone() == tag_union.clone()); // Clone
assert!(tag_union.partial_cmp(&tag_union) == Some(Ordering::Equal)); // PartialOrd
@ -57,7 +47,7 @@ use std::os::raw::c_char;
#[no_mangle]
pub unsafe extern "C" fn roc_alloc(size: usize, _alignment: u32) -> *mut c_void {
return libc::malloc(size);
libc::malloc(size)
}
#[no_mangle]

View file

@ -7,5 +7,5 @@ platform "test-platform"
Expr : [String Str, Concat Expr Expr]
mainForHost : Expr
mainForHost = main
mainForHost : {} -> Expr
mainForHost = \{} -> main

View file

@ -13,13 +13,7 @@ pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<Expr> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Verify that it has all the expected traits.

View file

@ -7,5 +7,5 @@ platform "test-platform"
StrConsList : [Nil, Cons Str StrConsList]
mainForHost : StrConsList
mainForHost = main
mainForHost : {} -> StrConsList
mainForHost = \{} -> main

View file

@ -13,17 +13,12 @@ pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<StrConsList> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Verify that it has all the expected traits.
assert!(tag_union == tag_union); // PartialEq
assert!(tag_union.clone() == tag_union.clone()); // Clone
assert!(tag_union.partial_cmp(&tag_union) == Some(Ordering::Equal)); // PartialOrd
@ -38,8 +33,8 @@ pub extern "C" fn rust_main() -> i32 {
"#
),
tag_union,
StrConsList::Cons("small str".into(), StrConsList::Nil),
StrConsList::Nil,
StrConsList::Cons("small str".into(), StrConsList::Nil()),
StrConsList::Nil(),
); // Debug
let mut set = HashSet::new();

View file

@ -7,5 +7,5 @@ platform "test-platform"
StrFingerTree : [Empty, Single Str, More Str StrFingerTree]
mainForHost : StrFingerTree
mainForHost = main
mainForHost : {} -> StrFingerTree
mainForHost = \{} -> main

View file

@ -14,26 +14,20 @@ pub extern "C" fn rust_main() -> i32 {
use std::cmp::Ordering;
use std::collections::hash_set::HashSet;
let tag_union = unsafe {
let mut ret: core::mem::MaybeUninit<StrFingerTree> = core::mem::MaybeUninit::uninit();
roc_main(ret.as_mut_ptr());
ret.assume_init()
};
let tag_union = test_glue::mainForHost(());
// Eq
assert!(StrFingerTree::Empty == StrFingerTree::Empty);
assert!(StrFingerTree::Empty != tag_union);
assert!(StrFingerTree::Empty() == StrFingerTree::Empty());
assert!(StrFingerTree::Empty() != tag_union);
assert!(
StrFingerTree::Single(RocStr::from("foo")) == StrFingerTree::Single(RocStr::from("foo"))
);
assert!(StrFingerTree::Single(RocStr::from("foo")) != StrFingerTree::Empty);
assert!(StrFingerTree::Single(RocStr::from("foo")) != StrFingerTree::Empty());
// Verify that it has all the expected traits.
assert!(tag_union == tag_union); // PartialEq
assert!(tag_union.clone() == tag_union.clone()); // Clone
assert!(StrFingerTree::Empty.clone() == StrFingerTree::Empty); // Clone
assert!(StrFingerTree::Empty().clone() == StrFingerTree::Empty()); // Clone
assert!(tag_union.partial_cmp(&tag_union) == Some(Ordering::Equal)); // PartialOrd
assert!(tag_union.cmp(&tag_union) == Ordering::Equal); // Ord
@ -53,9 +47,9 @@ pub extern "C" fn rust_main() -> i32 {
"small str".into(),
StrFingerTree::Single("other str".into()),
),
StrFingerTree::More("small str".into(), StrFingerTree::Empty),
StrFingerTree::More("small str".into(), StrFingerTree::Empty()),
StrFingerTree::Single("small str".into()),
StrFingerTree::Empty,
StrFingerTree::Empty(),
); // Debug
let mut set = HashSet::new();

View file

@ -23,14 +23,14 @@ pub extern "C" fn rust_main() -> i32 {
assert!(tag_union.cmp(&tag_union) == Ordering::Equal); // Ord
println!(
"tag_union was: {:?}\n`Foo \"small str\"` is: {:?}\n`Foo \"A long enough string to not be small\"` is: {:?}\n`Bar 123` is: {:?}\n`Baz` is: {:?}\n`Blah 456` is: {:?}",
tag_union,
NonRecursive::Foo("small str".into()),
NonRecursive::Foo("A long enough string to not be small".into()),
NonRecursive::Bar(123),
NonRecursive::Baz(),
NonRecursive::Blah(456),
); // Debug
"tag_union was: {:?}\n`Foo \"small str\"` is: {:?}\n`Foo \"A long enough string to not be small\"` is: {:?}\n`Bar 123` is: {:?}\n`Baz` is: {:?}\n`Blah 456` is: {:?}",
tag_union,
NonRecursive::Foo("small str".into()),
NonRecursive::Foo("A long enough string to not be small".into()),
NonRecursive::Bar(123),
NonRecursive::Baz(),
NonRecursive::Blah(456),
); // Debug
let mut set = HashSet::new();

View file

@ -92,33 +92,33 @@ mod glue_cli_run {
`Baz` is: NonRecursive::Baz(())
`Blah 456` is: NonRecursive::Blah(456)
"#),
// nullable_wrapped:"nullable-wrapped" => indoc!(r#"
// tag_union was: StrFingerTree::More("foo", StrFingerTree::More("bar", StrFingerTree::Empty))
// `More "small str" (Single "other str")` is: StrFingerTree::More("small str", StrFingerTree::Single("other str"))
// `More "small str" Empty` is: StrFingerTree::More("small str", StrFingerTree::Empty)
// `Single "small str"` is: StrFingerTree::Single("small str")
// `Empty` is: StrFingerTree::Empty
// "#),
// nullable_unwrapped:"nullable-unwrapped" => indoc!(r#"
// tag_union was: StrConsList::Cons("World!", StrConsList::Cons("Hello ", StrConsList::Nil))
// `Cons "small str" Nil` is: StrConsList::Cons("small str", StrConsList::Nil)
// `Nil` is: StrConsList::Nil
// "#),
// nonnullable_unwrapped:"nonnullable-unwrapped" => indoc!(r#"
// tag_union was: StrRoseTree::Tree(ManuallyDrop { value: StrRoseTree_Tree { f0: "root", f1: [StrRoseTree::Tree(ManuallyDrop { value: StrRoseTree_Tree { f0: "leaf1", f1: [] } }), StrRoseTree::Tree(ManuallyDrop { value: StrRoseTree_Tree { f0: "leaf2", f1: [] } })] } })
// Tree "foo" [] is: StrRoseTree::Tree(ManuallyDrop { value: StrRoseTree_Tree { f0: "foo", f1: [] } })
// "#),
// basic_recursive_union:"basic-recursive-union" => indoc!(r#"
// tag_union was: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
// `Concat (String "Hello, ") (String "World!")` is: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
// `String "this is a test"` is: Expr::String("this is a test")
// "#),
// advanced_recursive_union:"advanced-recursive-union" => indoc!(r#"
// rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { tool: Tool::SystemTool(R4 { name: "test", num: 42 }) }), inputFiles: ["foo"] }) }
// "#),
// list_recursive_union:"list-recursive-union" => indoc!(r#"
// rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { args: [], tool: Tool::SystemTool(R3 { name: "test" }) }), inputFiles: ["foo"], job: [] }) }
// "#),
nullable_wrapped:"nullable-wrapped" => indoc!(r#"
tag_union was: StrFingerTree::More("foo", StrFingerTree::More("bar", StrFingerTree::Empty))
`More "small str" (Single "other str")` is: StrFingerTree::More("small str", StrFingerTree::Single("other str"))
`More "small str" Empty` is: StrFingerTree::More("small str", StrFingerTree::Empty)
`Single "small str"` is: StrFingerTree::Single("small str")
`Empty` is: StrFingerTree::Empty
"#),
nullable_unwrapped:"nullable-unwrapped" => indoc!(r#"
tag_union was: StrConsList::Cons("World!", StrConsList::Cons("Hello ", StrConsList::Nil))
`Cons "small str" Nil` is: StrConsList::Cons("small str", StrConsList::Nil)
`Nil` is: StrConsList::Nil
"#),
nonnullable_unwrapped:"nonnullable-unwrapped" => indoc!(r#"
tag_union was: StrRoseTree::Tree("root", [StrRoseTree::Tree("leaf1", []), StrRoseTree::Tree("leaf2", [])])
Tree "foo" [] is: StrRoseTree::Tree("foo", [])
"#),
basic_recursive_union:"basic-recursive-union" => indoc!(r#"
tag_union was: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
`Concat (String "Hello, ") (String "World!")` is: Expr::Concat(Expr::String("Hello, "), Expr::String("World!"))
`String "this is a test"` is: Expr::String("this is a test")
"#),
advanced_recursive_union:"advanced-recursive-union" => indoc!(r#"
rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { tool: Tool::SystemTool(R4 { name: "test", num: 42 }) }), inputFiles: ["foo"] }) }
"#),
list_recursive_union:"list-recursive-union" => indoc!(r#"
rbt was: Rbt { default: Job::Job(R1 { command: Command::Command(R2 { args: [], tool: Tool::SystemTool(R3 { name: "test" }) }), inputFiles: ["foo"], job: [] }) }
"#),
multiple_modules:"multiple-modules" => indoc!(r#"
combined was: Combined { s1: DepStr1::S("hello"), s2: DepStr2::R("world") }
"#),

View file

@ -33,7 +33,7 @@ impl<T> RocBox<T> {
let contents = unsafe {
let contents_ptr = ptr.cast::<u8>().add(alignment).cast::<T>();
*contents_ptr = contents;
core::ptr::write(contents_ptr, contents);
// We already verified that the original alloc pointer was non-null,
// and this one is the alloc pointer with `alignment` bytes added to it,
@ -44,6 +44,15 @@ impl<T> RocBox<T> {
Self { contents }
}
/// # Safety
///
/// The box must be unique in order to leak it safely
pub unsafe fn leak(self) -> *mut T {
let ptr = self.contents.as_ptr() as *mut T;
core::mem::forget(self);
ptr
}
#[inline(always)]
fn alloc_alignment() -> usize {
mem::align_of::<T>().max(mem::align_of::<Storage>())
@ -110,6 +119,12 @@ where
}
}
impl<T: core::hash::Hash> core::hash::Hash for RocBox<T> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.contents.hash(state)
}
}
impl<T> Debug for RocBox<T>
where
T: Debug,