mirror of
https://github.com/python/cpython.git
synced 2025-08-04 00:48:58 +00:00
bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970)
This commit is contained in:
parent
2effef7453
commit
81a5fc38e8
4 changed files with 116 additions and 13 deletions
|
@ -9,7 +9,7 @@ from functools import partial
|
|||
from math import log, exp, pi, fsum, sin, factorial
|
||||
from test import support
|
||||
from fractions import Fraction
|
||||
|
||||
from collections import Counter
|
||||
|
||||
class TestBasicOps:
|
||||
# Superclass with tests common to all generators.
|
||||
|
@ -161,6 +161,77 @@ class TestBasicOps:
|
|||
population = {10, 20, 30, 40, 50, 60, 70}
|
||||
self.gen.sample(population, k=5)
|
||||
|
||||
def test_sample_with_counts(self):
|
||||
sample = self.gen.sample
|
||||
|
||||
# General case
|
||||
colors = ['red', 'green', 'blue', 'orange', 'black', 'brown', 'amber']
|
||||
counts = [500, 200, 20, 10, 5, 0, 1 ]
|
||||
k = 700
|
||||
summary = Counter(sample(colors, counts=counts, k=k))
|
||||
self.assertEqual(sum(summary.values()), k)
|
||||
for color, weight in zip(colors, counts):
|
||||
self.assertLessEqual(summary[color], weight)
|
||||
self.assertNotIn('brown', summary)
|
||||
|
||||
# Case that exhausts the population
|
||||
k = sum(counts)
|
||||
summary = Counter(sample(colors, counts=counts, k=k))
|
||||
self.assertEqual(sum(summary.values()), k)
|
||||
for color, weight in zip(colors, counts):
|
||||
self.assertLessEqual(summary[color], weight)
|
||||
self.assertNotIn('brown', summary)
|
||||
|
||||
# Case with population size of 1
|
||||
summary = Counter(sample(['x'], counts=[10], k=8))
|
||||
self.assertEqual(summary, Counter(x=8))
|
||||
|
||||
# Case with all counts equal.
|
||||
nc = len(colors)
|
||||
summary = Counter(sample(colors, counts=[10]*nc, k=10*nc))
|
||||
self.assertEqual(summary, Counter(10*colors))
|
||||
|
||||
# Test error handling
|
||||
with self.assertRaises(TypeError):
|
||||
sample(['red', 'green', 'blue'], counts=10, k=10) # counts not iterable
|
||||
with self.assertRaises(ValueError):
|
||||
sample(['red', 'green', 'blue'], counts=[-3, -7, -8], k=2) # counts are negative
|
||||
with self.assertRaises(ValueError):
|
||||
sample(['red', 'green', 'blue'], counts=[0, 0, 0], k=2) # counts are zero
|
||||
with self.assertRaises(ValueError):
|
||||
sample(['red', 'green'], counts=[10, 10], k=21) # population too small
|
||||
with self.assertRaises(ValueError):
|
||||
sample(['red', 'green', 'blue'], counts=[1, 2], k=2) # too few counts
|
||||
with self.assertRaises(ValueError):
|
||||
sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts
|
||||
|
||||
def test_sample_counts_equivalence(self):
|
||||
# Test the documented strong equivalence to a sample with repeated elements.
|
||||
# We run this test on random.Random() which makes deterministic selections
|
||||
# for a given seed value.
|
||||
sample = random.sample
|
||||
seed = random.seed
|
||||
|
||||
colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
|
||||
counts = [500, 200, 20, 10, 5, 1 ]
|
||||
k = 700
|
||||
seed(8675309)
|
||||
s1 = sample(colors, counts=counts, k=k)
|
||||
seed(8675309)
|
||||
expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
|
||||
self.assertEqual(len(expanded), sum(counts))
|
||||
s2 = sample(expanded, k=k)
|
||||
self.assertEqual(s1, s2)
|
||||
|
||||
pop = 'abcdefghi'
|
||||
counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
|
||||
seed(8675309)
|
||||
s1 = ''.join(sample(pop, counts=counts, k=30))
|
||||
expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
|
||||
seed(8675309)
|
||||
s2 = ''.join(sample(expanded, k=30))
|
||||
self.assertEqual(s1, s2)
|
||||
|
||||
def test_choices(self):
|
||||
choices = self.gen.choices
|
||||
data = ['red', 'green', 'blue', 'yellow']
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue