Update Dict to no longer use Nat

This commit is contained in:
Richard Feldman 2024-01-21 14:50:38 -05:00
parent 77a10986d6
commit 16ddb16177
No known key found for this signature in database
GPG key ID: F1F21AA5B1D9E43B

View file

@ -34,7 +34,7 @@ interface Dict
Result.{ Result },
List,
Str,
Num.{ Nat, U64, F32, U32, U8, I8 },
Num.{ U64, F32, U32, U8, I8 },
Hash.{ Hasher, Hash },
Inspect.{ Inspect, Inspector, InspectFormatter },
]
@ -151,26 +151,25 @@ empty = \{} ->
## Return a dictionary with space allocated for a number of entries. This
## may provide a performance optimization if you know how many entries will be
## inserted.
withCapacity : Nat -> Dict * *
withCapacity : U64 -> Dict * *
withCapacity = \requested ->
empty {}
|> reserve requested
## Enlarge the dictionary for at least capacity additional elements
reserve : Dict k v, Nat -> Dict k v
reserve : Dict k v, U64 -> Dict k v
reserve = \@Dict { buckets, data, maxBucketCapacity: originalMaxBucketCapacity, maxLoadFactor, shifts }, requested ->
currentSize = List.len data
requestedSize = currentSize + requested
size = Num.min (Num.toU64 requestedSize) maxSize
requestedSize = Num.addWrap currentSize requested
size = Num.min requestedSize maxSize
requestedShifts = calcShiftsForSize size maxLoadFactor
if (List.isEmpty buckets) || requestedShifts > shifts then
(buckets0, maxBucketCapacity) = allocBucketsFromShift requestedShifts maxLoadFactor
buckets1 = fillBucketsFromData buckets0 data requestedShifts
sizeNat = Num.toNat size
@Dict {
buckets: buckets1,
data: List.reserve data (Num.subSaturated sizeNat currentSize),
data: List.reserve data (Num.subSaturated size currentSize),
maxBucketCapacity,
maxLoadFactor,
shifts: requestedShifts,
@ -186,7 +185,7 @@ releaseExcessCapacity = \@Dict { buckets, data, maxBucketCapacity: originalMaxBu
size = List.len data
# NOTE: If we want, we technically could increase the load factor here to potentially minimize size more.
minShifts = calcShiftsForSize (Num.toU64 size) maxLoadFactor
minShifts = calcShiftsForSize size maxLoadFactor
if minShifts < shifts then
(buckets0, maxBucketCapacity) = allocBucketsFromShift minShifts maxLoadFactor
buckets1 = fillBucketsFromData buckets0 data minShifts
@ -208,9 +207,9 @@ releaseExcessCapacity = \@Dict { buckets, data, maxBucketCapacity: originalMaxBu
##
## capacityOfDict = Dict.capacity foodDict
## ```
capacity : Dict * * -> Nat
capacity : Dict * * -> U64
capacity = \@Dict { maxBucketCapacity } ->
Num.toNat maxBucketCapacity
maxBucketCapacity
## Returns a dictionary containing the key and value provided as input.
## ```
@ -251,7 +250,7 @@ fromList = \data ->
## |> Dict.len
## |> Bool.isEq 3
## ```
len : Dict * * -> Nat
len : Dict * * -> U64
len = \@Dict { data } ->
List.len data
@ -372,14 +371,14 @@ keepIf : Dict k v, ((k, v) -> Bool) -> Dict k v
keepIf = \dict, predicate ->
keepIfHelp dict predicate 0 (Dict.len dict)
keepIfHelp : Dict k v, ((k, v) -> Bool), Nat, Nat -> Dict k v
keepIfHelp : Dict k v, ((k, v) -> Bool), U64, U64 -> Dict k v
keepIfHelp = \@Dict dict, predicate, index, length ->
if index < length then
(key, value) = listGetUnsafe dict.data index
if predicate (key, value) then
keepIfHelp (@Dict dict) predicate (index + 1) length
keepIfHelp (@Dict dict) predicate (index |> Num.addWrap 1) length
else
keepIfHelp (Dict.remove (@Dict dict) key) predicate index (length - 1)
keepIfHelp (Dict.remove (@Dict dict) key) predicate index (length |> Num.subWrap 1)
else
@Dict dict
@ -450,13 +449,13 @@ insert = \dict, key, value ->
insertHelper buckets data bucketIndex distAndFingerprint key value maxBucketCapacity maxLoadFactor shifts
insertHelper : List Bucket, List (k, v), Nat, U32, k, v, U64, F32, U8 -> Dict k v
insertHelper : List Bucket, List (k, v), U64, U32, k, v, U64, F32, U8 -> Dict k v
insertHelper = \buckets0, data0, bucketIndex0, distAndFingerprint0, key, value, maxBucketCapacity, maxLoadFactor, shifts ->
loaded = listGetUnsafe buckets0 (Num.toNat bucketIndex0)
loaded = listGetUnsafe buckets0 bucketIndex0
if distAndFingerprint0 == loaded.distAndFingerprint then
(foundKey, _) = listGetUnsafe data0 (Num.toNat loaded.dataIndex)
(foundKey, _) = listGetUnsafe data0 (Num.toU64 loaded.dataIndex)
if foundKey == key then
data1 = List.set data0 (Num.toNat loaded.dataIndex) (key, value)
data1 = List.set data0 (Num.toU64 loaded.dataIndex) (key, value)
@Dict { buckets: buckets0, data: data1, maxBucketCapacity, maxLoadFactor, shifts }
else
bucketIndex1 = nextBucketIndex bucketIndex0 (List.len buckets0)
@ -464,7 +463,7 @@ insertHelper = \buckets0, data0, bucketIndex0, distAndFingerprint0, key, value,
insertHelper buckets0 data0 bucketIndex1 distAndFingerprint1 key value maxBucketCapacity maxLoadFactor shifts
else if distAndFingerprint0 > loaded.distAndFingerprint then
data1 = List.append data0 (key, value)
dataIndex = (List.len data1) - 1
dataIndex = (List.len data1) |> Num.subWrap 1
buckets1 = placeAndShiftUp buckets0 { distAndFingerprint: distAndFingerprint0, dataIndex: Num.toU32 dataIndex } bucketIndex0
@Dict { buckets: buckets1, data: data1, maxBucketCapacity, maxLoadFactor, shifts }
else
@ -487,7 +486,7 @@ remove = \@Dict { buckets, data, maxBucketCapacity, maxLoadFactor, shifts }, key
(bucketIndex0, distAndFingerprint0) = nextWhileLess buckets key shifts
(bucketIndex1, distAndFingerprint1) = removeHelper buckets bucketIndex0 distAndFingerprint0 data key
bucket = listGetUnsafe buckets (Num.toNat bucketIndex1)
bucket = listGetUnsafe buckets bucketIndex1
if distAndFingerprint1 != bucket.distAndFingerprint then
@Dict { buckets, data, maxBucketCapacity, maxLoadFactor, shifts }
else
@ -495,10 +494,11 @@ remove = \@Dict { buckets, data, maxBucketCapacity, maxLoadFactor, shifts }, key
else
@Dict { buckets, data, maxBucketCapacity, maxLoadFactor, shifts }
removeHelper : List Bucket, U64, U32, List (k, *), k -> (U64, U32) where k implements Eq
removeHelper = \buckets, bucketIndex, distAndFingerprint, data, key ->
bucket = listGetUnsafe buckets (Num.toNat bucketIndex)
bucket = listGetUnsafe buckets bucketIndex
if distAndFingerprint == bucket.distAndFingerprint then
(foundKey, _) = listGetUnsafe data (Num.toNat bucket.dataIndex)
(foundKey, _) = listGetUnsafe data (Num.toU64 bucket.dataIndex)
if foundKey == key then
(bucketIndex, distAndFingerprint)
else
@ -529,7 +529,7 @@ update = \@Dict { buckets, data, maxBucketCapacity, maxLoadFactor, shifts }, key
when alter (Present value) is
Present newValue ->
bucket = listGetUnsafe buckets bucketIndex
newData = List.set data (Num.toNat bucket.dataIndex) (key, newValue)
newData = List.set data (Num.toU64 bucket.dataIndex) (key, newValue)
@Dict { buckets, data: newData, maxBucketCapacity, maxLoadFactor, shifts }
Missing ->
@ -538,7 +538,7 @@ update = \@Dict { buckets, data, maxBucketCapacity, maxLoadFactor, shifts }, key
Err KeyNotFound ->
when alter Missing is
Present newValue ->
if List.len data >= (Num.toNat maxBucketCapacity) then
if List.len data >= maxBucketCapacity then
# Need to reallocate let regular insert handle that.
insert (@Dict { buckets, data, maxBucketCapacity, maxLoadFactor, shifts }) key newValue
else
@ -721,20 +721,20 @@ emptyBucket = { distAndFingerprint: 0, dataIndex: 0 }
distInc = Num.shiftLeftBy 1u32 8 # skip 1 byte fingerprint
fingerprintMask = Num.subWrap distInc 1 # mask for 1 byte of fingerprint
defaultMaxLoadFactor = 0.8
initialShifts = 64 - 3 # 2^(64-shifts) number of buckets
initialShifts = 64 |> Num.subWrap 3 # 2^(64-shifts) number of buckets
maxSize = Num.shiftLeftBy 1u64 32
maxBucketCount = maxSize
incrementDist = \distAndFingerprint ->
distAndFingerprint + distInc
Num.addWrap distAndFingerprint distInc
incrementDistN = \distAndFingerprint, n ->
distAndFingerprint + (n * distInc)
Num.addWrap distAndFingerprint (Num.mulWrap n distInc)
decrementDist = \distAndFingerprint ->
distAndFingerprint - distInc
distAndFingerprint |> Num.subWrap distInc
find : Dict k v, k -> { bucketIndex : Nat, result : Result v [KeyNotFound] }
find : Dict k v, k -> { bucketIndex : U64, result : Result v [KeyNotFound] }
find = \@Dict { buckets, data, shifts }, key ->
hash = hashKey key
distAndFingerprint = distAndFingerprintFromHash hash
@ -749,13 +749,13 @@ find = \@Dict { buckets, data, shifts }, key ->
findManualUnrolls = 2
findFirstUnroll : List Bucket, Nat, U32, List (k, v), k -> { bucketIndex : Nat, result : Result v [KeyNotFound] } where k implements Eq
findFirstUnroll : List Bucket, U64, U32, List (k, v), k -> { bucketIndex : U64, result : Result v [KeyNotFound] } where k implements Eq
findFirstUnroll = \buckets, bucketIndex, distAndFingerprint, data, key ->
# TODO: once we have short circuit evaluation, use it here and other similar locations in this file.
# Avoid the nested if with else block inconvenience.
bucket = listGetUnsafe buckets bucketIndex
if distAndFingerprint == bucket.distAndFingerprint then
(foundKey, value) = listGetUnsafe data (Num.toNat bucket.dataIndex)
(foundKey, value) = listGetUnsafe data (Num.toU64 bucket.dataIndex)
if foundKey == key then
{ bucketIndex, result: Ok value }
else
@ -763,11 +763,11 @@ findFirstUnroll = \buckets, bucketIndex, distAndFingerprint, data, key ->
else
findSecondUnroll buckets (nextBucketIndex bucketIndex (List.len buckets)) (incrementDist distAndFingerprint) data key
findSecondUnroll : List Bucket, Nat, U32, List (k, v), k -> { bucketIndex : Nat, result : Result v [KeyNotFound] } where k implements Eq
findSecondUnroll : List Bucket, U64, U32, List (k, v), k -> { bucketIndex : U64, result : Result v [KeyNotFound] } where k implements Eq
findSecondUnroll = \buckets, bucketIndex, distAndFingerprint, data, key ->
bucket = listGetUnsafe buckets bucketIndex
if distAndFingerprint == bucket.distAndFingerprint then
(foundKey, value) = listGetUnsafe data (Num.toNat bucket.dataIndex)
(foundKey, value) = listGetUnsafe data (Num.toU64 bucket.dataIndex)
if foundKey == key then
{ bucketIndex, result: Ok value }
else
@ -775,11 +775,11 @@ findSecondUnroll = \buckets, bucketIndex, distAndFingerprint, data, key ->
else
findHelper buckets (nextBucketIndex bucketIndex (List.len buckets)) (incrementDist distAndFingerprint) data key
findHelper : List Bucket, Nat, U32, List (k, v), k -> { bucketIndex : Nat, result : Result v [KeyNotFound] } where k implements Eq
findHelper : List Bucket, U64, U32, List (k, v), k -> { bucketIndex : U64, result : Result v [KeyNotFound] } where k implements Eq
findHelper = \buckets, bucketIndex, distAndFingerprint, data, key ->
bucket = listGetUnsafe buckets bucketIndex
if distAndFingerprint == bucket.distAndFingerprint then
(foundKey, value) = listGetUnsafe data (Num.toNat bucket.dataIndex)
(foundKey, value) = listGetUnsafe data (Num.toU64 bucket.dataIndex)
if foundKey == key then
{ bucketIndex, result: Ok value }
else
@ -789,18 +789,19 @@ findHelper = \buckets, bucketIndex, distAndFingerprint, data, key ->
else
findHelper buckets (nextBucketIndex bucketIndex (List.len buckets)) (incrementDist distAndFingerprint) data key
removeBucket : Dict k v, Nat -> Dict k v
removeBucket : Dict k v, U64 -> Dict k v
removeBucket = \@Dict { buckets: buckets0, data: data0, maxBucketCapacity, maxLoadFactor, shifts }, bucketIndex0 ->
{ dataIndex: dataIndexToRemove } = listGetUnsafe buckets0 bucketIndex0
dataIndexToRemove = (listGetUnsafe buckets0 bucketIndex0).dataIndex
dataIndexToRemoveU64 = Num.toU64 dataIndexToRemove
(buckets1, bucketIndex1) = removeBucketHelper buckets0 bucketIndex0
buckets2 = List.set buckets1 bucketIndex1 emptyBucket
lastDataIndex = List.len data0 - 1
if (Num.toNat dataIndexToRemove) != lastDataIndex then
lastDataIndex = List.len data0 |> Num.subWrap 1
if dataIndexToRemoveU64 != lastDataIndex then
# Swap removed item to the end
data1 = List.swap data0 (Num.toNat dataIndexToRemove) lastDataIndex
(key, _) = listGetUnsafe data1 (Num.toNat dataIndexToRemove)
data1 = List.swap data0 dataIndexToRemoveU64 lastDataIndex
(key, _) = listGetUnsafe data1 dataIndexToRemoveU64
# Update the data index of the new value.
hash = hashKey key
@ -824,7 +825,7 @@ removeBucket = \@Dict { buckets: buckets0, data: data0, maxBucketCapacity, maxLo
shifts,
}
scanForIndex : List Bucket, Nat, U32 -> Nat
scanForIndex : List Bucket, U64, U32 -> U64
scanForIndex = \buckets, bucketIndex, dataIndex ->
bucket = listGetUnsafe buckets bucketIndex
if bucket.dataIndex != dataIndex then
@ -832,7 +833,7 @@ scanForIndex = \buckets, bucketIndex, dataIndex ->
else
bucketIndex
removeBucketHelper : List Bucket, Nat -> (List Bucket, Nat)
removeBucketHelper : List Bucket, U64 -> (List Bucket, U64)
removeBucketHelper = \buckets, bucketIndex ->
nextIndex = nextBucketIndex bucketIndex (List.len buckets)
nextBucket = listGetUnsafe buckets nextIndex
@ -846,7 +847,7 @@ removeBucketHelper = \buckets, bucketIndex ->
increaseSize : Dict k v -> Dict k v
increaseSize = \@Dict { data, maxBucketCapacity, maxLoadFactor, shifts } ->
if maxBucketCapacity != maxBucketCount then
newShifts = shifts - 1
newShifts = shifts |> Num.subWrap 1
(buckets0, newMaxBucketCapacity) = allocBucketsFromShift newShifts maxLoadFactor
buckets1 = fillBucketsFromData buckets0 data newShifts
@Dict {
@ -864,14 +865,14 @@ allocBucketsFromShift = \shifts, maxLoadFactor ->
bucketCount = calcNumBuckets shifts
if bucketCount == maxBucketCount then
# reached the maximum, make sure we can use each bucket
(List.repeat emptyBucket (Num.toNat maxBucketCount), maxBucketCount)
(List.repeat emptyBucket maxBucketCount, maxBucketCount)
else
maxBucketCapacity =
bucketCount
|> Num.toF32
|> Num.mul maxLoadFactor
|> Num.floor
(List.repeat emptyBucket (Num.toNat bucketCount), maxBucketCapacity)
(List.repeat emptyBucket bucketCount, maxBucketCapacity)
calcShiftsForSize : U64, F32 -> U8
calcShiftsForSize = \size, maxLoadFactor ->
@ -885,13 +886,13 @@ calcShiftsForSizeHelper = \shifts, size, maxLoadFactor ->
|> Num.mul maxLoadFactor
|> Num.floor
if shifts > 0 && maxBucketCapacity < size then
calcShiftsForSizeHelper (shifts - 1) size maxLoadFactor
calcShiftsForSizeHelper (shifts |> Num.subWrap 1) size maxLoadFactor
else
shifts
calcNumBuckets = \shifts ->
Num.min
(Num.shiftLeftBy 1 (64 - shifts))
(Num.shiftLeftBy 1 (64 |> Num.subWrap shifts))
maxBucketCount
fillBucketsFromData = \buckets0, data, shifts ->
@ -899,7 +900,7 @@ fillBucketsFromData = \buckets0, data, shifts ->
(bucketIndex, distAndFingerprint) = nextWhileLess buckets1 key shifts
placeAndShiftUp buckets1 { distAndFingerprint, dataIndex: Num.toU32 dataIndex } bucketIndex
nextWhileLess : List Bucket, k, U8 -> (Nat, U32) where k implements Hash & Eq
nextWhileLess : List Bucket, k, U8 -> (U64, U32) where k implements Hash & Eq
nextWhileLess = \buckets, key, shifts ->
hash = hashKey key
distAndFingerprint = distAndFingerprintFromHash hash
@ -908,22 +909,22 @@ nextWhileLess = \buckets, key, shifts ->
nextWhileLessHelper buckets bucketIndex distAndFingerprint
nextWhileLessHelper = \buckets, bucketIndex, distAndFingerprint ->
loaded = listGetUnsafe buckets (Num.toNat bucketIndex)
loaded = listGetUnsafe buckets bucketIndex
if distAndFingerprint < loaded.distAndFingerprint then
nextWhileLessHelper buckets (nextBucketIndex bucketIndex (List.len buckets)) (incrementDist distAndFingerprint)
else
(bucketIndex, distAndFingerprint)
placeAndShiftUp = \buckets0, bucket, bucketIndex ->
loaded = listGetUnsafe buckets0 (Num.toNat bucketIndex)
loaded = listGetUnsafe buckets0 bucketIndex
if loaded.distAndFingerprint != 0 then
buckets1 = List.set buckets0 (Num.toNat bucketIndex) bucket
buckets1 = List.set buckets0 bucketIndex bucket
placeAndShiftUp
buckets1
{ loaded & distAndFingerprint: incrementDist loaded.distAndFingerprint }
(nextBucketIndex bucketIndex (List.len buckets1))
else
List.set buckets0 (Num.toNat bucketIndex) bucket
List.set buckets0 bucketIndex bucket
nextBucketIndex = \bucketIndex, maxBuckets ->
# I just ported this impl directly.
@ -947,11 +948,10 @@ distAndFingerprintFromHash = \hash ->
|> Num.bitwiseAnd fingerprintMask
|> Num.bitwiseOr distInc
bucketIndexFromHash : U64, U8 -> Nat
bucketIndexFromHash : U64, U8 -> U64
bucketIndexFromHash = \hash, shifts ->
hash
|> Num.shiftRightZfBy shifts
|> Num.toNat
expect
val =
@ -1185,12 +1185,6 @@ expect
|> len
|> Bool.isEq 0
# Makes sure a Dict with Nat keys works
expect
empty {}
|> insert 7nat "Testing"
|> get 7
|> Bool.isEq (Ok "Testing")
# All BadKey's hash to the same location.
# This is needed to test some robinhood logic.
@ -1225,7 +1219,7 @@ expect
acc, k <- List.walk badKeys (Dict.empty {})
Dict.update acc k \val ->
when val is
Present p -> Present (p + 1)
Present p -> Present (p |> Num.addWrap 1)
Missing -> Present 0
allInsertedCorrectly =
@ -1236,7 +1230,7 @@ expect
# Note, there are a number of places we should probably use set and replace unsafe.
# unsafe primitive that does not perform a bounds check
listGetUnsafe : List a, Nat -> a
listGetUnsafe : List a, U64 -> a
# We have decided not to expose the standard roc hashing algorithm.
# This is to avoid external dependence and the need for versioning.
@ -1368,9 +1362,9 @@ addBytes = \@LowLevelHasher { initializedSeed, state }, list ->
else
hashBytesHelper48 initializedSeed initializedSeed initializedSeed list 0 length
combineState (@LowLevelHasher { initializedSeed, state }) { a: abs.a, b: abs.b, seed: abs.seed, length: Num.toU64 length }
combineState (@LowLevelHasher { initializedSeed, state }) { a: abs.a, b: abs.b, seed: abs.seed, length }
hashBytesHelper48 : U64, U64, U64, List U8, Nat, Nat -> { a : U64, b : U64, seed : U64 }
hashBytesHelper48 : U64, U64, U64, List U8, U64, U64 -> { a : U64, b : U64, seed : U64 }
hashBytesHelper48 = \seed, see1, see2, list, index, remaining ->
newSeed = wymix (Num.bitwiseXor (wyr8 list index) wyp1) (Num.bitwiseXor (wyr8 list (Num.addWrap index 8)) seed)
newSee1 = wymix (Num.bitwiseXor (wyr8 list (Num.addWrap index 16)) wyp2) (Num.bitwiseXor (wyr8 list (Num.addWrap index 24)) see1)
@ -1389,7 +1383,7 @@ hashBytesHelper48 = \seed, see1, see2, list, index, remaining ->
{ a: wyr8 list (Num.subWrap newRemaining 16 |> Num.addWrap newIndex), b: wyr8 list (Num.subWrap newRemaining 8 |> Num.addWrap newIndex), seed: finalSeed }
hashBytesHelper16 : U64, List U8, Nat, Nat -> { a : U64, b : U64, seed : U64 }
hashBytesHelper16 : U64, List U8, U64, U64 -> { a : U64, b : U64, seed : U64 }
hashBytesHelper16 = \seed, list, index, remaining ->
newSeed = wymix (Num.bitwiseXor (wyr8 list index) wyp1) (Num.bitwiseXor (wyr8 list (Num.addWrap index 8)) seed)
newRemaining = Num.subWrap remaining 16
@ -1426,7 +1420,7 @@ wymum = \a, b ->
{ lower, upper }
# Get the next 8 bytes as a U64
wyr8 : List U8, Nat -> U64
wyr8 : List U8, U64 -> U64
wyr8 = \list, index ->
# With seamless slices and Num.fromBytes, this should be possible to make faster and nicer.
# It would also deal with the fact that on big endian systems we want to invert the order here.
@ -1447,7 +1441,7 @@ wyr8 = \list, index ->
Num.bitwiseOr (Num.bitwiseOr a b) (Num.bitwiseOr c d)
# Get the next 4 bytes as a U64 with some shifting.
wyr4 : List U8, Nat -> U64
wyr4 : List U8, U64 -> U64
wyr4 = \list, index ->
p1 = listGetUnsafe list index |> Num.toU64
p2 = listGetUnsafe list (Num.addWrap index 1) |> Num.toU64
@ -1460,7 +1454,7 @@ wyr4 = \list, index ->
# Get the next K bytes with some shifting.
# K must be 3 or less.
wyr3 : List U8, Nat, Nat -> U64
wyr3 : List U8, U64, U64 -> U64
wyr3 = \list, index, k ->
# ((uint64_t)p[0])<<16)|(((uint64_t)p[k>>1])<<8)|p[k-1]
p1 = listGetUnsafe list index |> Num.toU64