mirror of
https://github.com/python/cpython.git
synced 2025-09-26 10:19:53 +00:00
Issue #18844: Add random.weighted_choices()
This commit is contained in:
parent
63d98bcd4c
commit
e8f1e002c6
4 changed files with 118 additions and 1 deletions
|
@ -124,6 +124,27 @@ Functions for sequences:
|
||||||
Return a random element from the non-empty sequence *seq*. If *seq* is empty,
|
Return a random element from the non-empty sequence *seq*. If *seq* is empty,
|
||||||
raises :exc:`IndexError`.
|
raises :exc:`IndexError`.
|
||||||
|
|
||||||
|
.. function:: weighted_choices(k, population, weights=None, *, cum_weights=None)
|
||||||
|
|
||||||
|
Return a *k* sized list of elements chosen from the *population* with replacement.
|
||||||
|
If the *population* is empty, raises :exc:`IndexError`.
|
||||||
|
|
||||||
|
If a *weights* sequence is specified, selections are made according to the
|
||||||
|
relative weights. Alternatively, if a *cum_weights* sequence is given, the
|
||||||
|
selections are made according to the cumulative weights. For example, the
|
||||||
|
relative weights ``[10, 5, 30, 5]`` are equivalent to the cumulative
|
||||||
|
weights ``[10, 15, 45, 50]``. Internally, the relative weights are
|
||||||
|
converted to cumulative weights before making selections, so supplying the
|
||||||
|
cumulative weights saves work.
|
||||||
|
|
||||||
|
If neither *weights* nor *cum_weights* are specified, selections are made
|
||||||
|
with equal probability. If a weights sequence is supplied, it must be
|
||||||
|
the same length as the *population* sequence. It is a :exc:`TypeError`
|
||||||
|
to specify both *weights* and *cum_weights*.
|
||||||
|
|
||||||
|
The *weights* or *cum_weights* can use any numeric type that interoperates
|
||||||
|
with the :class:`float` values returned by :func:`random` (that includes
|
||||||
|
integers, floats, and fractions but excludes decimals).
|
||||||
|
|
||||||
.. function:: shuffle(x[, random])
|
.. function:: shuffle(x[, random])
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
---------
|
---------
|
||||||
pick random element
|
pick random element
|
||||||
pick random sample
|
pick random sample
|
||||||
|
pick weighted random sample
|
||||||
generate random permutation
|
generate random permutation
|
||||||
|
|
||||||
distributions on the real line:
|
distributions on the real line:
|
||||||
|
@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
|
||||||
from os import urandom as _urandom
|
from os import urandom as _urandom
|
||||||
from _collections_abc import Set as _Set, Sequence as _Sequence
|
from _collections_abc import Set as _Set, Sequence as _Sequence
|
||||||
from hashlib import sha512 as _sha512
|
from hashlib import sha512 as _sha512
|
||||||
|
import itertools as _itertools
|
||||||
|
import bisect as _bisect
|
||||||
|
|
||||||
__all__ = ["Random","seed","random","uniform","randint","choice","sample",
|
__all__ = ["Random","seed","random","uniform","randint","choice","sample",
|
||||||
"randrange","shuffle","normalvariate","lognormvariate",
|
"randrange","shuffle","normalvariate","lognormvariate",
|
||||||
"expovariate","vonmisesvariate","gammavariate","triangular",
|
"expovariate","vonmisesvariate","gammavariate","triangular",
|
||||||
"gauss","betavariate","paretovariate","weibullvariate",
|
"gauss","betavariate","paretovariate","weibullvariate",
|
||||||
"getstate","setstate", "getrandbits",
|
"getstate","setstate", "getrandbits", "weighted_choices",
|
||||||
"SystemRandom"]
|
"SystemRandom"]
|
||||||
|
|
||||||
NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
|
NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
|
||||||
|
@ -334,6 +337,28 @@ class Random(_random.Random):
|
||||||
result[i] = population[j]
|
result[i] = population[j]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def weighted_choices(self, k, population, weights=None, *, cum_weights=None):
|
||||||
|
"""Return a k sized list of population elements chosen with replacement.
|
||||||
|
|
||||||
|
If the relative weights or cumulative weights are not specified,
|
||||||
|
the selections are made with equal probability.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if cum_weights is None:
|
||||||
|
if weights is None:
|
||||||
|
choice = self.choice
|
||||||
|
return [choice(population) for i in range(k)]
|
||||||
|
else:
|
||||||
|
cum_weights = list(_itertools.accumulate(weights))
|
||||||
|
elif weights is not None:
|
||||||
|
raise TypeError('Cannot specify both weights and cumulative_weights')
|
||||||
|
if len(cum_weights) != len(population):
|
||||||
|
raise ValueError('The number of weights does not match the population')
|
||||||
|
bisect = _bisect.bisect
|
||||||
|
random = self.random
|
||||||
|
total = cum_weights[-1]
|
||||||
|
return [population[bisect(cum_weights, random() * total)] for i in range(k)]
|
||||||
|
|
||||||
## -------------------- real-valued distributions -------------------
|
## -------------------- real-valued distributions -------------------
|
||||||
|
|
||||||
## -------------------- uniform distribution -------------------
|
## -------------------- uniform distribution -------------------
|
||||||
|
@ -724,6 +749,7 @@ choice = _inst.choice
|
||||||
randrange = _inst.randrange
|
randrange = _inst.randrange
|
||||||
sample = _inst.sample
|
sample = _inst.sample
|
||||||
shuffle = _inst.shuffle
|
shuffle = _inst.shuffle
|
||||||
|
weighted_choices = _inst.weighted_choices
|
||||||
normalvariate = _inst.normalvariate
|
normalvariate = _inst.normalvariate
|
||||||
lognormvariate = _inst.lognormvariate
|
lognormvariate = _inst.lognormvariate
|
||||||
expovariate = _inst.expovariate
|
expovariate = _inst.expovariate
|
||||||
|
|
|
@ -7,6 +7,7 @@ import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from math import log, exp, pi, fsum, sin
|
from math import log, exp, pi, fsum, sin
|
||||||
from test import support
|
from test import support
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
class TestBasicOps:
|
class TestBasicOps:
|
||||||
# Superclass with tests common to all generators.
|
# Superclass with tests common to all generators.
|
||||||
|
@ -141,6 +142,73 @@ class TestBasicOps:
|
||||||
def test_sample_on_dicts(self):
|
def test_sample_on_dicts(self):
|
||||||
self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2)
|
self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2)
|
||||||
|
|
||||||
|
def test_weighted_choices(self):
|
||||||
|
weighted_choices = self.gen.weighted_choices
|
||||||
|
data = ['red', 'green', 'blue', 'yellow']
|
||||||
|
str_data = 'abcd'
|
||||||
|
range_data = range(4)
|
||||||
|
set_data = set(range(4))
|
||||||
|
|
||||||
|
# basic functionality
|
||||||
|
for sample in [
|
||||||
|
weighted_choices(5, data),
|
||||||
|
weighted_choices(5, data, range(4)),
|
||||||
|
weighted_choices(k=5, population=data, weights=range(4)),
|
||||||
|
weighted_choices(k=5, population=data, cum_weights=range(4)),
|
||||||
|
]:
|
||||||
|
self.assertEqual(len(sample), 5)
|
||||||
|
self.assertEqual(type(sample), list)
|
||||||
|
self.assertTrue(set(sample) <= set(data))
|
||||||
|
|
||||||
|
# test argument handling
|
||||||
|
with self.assertRaises(TypeError): # missing arguments
|
||||||
|
weighted_choices(2)
|
||||||
|
|
||||||
|
self.assertEqual(weighted_choices(0, data), []) # k == 0
|
||||||
|
self.assertEqual(weighted_choices(-1, data), []) # negative k behaves like ``[0] * -1``
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
weighted_choices(2.5, data) # k is a float
|
||||||
|
|
||||||
|
self.assertTrue(set(weighted_choices(5, str_data)) <= set(str_data)) # population is a string sequence
|
||||||
|
self.assertTrue(set(weighted_choices(5, range_data)) <= set(range_data)) # population is a range
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
weighted_choices(2.5, set_data) # population is not a sequence
|
||||||
|
|
||||||
|
self.assertTrue(set(weighted_choices(5, data, None)) <= set(data)) # weights is None
|
||||||
|
self.assertTrue(set(weighted_choices(5, data, weights=None)) <= set(data))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
weighted_choices(5, data, [1,2]) # len(weights) != len(population)
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
weighted_choices(5, data, [0]*4) # weights sum to zero
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
weighted_choices(5, data, 10) # non-iterable weights
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
weighted_choices(5, data, [None]*4) # non-numeric weights
|
||||||
|
for weights in [
|
||||||
|
[15, 10, 25, 30], # integer weights
|
||||||
|
[15.1, 10.2, 25.2, 30.3], # float weights
|
||||||
|
[Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional weights
|
||||||
|
[True, False, True, False] # booleans (include / exclude)
|
||||||
|
]:
|
||||||
|
self.assertTrue(set(weighted_choices(5, data, weights)) <= set(data))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
weighted_choices(5, data, cum_weights=[1,2]) # len(weights) != len(population)
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
weighted_choices(5, data, cum_weights=[0]*4) # cum_weights sum to zero
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
weighted_choices(5, data, cum_weights=10) # non-iterable cum_weights
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
weighted_choices(5, data, cum_weights=[None]*4) # non-numeric cum_weights
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
weighted_choices(5, data, range(4), cum_weights=range(4)) # both weights and cum_weights
|
||||||
|
for weights in [
|
||||||
|
[15, 10, 25, 30], # integer cum_weights
|
||||||
|
[15.1, 10.2, 25.2, 30.3], # float cum_weights
|
||||||
|
[Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional cum_weights
|
||||||
|
]:
|
||||||
|
self.assertTrue(set(weighted_choices(5, data, cum_weights=weights)) <= set(data))
|
||||||
|
|
||||||
def test_gauss(self):
|
def test_gauss(self):
|
||||||
# Ensure that the seed() method initializes all the hidden state. In
|
# Ensure that the seed() method initializes all the hidden state. In
|
||||||
# particular, through 2.2.1 it failed to reset a piece of state used
|
# particular, through 2.2.1 it failed to reset a piece of state used
|
||||||
|
|
|
@ -101,6 +101,8 @@ Library
|
||||||
- Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name
|
- Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name
|
||||||
fields in X.509 certs.
|
fields in X.509 certs.
|
||||||
|
|
||||||
|
- Issue #18844: Add random.weighted_choices().
|
||||||
|
|
||||||
- Issue #25761: Improved error reporting about truncated pickle data in
|
- Issue #25761: Improved error reporting about truncated pickle data in
|
||||||
C implementation of unpickler. UnpicklingError is now raised instead of
|
C implementation of unpickler. UnpicklingError is now raised instead of
|
||||||
AttributeError and ValueError in some cases.
|
AttributeError and ValueError in some cases.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue