Issue #18844: Add random.weighted_choices()

This commit is contained in:
Raymond Hettinger 2016-09-06 17:15:29 -07:00
parent 63d98bcd4c
commit e8f1e002c6
4 changed files with 118 additions and 1 deletions

View file

@ -8,6 +8,7 @@
---------
pick random element
pick random sample
pick weighted random sample
generate random permutation
distributions on the real line:
@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
from os import urandom as _urandom
from _collections_abc import Set as _Set, Sequence as _Sequence
from hashlib import sha512 as _sha512
import itertools as _itertools
import bisect as _bisect
__all__ = ["Random","seed","random","uniform","randint","choice","sample",
"randrange","shuffle","normalvariate","lognormvariate",
"expovariate","vonmisesvariate","gammavariate","triangular",
"gauss","betavariate","paretovariate","weibullvariate",
"getstate","setstate", "getrandbits",
"getstate","setstate", "getrandbits", "weighted_choices",
"SystemRandom"]
NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
@ -334,6 +337,28 @@ class Random(_random.Random):
result[i] = population[j]
return result
def weighted_choices(self, k, population, weights=None, *, cum_weights=None):
"""Return a k sized list of population elements chosen with replacement.
If the relative weights or cumulative weights are not specified,
the selections are made with equal probability.
"""
if cum_weights is None:
if weights is None:
choice = self.choice
return [choice(population) for i in range(k)]
else:
cum_weights = list(_itertools.accumulate(weights))
elif weights is not None:
raise TypeError('Cannot specify both weights and cumulative_weights')
if len(cum_weights) != len(population):
raise ValueError('The number of weights does not match the population')
bisect = _bisect.bisect
random = self.random
total = cum_weights[-1]
return [population[bisect(cum_weights, random() * total)] for i in range(k)]
## -------------------- real-valued distributions -------------------
## -------------------- uniform distribution -------------------
@ -724,6 +749,7 @@ choice = _inst.choice
randrange = _inst.randrange
sample = _inst.sample
shuffle = _inst.shuffle
weighted_choices = _inst.weighted_choices
normalvariate = _inst.normalvariate
lognormvariate = _inst.lognormvariate
expovariate = _inst.expovariate