bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)

* Inlined code from variance functions

* Added helper functions for the float square root of a fraction

* Call helper functions

* Add blurb

* Fix over-specified test

* Add a test for the _sqrt_frac() helper function

* Increase the tested range

* Add type hints to the internal function.

* Fix test for correct rounding

* Simplify ⌊√(n/m)⌋ calculation

Co-authored-by: Mark Dickinson <dickinsm@gmail.com>

* Add comment and beef-up tests

* Test for zero denominator

* Add algorithmic references

* Add test for the _isqrt_frac_rto() helper function.

* Compute the 109 instead of hard-wiring it

* Stronger test for _isqrt_frac_rto()

* Bigger range

* Bigger range

* Replace float() call with int/int division to be parallel with the other code path.

* Factor out division. Update proof link. Remove internal type declaration

Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
This commit is contained in:
Raymond Hettinger 2021-11-26 22:54:50 -07:00 committed by GitHub
parent db55f3faba
commit af9ee57b96
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 16 deletions

View file

@ -130,6 +130,7 @@ __all__ = [
import math import math
import numbers import numbers
import random import random
import sys
from fractions import Fraction from fractions import Fraction
from decimal import Decimal from decimal import Decimal
@ -304,6 +305,27 @@ def _fail_neg(values, errmsg='negative value'):
raise StatisticsError(errmsg) raise StatisticsError(errmsg)
yield x yield x
def _isqrt_frac_rto(n: int, m: int) -> float:
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
a = math.isqrt(n // m)
return a | (a*a*m != n)
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
def _sqrt_frac(n: int, m: int) -> float:
"""Square root of n/m as a float, correctly rounded."""
# See principle and proof sketch at: https://bugs.python.org/msg407078
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
if q >= 0:
numerator = _isqrt_frac_rto(n, m << 2 * q) << q
denominator = 1
else:
numerator = _isqrt_frac_rto(n << -2 * q, m)
denominator = 1 << -q
return numerator / denominator # Convert to float
# === Measures of central tendency (averages) === # === Measures of central tendency (averages) ===
@ -837,14 +859,17 @@ def stdev(data, xbar=None):
1.0810874155219827 1.0810874155219827
""" """
# Fixme: Despite the exact sum of squared deviations, some inaccuracy if iter(data) is data:
# remain because there are two rounding steps. The first occurs in data = list(data)
# the _convert() step for variance(), the second occurs in math.sqrt(). n = len(data)
var = variance(data, xbar) if n < 2:
try: raise StatisticsError('stdev requires at least two data points')
T, ss = _ss(data, xbar)
mss = ss / (n - 1)
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt() return var.sqrt()
except AttributeError: return _sqrt_frac(mss.numerator, mss.denominator)
return math.sqrt(var)
def pstdev(data, mu=None): def pstdev(data, mu=None):
@ -856,14 +881,17 @@ def pstdev(data, mu=None):
0.986893273527251 0.986893273527251
""" """
# Fixme: Despite the exact sum of squared deviations, some inaccuracy if iter(data) is data:
# remain because there are two rounding steps. The first occurs in data = list(data)
# the _convert() step for pvariance(), the second occurs in math.sqrt(). n = len(data)
var = pvariance(data, mu) if n < 1:
try: raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu)
mss = ss / n
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt() return var.sqrt()
except AttributeError: return _sqrt_frac(mss.numerator, mss.denominator)
return math.sqrt(var)
# === Statistics for relations between two inputs === # === Statistics for relations between two inputs ===

View file

@ -9,13 +9,14 @@ import collections.abc
import copy import copy
import decimal import decimal
import doctest import doctest
import itertools
import math import math
import pickle import pickle
import random import random
import sys import sys
import unittest import unittest
from test import support from test import support
from test.support import import_helper from test.support import import_helper, requires_IEEE_754
from decimal import Decimal from decimal import Decimal
from fractions import Fraction from fractions import Fraction
@ -2161,6 +2162,66 @@ class TestPStdev(VarianceStdevMixin, NumericTestCase):
self.assertEqual(self.func(data), 2.5) self.assertEqual(self.func(data), 2.5)
self.assertEqual(self.func(data, mu=0.5), 6.5) self.assertEqual(self.func(data, mu=0.5), 6.5)
class TestSqrtHelpers(unittest.TestCase):
def test_isqrt_frac_rto(self):
for n, m in itertools.product(range(100), range(1, 1000)):
r = statistics._isqrt_frac_rto(n, m)
self.assertIsInstance(r, int)
if r*r*m == n:
# Root is exact
continue
# Inexact, so the root should be odd
self.assertEqual(r&1, 1)
# Verify correct rounding
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
@requires_IEEE_754
def test_sqrt_frac(self):
def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
if not x:
return root == 0.0
# Extract adjacent representable floats
r_up: float = math.nextafter(root, math.inf)
r_down: float = math.nextafter(root, -math.inf)
assert r_down < root < r_up
# Convert to fractions for exact arithmetic
frac_root: Fraction = Fraction(root)
half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2
half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2
# Check a closed interval.
# Does not test for a midpoint rounding rule.
return half_way_down ** 2 <= x <= half_way_up ** 2
randrange = random.randrange
for i in range(60_000):
numerator: int = randrange(10 ** randrange(50))
denonimator: int = randrange(10 ** randrange(50)) + 1
with self.subTest(numerator=numerator, denonimator=denonimator):
x: Fraction = Fraction(numerator, denonimator)
root: float = statistics._sqrt_frac(numerator, denonimator)
self.assertTrue(is_root_correctly_rounded(x, root))
# Verify that corner cases and error handling match math.sqrt()
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
with self.assertRaises(ValueError):
statistics._sqrt_frac(-1, 1)
with self.assertRaises(ValueError):
statistics._sqrt_frac(1, -1)
# Error handling for zero denominator matches that for Fraction(1, 0)
with self.assertRaises(ZeroDivisionError):
statistics._sqrt_frac(1, 0)
# The result is well defined if both inputs are negative
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
class TestStdev(VarianceStdevMixin, NumericTestCase): class TestStdev(VarianceStdevMixin, NumericTestCase):
# Tests for sample standard deviation. # Tests for sample standard deviation.
def setUp(self): def setUp(self):
@ -2175,7 +2236,7 @@ class TestStdev(VarianceStdevMixin, NumericTestCase):
# Test that stdev is, in fact, the square root of variance. # Test that stdev is, in fact, the square root of variance.
data = [random.uniform(-2, 9) for _ in range(1000)] data = [random.uniform(-2, 9) for _ in range(1000)]
expected = math.sqrt(statistics.variance(data)) expected = math.sqrt(statistics.variance(data))
self.assertEqual(self.func(data), expected) self.assertAlmostEqual(self.func(data), expected)
def test_center_not_at_mean(self): def test_center_not_at_mean(self):
data = (1.0, 2.0) data = (1.0, 2.0)

View file

@ -0,0 +1,2 @@
Improve the accuracy of stdev() and pstdev() in the statistics module. When
the inputs are floats or fractions, the output is a correctly rounded float