mirror of
https://github.com/python/cpython.git
synced 2025-08-31 22:18:28 +00:00
bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)
* Inlined code from variance functions * Added helper functions for the float square root of a fraction * Call helper functions * Add blurb * Fix over-specified test * Add a test for the _sqrt_frac() helper function * Increase the tested range * Add type hints to the internal function. * Fix test for correct rounding * Simplify ⌊√(n/m)⌋ calculation Co-authored-by: Mark Dickinson <dickinsm@gmail.com> * Add comment and beef-up tests * Test for zero denominator * Add algorithmic references * Add test for the _isqrt_frac_rto() helper function. * Compute the 109 instead of hard-wiring it * Stronger test for _isqrt_frac_rto() * Bigger range * Bigger range * Replace float() call with int/int division to be parallel with the other code path. * Factor out division. Update proof link. Remove internal type declaration Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
This commit is contained in:
parent
db55f3faba
commit
af9ee57b96
3 changed files with 107 additions and 16 deletions
|
@ -130,6 +130,7 @@ __all__ = [
|
||||||
import math
|
import math
|
||||||
import numbers
|
import numbers
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
@ -304,6 +305,27 @@ def _fail_neg(values, errmsg='negative value'):
|
||||||
raise StatisticsError(errmsg)
|
raise StatisticsError(errmsg)
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
|
def _isqrt_frac_rto(n: int, m: int) -> float:
|
||||||
|
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
|
||||||
|
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
|
||||||
|
a = math.isqrt(n // m)
|
||||||
|
return a | (a*a*m != n)
|
||||||
|
|
||||||
|
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
|
||||||
|
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
|
||||||
|
|
||||||
|
def _sqrt_frac(n: int, m: int) -> float:
|
||||||
|
"""Square root of n/m as a float, correctly rounded."""
|
||||||
|
# See principle and proof sketch at: https://bugs.python.org/msg407078
|
||||||
|
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
|
||||||
|
if q >= 0:
|
||||||
|
numerator = _isqrt_frac_rto(n, m << 2 * q) << q
|
||||||
|
denominator = 1
|
||||||
|
else:
|
||||||
|
numerator = _isqrt_frac_rto(n << -2 * q, m)
|
||||||
|
denominator = 1 << -q
|
||||||
|
return numerator / denominator # Convert to float
|
||||||
|
|
||||||
|
|
||||||
# === Measures of central tendency (averages) ===
|
# === Measures of central tendency (averages) ===
|
||||||
|
|
||||||
|
@ -837,14 +859,17 @@ def stdev(data, xbar=None):
|
||||||
1.0810874155219827
|
1.0810874155219827
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
|
if iter(data) is data:
|
||||||
# remain because there are two rounding steps. The first occurs in
|
data = list(data)
|
||||||
# the _convert() step for variance(), the second occurs in math.sqrt().
|
n = len(data)
|
||||||
var = variance(data, xbar)
|
if n < 2:
|
||||||
try:
|
raise StatisticsError('stdev requires at least two data points')
|
||||||
|
T, ss = _ss(data, xbar)
|
||||||
|
mss = ss / (n - 1)
|
||||||
|
if hasattr(T, 'sqrt'):
|
||||||
|
var = _convert(mss, T)
|
||||||
return var.sqrt()
|
return var.sqrt()
|
||||||
except AttributeError:
|
return _sqrt_frac(mss.numerator, mss.denominator)
|
||||||
return math.sqrt(var)
|
|
||||||
|
|
||||||
|
|
||||||
def pstdev(data, mu=None):
|
def pstdev(data, mu=None):
|
||||||
|
@ -856,14 +881,17 @@ def pstdev(data, mu=None):
|
||||||
0.986893273527251
|
0.986893273527251
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
|
if iter(data) is data:
|
||||||
# remain because there are two rounding steps. The first occurs in
|
data = list(data)
|
||||||
# the _convert() step for pvariance(), the second occurs in math.sqrt().
|
n = len(data)
|
||||||
var = pvariance(data, mu)
|
if n < 1:
|
||||||
try:
|
raise StatisticsError('pstdev requires at least one data point')
|
||||||
|
T, ss = _ss(data, mu)
|
||||||
|
mss = ss / n
|
||||||
|
if hasattr(T, 'sqrt'):
|
||||||
|
var = _convert(mss, T)
|
||||||
return var.sqrt()
|
return var.sqrt()
|
||||||
except AttributeError:
|
return _sqrt_frac(mss.numerator, mss.denominator)
|
||||||
return math.sqrt(var)
|
|
||||||
|
|
||||||
|
|
||||||
# === Statistics for relations between two inputs ===
|
# === Statistics for relations between two inputs ===
|
||||||
|
|
|
@ -9,13 +9,14 @@ import collections.abc
|
||||||
import copy
|
import copy
|
||||||
import decimal
|
import decimal
|
||||||
import doctest
|
import doctest
|
||||||
|
import itertools
|
||||||
import math
|
import math
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from test import support
|
from test import support
|
||||||
from test.support import import_helper
|
from test.support import import_helper, requires_IEEE_754
|
||||||
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
|
@ -2161,6 +2162,66 @@ class TestPStdev(VarianceStdevMixin, NumericTestCase):
|
||||||
self.assertEqual(self.func(data), 2.5)
|
self.assertEqual(self.func(data), 2.5)
|
||||||
self.assertEqual(self.func(data, mu=0.5), 6.5)
|
self.assertEqual(self.func(data, mu=0.5), 6.5)
|
||||||
|
|
||||||
|
class TestSqrtHelpers(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_isqrt_frac_rto(self):
|
||||||
|
for n, m in itertools.product(range(100), range(1, 1000)):
|
||||||
|
r = statistics._isqrt_frac_rto(n, m)
|
||||||
|
self.assertIsInstance(r, int)
|
||||||
|
if r*r*m == n:
|
||||||
|
# Root is exact
|
||||||
|
continue
|
||||||
|
# Inexact, so the root should be odd
|
||||||
|
self.assertEqual(r&1, 1)
|
||||||
|
# Verify correct rounding
|
||||||
|
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
|
||||||
|
|
||||||
|
@requires_IEEE_754
|
||||||
|
def test_sqrt_frac(self):
|
||||||
|
|
||||||
|
def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
|
||||||
|
if not x:
|
||||||
|
return root == 0.0
|
||||||
|
|
||||||
|
# Extract adjacent representable floats
|
||||||
|
r_up: float = math.nextafter(root, math.inf)
|
||||||
|
r_down: float = math.nextafter(root, -math.inf)
|
||||||
|
assert r_down < root < r_up
|
||||||
|
|
||||||
|
# Convert to fractions for exact arithmetic
|
||||||
|
frac_root: Fraction = Fraction(root)
|
||||||
|
half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2
|
||||||
|
half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2
|
||||||
|
|
||||||
|
# Check a closed interval.
|
||||||
|
# Does not test for a midpoint rounding rule.
|
||||||
|
return half_way_down ** 2 <= x <= half_way_up ** 2
|
||||||
|
|
||||||
|
randrange = random.randrange
|
||||||
|
|
||||||
|
for i in range(60_000):
|
||||||
|
numerator: int = randrange(10 ** randrange(50))
|
||||||
|
denonimator: int = randrange(10 ** randrange(50)) + 1
|
||||||
|
with self.subTest(numerator=numerator, denonimator=denonimator):
|
||||||
|
x: Fraction = Fraction(numerator, denonimator)
|
||||||
|
root: float = statistics._sqrt_frac(numerator, denonimator)
|
||||||
|
self.assertTrue(is_root_correctly_rounded(x, root))
|
||||||
|
|
||||||
|
# Verify that corner cases and error handling match math.sqrt()
|
||||||
|
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
statistics._sqrt_frac(-1, 1)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
statistics._sqrt_frac(1, -1)
|
||||||
|
|
||||||
|
# Error handling for zero denominator matches that for Fraction(1, 0)
|
||||||
|
with self.assertRaises(ZeroDivisionError):
|
||||||
|
statistics._sqrt_frac(1, 0)
|
||||||
|
|
||||||
|
# The result is well defined if both inputs are negative
|
||||||
|
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
|
||||||
|
|
||||||
|
|
||||||
class TestStdev(VarianceStdevMixin, NumericTestCase):
|
class TestStdev(VarianceStdevMixin, NumericTestCase):
|
||||||
# Tests for sample standard deviation.
|
# Tests for sample standard deviation.
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -2175,7 +2236,7 @@ class TestStdev(VarianceStdevMixin, NumericTestCase):
|
||||||
# Test that stdev is, in fact, the square root of variance.
|
# Test that stdev is, in fact, the square root of variance.
|
||||||
data = [random.uniform(-2, 9) for _ in range(1000)]
|
data = [random.uniform(-2, 9) for _ in range(1000)]
|
||||||
expected = math.sqrt(statistics.variance(data))
|
expected = math.sqrt(statistics.variance(data))
|
||||||
self.assertEqual(self.func(data), expected)
|
self.assertAlmostEqual(self.func(data), expected)
|
||||||
|
|
||||||
def test_center_not_at_mean(self):
|
def test_center_not_at_mean(self):
|
||||||
data = (1.0, 2.0)
|
data = (1.0, 2.0)
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
Improve the accuracy of stdev() and pstdev() in the statistics module. When
|
||||||
|
the inputs are floats or fractions, the output is a correctly rounded float
|
Loading…
Add table
Add a link
Reference in a new issue