mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970)
This commit is contained in:
parent
2effef7453
commit
81a5fc38e8
4 changed files with 116 additions and 13 deletions
|
@ -217,7 +217,7 @@ Functions for sequences
|
||||||
The optional parameter *random*.
|
The optional parameter *random*.
|
||||||
|
|
||||||
|
|
||||||
.. function:: sample(population, k)
|
.. function:: sample(population, k, *, counts=None)
|
||||||
|
|
||||||
Return a *k* length list of unique elements chosen from the population sequence
|
Return a *k* length list of unique elements chosen from the population sequence
|
||||||
or set. Used for random sampling without replacement.
|
or set. Used for random sampling without replacement.
|
||||||
|
@ -231,6 +231,11 @@ Functions for sequences
|
||||||
Members of the population need not be :term:`hashable` or unique. If the population
|
Members of the population need not be :term:`hashable` or unique. If the population
|
||||||
contains repeats, then each occurrence is a possible selection in the sample.
|
contains repeats, then each occurrence is a possible selection in the sample.
|
||||||
|
|
||||||
|
Repeated elements can be specified one at a time or with the optional
|
||||||
|
keyword-only *counts* parameter. For example, ``sample(['red', 'blue'],
|
||||||
|
counts=[4, 2], k=5)`` is equivalent to ``sample(['red', 'red', 'red', 'red',
|
||||||
|
'blue', 'blue'], k=5)``.
|
||||||
|
|
||||||
To choose a sample from a range of integers, use a :func:`range` object as an
|
To choose a sample from a range of integers, use a :func:`range` object as an
|
||||||
argument. This is especially fast and space efficient for sampling from a large
|
argument. This is especially fast and space efficient for sampling from a large
|
||||||
population: ``sample(range(10000000), k=60)``.
|
population: ``sample(range(10000000), k=60)``.
|
||||||
|
@ -238,6 +243,9 @@ Functions for sequences
|
||||||
If the sample size is larger than the population size, a :exc:`ValueError`
|
If the sample size is larger than the population size, a :exc:`ValueError`
|
||||||
is raised.
|
is raised.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.9
|
||||||
|
Added the *counts* parameter.
|
||||||
|
|
||||||
.. deprecated:: 3.9
|
.. deprecated:: 3.9
|
||||||
In the future, the *population* must be a sequence. Instances of
|
In the future, the *population* must be a sequence. Instances of
|
||||||
:class:`set` are no longer supported. The set must first be converted
|
:class:`set` are no longer supported. The set must first be converted
|
||||||
|
@ -420,12 +428,11 @@ Simulations::
|
||||||
>>> choices(['red', 'black', 'green'], [18, 18, 2], k=6)
|
>>> choices(['red', 'black', 'green'], [18, 18, 2], k=6)
|
||||||
['red', 'green', 'black', 'black', 'red', 'black']
|
['red', 'green', 'black', 'black', 'red', 'black']
|
||||||
|
|
||||||
>>> # Deal 20 cards without replacement from a deck of 52 playing cards
|
>>> # Deal 20 cards without replacement from a deck
|
||||||
>>> # and determine the proportion of cards with a ten-value
|
>>> # of 52 playing cards, and determine the proportion of cards
|
||||||
>>> # (a ten, jack, queen, or king).
|
>>> # with a ten-value: ten, jack, queen, or king.
|
||||||
>>> deck = collections.Counter(tens=16, low_cards=36)
|
>>> dealt = sample(['tens', 'low cards'], counts=[16, 36], k=20)
|
||||||
>>> seen = sample(list(deck.elements()), k=20)
|
>>> dealt.count('tens') / 20
|
||||||
>>> seen.count('tens') / 20
|
|
||||||
0.15
|
0.15
|
||||||
|
|
||||||
>>> # Estimate the probability of getting 5 or more heads from 7 spins
|
>>> # Estimate the probability of getting 5 or more heads from 7 spins
|
||||||
|
|
|
@ -331,7 +331,7 @@ class Random(_random.Random):
|
||||||
j = _int(random() * (i+1))
|
j = _int(random() * (i+1))
|
||||||
x[i], x[j] = x[j], x[i]
|
x[i], x[j] = x[j], x[i]
|
||||||
|
|
||||||
def sample(self, population, k):
|
def sample(self, population, k, *, counts=None):
|
||||||
"""Chooses k unique random elements from a population sequence or set.
|
"""Chooses k unique random elements from a population sequence or set.
|
||||||
|
|
||||||
Returns a new list containing elements from the population while
|
Returns a new list containing elements from the population while
|
||||||
|
@ -344,9 +344,21 @@ class Random(_random.Random):
|
||||||
population contains repeats, then each occurrence is a possible
|
population contains repeats, then each occurrence is a possible
|
||||||
selection in the sample.
|
selection in the sample.
|
||||||
|
|
||||||
To choose a sample in a range of integers, use range as an argument.
|
Repeated elements can be specified one at a time or with the optional
|
||||||
This is especially fast and space efficient for sampling from a
|
counts parameter. For example:
|
||||||
large population: sample(range(10000000), 60)
|
|
||||||
|
sample(['red', 'blue'], counts=[4, 2], k=5)
|
||||||
|
|
||||||
|
is equivalent to:
|
||||||
|
|
||||||
|
sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)
|
||||||
|
|
||||||
|
To choose a sample from a range of integers, use range() for the
|
||||||
|
population argument. This is especially fast and space efficient
|
||||||
|
for sampling from a large population:
|
||||||
|
|
||||||
|
sample(range(10000000), 60)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sampling without replacement entails tracking either potential
|
# Sampling without replacement entails tracking either potential
|
||||||
|
@ -379,8 +391,20 @@ class Random(_random.Random):
|
||||||
population = tuple(population)
|
population = tuple(population)
|
||||||
if not isinstance(population, _Sequence):
|
if not isinstance(population, _Sequence):
|
||||||
raise TypeError("Population must be a sequence. For dicts or sets, use sorted(d).")
|
raise TypeError("Population must be a sequence. For dicts or sets, use sorted(d).")
|
||||||
randbelow = self._randbelow
|
|
||||||
n = len(population)
|
n = len(population)
|
||||||
|
if counts is not None:
|
||||||
|
cum_counts = list(_accumulate(counts))
|
||||||
|
if len(cum_counts) != n:
|
||||||
|
raise ValueError('The number of counts does not match the population')
|
||||||
|
total = cum_counts.pop()
|
||||||
|
if not isinstance(total, int):
|
||||||
|
raise TypeError('Counts must be integers')
|
||||||
|
if total <= 0:
|
||||||
|
raise ValueError('Total of counts must be greater than zero')
|
||||||
|
selections = sample(range(total), k=k)
|
||||||
|
bisect = _bisect
|
||||||
|
return [population[bisect(cum_counts, s)] for s in selections]
|
||||||
|
randbelow = self._randbelow
|
||||||
if not 0 <= k <= n:
|
if not 0 <= k <= n:
|
||||||
raise ValueError("Sample larger than population or is negative")
|
raise ValueError("Sample larger than population or is negative")
|
||||||
result = [None] * k
|
result = [None] * k
|
||||||
|
|
|
@ -9,7 +9,7 @@ from functools import partial
|
||||||
from math import log, exp, pi, fsum, sin, factorial
|
from math import log, exp, pi, fsum, sin, factorial
|
||||||
from test import support
|
from test import support
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
class TestBasicOps:
|
class TestBasicOps:
|
||||||
# Superclass with tests common to all generators.
|
# Superclass with tests common to all generators.
|
||||||
|
@ -161,6 +161,77 @@ class TestBasicOps:
|
||||||
population = {10, 20, 30, 40, 50, 60, 70}
|
population = {10, 20, 30, 40, 50, 60, 70}
|
||||||
self.gen.sample(population, k=5)
|
self.gen.sample(population, k=5)
|
||||||
|
|
||||||
|
def test_sample_with_counts(self):
|
||||||
|
sample = self.gen.sample
|
||||||
|
|
||||||
|
# General case
|
||||||
|
colors = ['red', 'green', 'blue', 'orange', 'black', 'brown', 'amber']
|
||||||
|
counts = [500, 200, 20, 10, 5, 0, 1 ]
|
||||||
|
k = 700
|
||||||
|
summary = Counter(sample(colors, counts=counts, k=k))
|
||||||
|
self.assertEqual(sum(summary.values()), k)
|
||||||
|
for color, weight in zip(colors, counts):
|
||||||
|
self.assertLessEqual(summary[color], weight)
|
||||||
|
self.assertNotIn('brown', summary)
|
||||||
|
|
||||||
|
# Case that exhausts the population
|
||||||
|
k = sum(counts)
|
||||||
|
summary = Counter(sample(colors, counts=counts, k=k))
|
||||||
|
self.assertEqual(sum(summary.values()), k)
|
||||||
|
for color, weight in zip(colors, counts):
|
||||||
|
self.assertLessEqual(summary[color], weight)
|
||||||
|
self.assertNotIn('brown', summary)
|
||||||
|
|
||||||
|
# Case with population size of 1
|
||||||
|
summary = Counter(sample(['x'], counts=[10], k=8))
|
||||||
|
self.assertEqual(summary, Counter(x=8))
|
||||||
|
|
||||||
|
# Case with all counts equal.
|
||||||
|
nc = len(colors)
|
||||||
|
summary = Counter(sample(colors, counts=[10]*nc, k=10*nc))
|
||||||
|
self.assertEqual(summary, Counter(10*colors))
|
||||||
|
|
||||||
|
# Test error handling
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
sample(['red', 'green', 'blue'], counts=10, k=10) # counts not iterable
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
sample(['red', 'green', 'blue'], counts=[-3, -7, -8], k=2) # counts are negative
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
sample(['red', 'green', 'blue'], counts=[0, 0, 0], k=2) # counts are zero
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
sample(['red', 'green'], counts=[10, 10], k=21) # population too small
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
sample(['red', 'green', 'blue'], counts=[1, 2], k=2) # too few counts
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts
|
||||||
|
|
||||||
|
def test_sample_counts_equivalence(self):
|
||||||
|
# Test the documented strong equivalence to a sample with repeated elements.
|
||||||
|
# We run this test on random.Random() which makes deterministic selections
|
||||||
|
# for a given seed value.
|
||||||
|
sample = random.sample
|
||||||
|
seed = random.seed
|
||||||
|
|
||||||
|
colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
|
||||||
|
counts = [500, 200, 20, 10, 5, 1 ]
|
||||||
|
k = 700
|
||||||
|
seed(8675309)
|
||||||
|
s1 = sample(colors, counts=counts, k=k)
|
||||||
|
seed(8675309)
|
||||||
|
expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
|
||||||
|
self.assertEqual(len(expanded), sum(counts))
|
||||||
|
s2 = sample(expanded, k=k)
|
||||||
|
self.assertEqual(s1, s2)
|
||||||
|
|
||||||
|
pop = 'abcdefghi'
|
||||||
|
counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
|
||||||
|
seed(8675309)
|
||||||
|
s1 = ''.join(sample(pop, counts=counts, k=30))
|
||||||
|
expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
|
||||||
|
seed(8675309)
|
||||||
|
s2 = ''.join(sample(expanded, k=30))
|
||||||
|
self.assertEqual(s1, s2)
|
||||||
|
|
||||||
def test_choices(self):
|
def test_choices(self):
|
||||||
choices = self.gen.choices
|
choices = self.gen.choices
|
||||||
data = ['red', 'green', 'blue', 'yellow']
|
data = ['red', 'green', 'blue', 'yellow']
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Added an optional *counts* parameter to random.sample().
|
Loading…
Add table
Add a link
Reference in a new issue