mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
Issue #18844: Add random.weighted_choices()
This commit is contained in:
parent
63d98bcd4c
commit
e8f1e002c6
4 changed files with 118 additions and 1 deletions
|
@ -8,6 +8,7 @@
|
|||
---------
|
||||
pick random element
|
||||
pick random sample
|
||||
pick weighted random sample
|
||||
generate random permutation
|
||||
|
||||
distributions on the real line:
|
||||
|
@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
|
|||
from os import urandom as _urandom
|
||||
from _collections_abc import Set as _Set, Sequence as _Sequence
|
||||
from hashlib import sha512 as _sha512
|
||||
import itertools as _itertools
|
||||
import bisect as _bisect
|
||||
|
||||
__all__ = ["Random","seed","random","uniform","randint","choice","sample",
|
||||
"randrange","shuffle","normalvariate","lognormvariate",
|
||||
"expovariate","vonmisesvariate","gammavariate","triangular",
|
||||
"gauss","betavariate","paretovariate","weibullvariate",
|
||||
"getstate","setstate", "getrandbits",
|
||||
"getstate","setstate", "getrandbits", "weighted_choices",
|
||||
"SystemRandom"]
|
||||
|
||||
NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
|
||||
|
@ -334,6 +337,28 @@ class Random(_random.Random):
|
|||
result[i] = population[j]
|
||||
return result
|
||||
|
||||
def weighted_choices(self, k, population, weights=None, *, cum_weights=None):
|
||||
"""Return a k sized list of population elements chosen with replacement.
|
||||
|
||||
If the relative weights or cumulative weights are not specified,
|
||||
the selections are made with equal probability.
|
||||
|
||||
"""
|
||||
if cum_weights is None:
|
||||
if weights is None:
|
||||
choice = self.choice
|
||||
return [choice(population) for i in range(k)]
|
||||
else:
|
||||
cum_weights = list(_itertools.accumulate(weights))
|
||||
elif weights is not None:
|
||||
raise TypeError('Cannot specify both weights and cumulative_weights')
|
||||
if len(cum_weights) != len(population):
|
||||
raise ValueError('The number of weights does not match the population')
|
||||
bisect = _bisect.bisect
|
||||
random = self.random
|
||||
total = cum_weights[-1]
|
||||
return [population[bisect(cum_weights, random() * total)] for i in range(k)]
|
||||
|
||||
## -------------------- real-valued distributions -------------------
|
||||
|
||||
## -------------------- uniform distribution -------------------
|
||||
|
@ -724,6 +749,7 @@ choice = _inst.choice
|
|||
randrange = _inst.randrange
|
||||
sample = _inst.sample
|
||||
shuffle = _inst.shuffle
|
||||
weighted_choices = _inst.weighted_choices
|
||||
normalvariate = _inst.normalvariate
|
||||
lognormvariate = _inst.lognormvariate
|
||||
expovariate = _inst.expovariate
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue