Situationally gen Eq/Ord/Hash glue for tag unions

This commit is contained in:
Richard Feldman 2023-08-10 12:50:07 -04:00
parent e4ba92c1a6
commit aed1c26e72
No known key found for this signature in database
GPG key ID: F1F21AA5B1D9E43B

View file

@ -426,16 +426,20 @@ deriveDebugTagUnion = \buf, types, tagUnionType, tags ->
} }
""" """
deriveEqTagUnion : Str, Str -> Str deriveEqTagUnion : Str, Types, Shape, Str -> Str
deriveEqTagUnion = \buf, tagUnionType -> deriveEqTagUnion = \buf, types, shape, tagUnionType ->
if canSupportEqHashOrd types shape then
""" """
\(buf) \(buf)
impl Eq for \(tagUnionType) {} impl Eq for \(tagUnionType) {}
""" """
else
buf
derivePartialEqTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str derivePartialEqTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialEqTagUnion = \buf, tagUnionType, tags -> derivePartialEqTagUnion = \buf, types, shape, tagUnionType, tags ->
if canSupportPartialEqOrd types shape then
checks = checks =
List.walk tags "" \accum, { name: tagName } -> List.walk tags "" \accum, { name: tagName } ->
""" """
@ -461,9 +465,12 @@ derivePartialEqTagUnion = \buf, tagUnionType, tags ->
} }
} }
""" """
else
buf
deriveOrdTagUnion : Str, Str -> Str deriveOrdTagUnion : Str, Types, Shape, Str -> Str
deriveOrdTagUnion = \buf, tagUnionType -> deriveOrdTagUnion = \buf, types, shape, tagUnionType ->
if canSupportEqHashOrd types shape then
""" """
\(buf) \(buf)
@ -473,9 +480,12 @@ deriveOrdTagUnion = \buf, tagUnionType ->
} }
} }
""" """
else
buf
derivePartialOrdTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str derivePartialOrdTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
derivePartialOrdTagUnion = \buf, tagUnionType, tags -> derivePartialOrdTagUnion = \buf, types, shape, tagUnionType, tags ->
if canSupportPartialEqOrd types shape then
checks = checks =
List.walk tags "" \accum, { name: tagName } -> List.walk tags "" \accum, { name: tagName } ->
""" """
@ -503,9 +513,12 @@ derivePartialOrdTagUnion = \buf, tagUnionType, tags ->
} }
} }
""" """
else
buf
deriveHashTagUnion : Str, Str, List { name : Str, payload : [Some TypeId, None] } -> Str deriveHashTagUnion : Str, Types, Shape, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
deriveHashTagUnion = \buf, tagUnionType, tags -> deriveHashTagUnion = \buf, types, shape, tagUnionType, tags ->
if canSupportEqHashOrd types shape then
checks = checks =
List.walk tags "" \accum, { name: tagName } -> List.walk tags "" \accum, { name: tagName } ->
""" """
@ -527,6 +540,8 @@ deriveHashTagUnion = \buf, tagUnionType, tags ->
} }
} }
""" """
else
buf
generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str generateConstructorFunctions : Str, Types, Str, List { name : Str, payload : [Some TypeId, None] } -> Str
generateConstructorFunctions = \buf, types, tagUnionType, tags -> generateConstructorFunctions = \buf, types, tagUnionType, tags ->
@ -646,6 +661,7 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
sizeOfSelf = Num.toStr (Types.size types id) sizeOfSelf = Num.toStr (Types.size types id)
alignOfSelf = Num.toStr (Types.alignment types id) alignOfSelf = Num.toStr (Types.alignment types id)
shape = Types.shape types id
# TODO: this value can be different than the alignment of `id` # TODO: this value can be different than the alignment of `id`
align = align =
@ -701,16 +717,16 @@ generateNonRecursiveTagUnion = \buf, types, id, name, tags, discriminantSize, di
""" """
|> deriveCloneTagUnion escapedName tags |> deriveCloneTagUnion escapedName tags
|> deriveDebugTagUnion types escapedName tags |> deriveDebugTagUnion types escapedName tags
|> deriveEqTagUnion escapedName |> deriveEqTagUnion types shape escapedName
|> derivePartialEqTagUnion escapedName tags |> derivePartialEqTagUnion types shape escapedName tags
|> deriveOrdTagUnion escapedName |> deriveOrdTagUnion types shape escapedName
|> derivePartialOrdTagUnion escapedName tags |> derivePartialOrdTagUnion types shape escapedName tags
|> deriveHashTagUnion escapedName tags |> deriveHashTagUnion types shape escapedName tags
|> generateDestructorFunctions types escapedName tags |> generateDestructorFunctions types escapedName tags
|> generateConstructorFunctions types escapedName tags |> generateConstructorFunctions types escapedName tags
|> \b -> |> \b ->
type = Types.shape types id type = Types.shape types id
if cannotDeriveCopy types type then if cannotSupportCopy types type then
# A custom drop impl is only needed when we can't derive copy. # A custom drop impl is only needed when we can't derive copy.
b b
|> Str.concat |> Str.concat
@ -942,7 +958,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|> Str.joinWith "\n" |> Str.joinWith "\n"
partialEqImpl = partialEqImpl =
if canDerivePartialEq types (Types.shape types id) then if canSupportPartialEqOrd types (Types.shape types id) then
""" """
impl PartialEq for \(escapedName) { impl PartialEq for \(escapedName) {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
@ -1027,7 +1043,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
hashImpl = hashImpl =
if canDerivePartialEq types (Types.shape types id) then if canSupportPartialEqOrd types (Types.shape types id) then
""" """
impl core::hash::Hash for \(escapedName) { impl core::hash::Hash for \(escapedName) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) { fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
@ -1067,7 +1083,7 @@ generateRecursiveTagUnion = \buf, types, id, tagUnionName, tags, discriminantSiz
|> Str.joinWith "\n" |> Str.joinWith "\n"
partialOrdImpl = partialOrdImpl =
if canDerivePartialEq types (Types.shape types id) then if canSupportPartialEqOrd types (Types.shape types id) then
""" """
impl PartialOrd for \(escapedName) { impl PartialOrd for \(escapedName) {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
@ -1198,7 +1214,7 @@ generateTagUnionDropPayload = \buf, types, selfMut, tags, discriminantName, disc
buf buf
|> writeTagImpls tags discriminantName indents \name, payload -> |> writeTagImpls tags discriminantName indents \name, payload ->
when payload is when payload is
Some id if cannotDeriveCopy types (Types.shape types id) -> Some id if cannotSupportCopy types (Types.shape types id) ->
"unsafe { core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) }," "unsafe { core::mem::ManuallyDrop::drop(&mut \(selfMut).payload.\(name)) },"
_ -> _ ->
@ -1272,7 +1288,7 @@ generateUnionField = \types ->
type = Types.shape types id type = Types.shape types id
fullTypeStr = fullTypeStr =
if cannotDeriveCopy types type then if cannotSupportCopy types type then
# types with pointers need ManuallyDrop # types with pointers need ManuallyDrop
# because rust unions don't (and can't) # because rust unions don't (and can't)
# know how to drop them automatically! # know how to drop them automatically!
@ -1673,54 +1689,58 @@ generateDeriveStr = \buf, types, type, includeDebug ->
buf buf
|> Str.concat "#[derive(Clone, " |> Str.concat "#[derive(Clone, "
|> condWrite (!(cannotDeriveCopy types type)) "Copy, " |> condWrite (!(cannotSupportCopy types type)) "Copy, "
|> condWrite (!(cannotDeriveDefault types type)) "Default, " |> condWrite (!(cannotSupportDefault types type)) "Default, "
|> condWrite deriveDebug "Debug, " |> condWrite deriveDebug "Debug, "
|> condWrite (canDerivePartialEq types type) "PartialEq, PartialOrd, " |> condWrite (canSupportPartialEqOrd types type) "PartialEq, PartialOrd, "
|> condWrite (!(hasFloat types type) && (canDerivePartialEq types type)) "Eq, Ord, Hash, " |> condWrite (canSupportEqHashOrd types type) "Eq, Ord, Hash, "
|> Str.concat ")]\n" |> Str.concat ")]\n"
canDerivePartialEq : Types, Shape -> Bool canSupportEqHashOrd : Types, Shape -> Bool
canDerivePartialEq = \types, type -> canSupportEqHashOrd = \types, type ->
!(hasFloat types type) && (canSupportPartialEqOrd types type)
canSupportPartialEqOrd : Types, Shape -> Bool
canSupportPartialEqOrd = \types, type ->
when type is when type is
Function rocFn -> Function rocFn ->
runtimeRepresentation = Types.shape types rocFn.lambdaSet runtimeRepresentation = Types.shape types rocFn.lambdaSet
canDerivePartialEq types runtimeRepresentation canSupportPartialEqOrd types runtimeRepresentation
Unsized -> Bool.false Unsized -> Bool.false
Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true
RocStr -> Bool.true RocStr -> Bool.true
RocList inner | RocSet inner | RocBox inner -> RocList inner | RocSet inner | RocBox inner ->
innerType = Types.shape types inner innerType = Types.shape types inner
canDerivePartialEq types innerType canSupportPartialEqOrd types innerType
RocDict k v -> RocDict k v ->
kType = Types.shape types k kType = Types.shape types k
vType = Types.shape types v vType = Types.shape types v
canDerivePartialEq types kType && canDerivePartialEq types vType canSupportPartialEqOrd types kType && canSupportPartialEqOrd types vType
TagUnion (Recursive { tags }) -> TagUnion (Recursive { tags }) ->
List.all tags \{ payload } -> List.all tags \{ payload } ->
when payload is when payload is
None -> Bool.true None -> Bool.true
Some id -> canDerivePartialEq types (Types.shape types id) Some id -> canSupportPartialEqOrd types (Types.shape types id)
TagUnion (NullableWrapped { tags }) -> TagUnion (NullableWrapped { tags }) ->
List.all tags \{ payload } -> List.all tags \{ payload } ->
when payload is when payload is
None -> Bool.true None -> Bool.true
Some id -> canDerivePartialEq types (Types.shape types id) Some id -> canSupportPartialEqOrd types (Types.shape types id)
TagUnion (NonNullableUnwrapped { payload }) -> TagUnion (NonNullableUnwrapped { payload }) ->
canDerivePartialEq types (Types.shape types payload) canSupportPartialEqOrd types (Types.shape types payload)
TagUnion (NullableUnwrapped { nonNullPayload }) -> TagUnion (NullableUnwrapped { nonNullPayload }) ->
canDerivePartialEq types (Types.shape types nonNullPayload) canSupportPartialEqOrd types (Types.shape types nonNullPayload)
RecursivePointer _ -> Bool.true RecursivePointer _ -> Bool.true
TagUnion (SingleTagStruct { payload: HasNoClosure fields }) -> TagUnion (SingleTagStruct { payload: HasNoClosure fields }) ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
TagUnion (SingleTagStruct { payload: HasClosure _ }) -> TagUnion (SingleTagStruct { payload: HasClosure _ }) ->
Bool.false Bool.false
@ -1728,23 +1748,23 @@ canDerivePartialEq = \types, type ->
TagUnion (NonRecursive { tags }) -> TagUnion (NonRecursive { tags }) ->
List.all tags \{ payload } -> List.all tags \{ payload } ->
when payload is when payload is
Some id -> canDerivePartialEq types (Types.shape types id) Some id -> canSupportPartialEqOrd types (Types.shape types id)
None -> Bool.true None -> Bool.true
RocResult okId errId -> RocResult okId errId ->
okShape = Types.shape types okId okShape = Types.shape types okId
errShape = Types.shape types errId errShape = Types.shape types errId
canDerivePartialEq types okShape && canDerivePartialEq types errShape canSupportPartialEqOrd types okShape && canSupportPartialEqOrd types errShape
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } -> Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } -> Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } ->
List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) List.all fields \{ id } -> canSupportPartialEqOrd types (Types.shape types id)
cannotDeriveCopy : Types, Shape -> Bool cannotSupportCopy : Types, Shape -> Bool
cannotDeriveCopy = \types, type -> cannotSupportCopy = \types, type ->
!(canDeriveCopy types type) !(canDeriveCopy types type)
canDeriveCopy : Types, Shape -> Bool canDeriveCopy : Types, Shape -> Bool
@ -1780,22 +1800,22 @@ canDeriveCopy = \types, type ->
Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } -> Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } ->
List.all fields \{ id } -> canDeriveCopy types (Types.shape types id) List.all fields \{ id } -> canDeriveCopy types (Types.shape types id)
cannotDeriveDefault = \types, type -> cannotSupportDefault = \types, type ->
when type is when type is
Unit | Unsized | EmptyTagUnion | TagUnion _ | RocResult _ _ | RecursivePointer _ | Function _ -> Bool.true Unit | Unsized | EmptyTagUnion | TagUnion _ | RocResult _ _ | RecursivePointer _ | Function _ -> Bool.true
RocStr | Bool | Num _ -> Bool.false RocStr | Bool | Num _ -> Bool.false
RocList id | RocSet id | RocBox id -> RocList id | RocSet id | RocBox id ->
cannotDeriveDefault types (Types.shape types id) cannotSupportDefault types (Types.shape types id)
TagUnionPayload { fields: HasClosure _ } -> Bool.true TagUnionPayload { fields: HasClosure _ } -> Bool.true
RocDict keyId valId -> RocDict keyId valId ->
cannotDeriveCopy types (Types.shape types keyId) cannotSupportCopy types (Types.shape types keyId)
|| cannotDeriveCopy types (Types.shape types valId) || cannotSupportCopy types (Types.shape types valId)
Struct { fields: HasClosure _ } -> Bool.true Struct { fields: HasClosure _ } -> Bool.true
Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } -> Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } ->
List.any fields \{ id } -> cannotDeriveDefault types (Types.shape types id) List.any fields \{ id } -> cannotSupportDefault types (Types.shape types id)
hasFloat = \types, type -> hasFloat = \types, type ->
hasFloatHelp types type (Set.empty {}) hasFloatHelp types type (Set.empty {})