diff --git a/crates/glue/src/RocType.roc b/crates/glue/src/RocType.roc index d3e8bcfbb4..cab235484d 100644 --- a/crates/glue/src/RocType.roc +++ b/crates/glue/src/RocType.roc @@ -77,11 +77,11 @@ RocType : [ EmptyTagUnion, Struct { name: Str, - fields: List { name: Str, type: TypeId } + fields: List { name: Str, id: TypeId } }, TagUnionPayload { name: Str, - fields: List { discriminant: Nat, type: TypeId }, + fields: List { discriminant: Nat, id: TypeId }, }, ## A recursive pointer, e.g. in StrConsList : [Nil, Cons Str StrConsList], ## this would be the field of Cons containing the (recursive) StrConsList type, diff --git a/crates/glue/src/RustGlue.roc b/crates/glue/src/RustGlue.roc index 6ee34e566e..ba9556b75c 100644 --- a/crates/glue/src/RustGlue.roc +++ b/crates/glue/src/RustGlue.roc @@ -82,14 +82,14 @@ generateStruct = \buf, types, id, name, fields, visibility -> structType = getType types id buf - |> addDeriveStr structType types IncludeDebug + |> addDeriveStr types structType IncludeDebug |> Str.concat "#[repr(\(repr))]\n\(pub) struct \(escapedName) {\n" |> \b -> List.walk fields b (generateStructFields types) |> Str.concat "}\n\n" generateStructFields = \types -> - \accum, { name: fieldName, type: fieldType } -> - typeStr = typeName types fieldType + \accum, { name: fieldName, id } -> + typeStr = typeName types id escapedFieldName = escapeKW fieldName Str.concat accum "\(indent)pub \(escapedFieldName): \(typeStr),\n" @@ -97,12 +97,13 @@ generateStructFields = \types -> nameTagUnionPayloadFields = \fields -> # Tag union payloads have numbered fields, so we prefix them # with an "f" because Rust doesn't allow struct fields to be numbers. - List.map fields \{ discriminant, type } -> + List.map fields \{ discriminant, id } -> discStr = Num.toStr discriminant - { name: "f\(discStr)", type } + { name: "f\(discStr)", id } -addDeriveStr = \buf, _type, _types, includeDebug -> +addDeriveStr = \buf, types, type, includeDebug -> + # TODO: full derive impl porting. buf |> Str.concat "#[derive(Clone, " |> \b -> @@ -111,9 +112,35 @@ addDeriveStr = \buf, _type, _types, includeDebug -> Str.concat b "Debug, " ExcludeDebug -> - b # TODO: full derive impl porting. + b + |> \b -> + if !(cannotDeriveCopy types type) then + Str.concat b "Copy, " + else + b |> Str.concat "PartialEq, PartialOrd)]\n" +cannotDeriveCopy = \types, type -> + when type is + Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) | Function _ -> Bool.false + RocStr | RocList _ | RocDict _ _ | RocSet _ | RocBox _ | TagUnion ( NullableUnwrapped _ ) | TagUnion ( NullableWrapped _ ) | TagUnion ( Recursive _ ) | TagUnion ( NonNullableUnwrapped _) | RecursivePointer _ -> Bool.true + TagUnion (SingleTagStruct { payloadFields }) -> + List.any payloadFields \id -> cannotDeriveCopy types (getType types id) + TagUnion (NonRecursive {tags}) -> + List.any tags \{payload} -> + when payload is + Some id -> cannotDeriveCopy types (getType types id) + None -> Bool.false + RocResult okId errId -> + cannotDeriveCopy types (getType types okId) + || cannotDeriveCopy types (getType types errId) + Struct { fields} -> + List.any fields \{ id } -> cannotDeriveCopy types (getType types id) + TagUnionPayload { fields} -> + List.any fields \{ id } -> cannotDeriveCopy types (getType types id) + _ -> crash "ugh" + + typeName = \types, id -> when getType types id is Unit -> "()"