bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970)

This commit is contained in:
Raymond Hettinger 2020-05-08 07:53:15 -07:00 committed by GitHub
parent 2effef7453
commit 81a5fc38e8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 13 deletions

View file

@ -331,7 +331,7 @@ class Random(_random.Random):
j = _int(random() * (i+1))
x[i], x[j] = x[j], x[i]
def sample(self, population, k):
def sample(self, population, k, *, counts=None):
"""Chooses k unique random elements from a population sequence or set.
Returns a new list containing elements from the population while
@ -344,9 +344,21 @@ class Random(_random.Random):
population contains repeats, then each occurrence is a possible
selection in the sample.
To choose a sample in a range of integers, use range as an argument.
This is especially fast and space efficient for sampling from a
large population: sample(range(10000000), 60)
Repeated elements can be specified one at a time or with the optional
counts parameter. For example:
sample(['red', 'blue'], counts=[4, 2], k=5)
is equivalent to:
sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)
To choose a sample from a range of integers, use range() for the
population argument. This is especially fast and space efficient
for sampling from a large population:
sample(range(10000000), 60)
"""
# Sampling without replacement entails tracking either potential
@ -379,8 +391,20 @@ class Random(_random.Random):
population = tuple(population)
if not isinstance(population, _Sequence):
raise TypeError("Population must be a sequence. For dicts or sets, use sorted(d).")
randbelow = self._randbelow
n = len(population)
if counts is not None:
cum_counts = list(_accumulate(counts))
if len(cum_counts) != n:
raise ValueError('The number of counts does not match the population')
total = cum_counts.pop()
if not isinstance(total, int):
raise TypeError('Counts must be integers')
if total <= 0:
raise ValueError('Total of counts must be greater than zero')
selections = sample(range(total), k=k)
bisect = _bisect
return [population[bisect(cum_counts, s)] for s in selections]
randbelow = self._randbelow
if not 0 <= k <= n:
raise ValueError("Sample larger than population or is negative")
result = [None] * k