bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)

* Inlined code from variance functions

* Added helper functions for the float square root of a fraction

* Call helper functions

* Add blurb

* Fix over-specified test

* Add a test for the _sqrt_frac() helper function

* Increase the tested range

* Add type hints to the internal function.

* Fix test for correct rounding

* Simplify ⌊√(n/m)⌋ calculation

Co-authored-by: Mark Dickinson <dickinsm@gmail.com>

* Add comment and beef-up tests

* Test for zero denominator

* Add algorithmic references

* Add test for the _isqrt_frac_rto() helper function.

* Compute the 109 instead of hard-wiring it

* Stronger test for _isqrt_frac_rto()

* Bigger range

* Bigger range

* Replace float() call with int/int division to be parallel with the other code path.

* Factor out division. Update proof link. Remove internal type declaration

Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
This commit is contained in:
Raymond Hettinger 2021-11-26 22:54:50 -07:00 committed by GitHub
parent db55f3faba
commit af9ee57b96
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 16 deletions

View file

@ -9,13 +9,14 @@ import collections.abc
import copy
import decimal
import doctest
import itertools
import math
import pickle
import random
import sys
import unittest
from test import support
from test.support import import_helper
from test.support import import_helper, requires_IEEE_754
from decimal import Decimal
from fractions import Fraction
@ -2161,6 +2162,66 @@ class TestPStdev(VarianceStdevMixin, NumericTestCase):
self.assertEqual(self.func(data), 2.5)
self.assertEqual(self.func(data, mu=0.5), 6.5)
class TestSqrtHelpers(unittest.TestCase):
def test_isqrt_frac_rto(self):
for n, m in itertools.product(range(100), range(1, 1000)):
r = statistics._isqrt_frac_rto(n, m)
self.assertIsInstance(r, int)
if r*r*m == n:
# Root is exact
continue
# Inexact, so the root should be odd
self.assertEqual(r&1, 1)
# Verify correct rounding
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
@requires_IEEE_754
def test_sqrt_frac(self):
def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
if not x:
return root == 0.0
# Extract adjacent representable floats
r_up: float = math.nextafter(root, math.inf)
r_down: float = math.nextafter(root, -math.inf)
assert r_down < root < r_up
# Convert to fractions for exact arithmetic
frac_root: Fraction = Fraction(root)
half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2
half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2
# Check a closed interval.
# Does not test for a midpoint rounding rule.
return half_way_down ** 2 <= x <= half_way_up ** 2
randrange = random.randrange
for i in range(60_000):
numerator: int = randrange(10 ** randrange(50))
denonimator: int = randrange(10 ** randrange(50)) + 1
with self.subTest(numerator=numerator, denonimator=denonimator):
x: Fraction = Fraction(numerator, denonimator)
root: float = statistics._sqrt_frac(numerator, denonimator)
self.assertTrue(is_root_correctly_rounded(x, root))
# Verify that corner cases and error handling match math.sqrt()
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
with self.assertRaises(ValueError):
statistics._sqrt_frac(-1, 1)
with self.assertRaises(ValueError):
statistics._sqrt_frac(1, -1)
# Error handling for zero denominator matches that for Fraction(1, 0)
with self.assertRaises(ZeroDivisionError):
statistics._sqrt_frac(1, 0)
# The result is well defined if both inputs are negative
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
class TestStdev(VarianceStdevMixin, NumericTestCase):
# Tests for sample standard deviation.
def setUp(self):
@ -2175,7 +2236,7 @@ class TestStdev(VarianceStdevMixin, NumericTestCase):
# Test that stdev is, in fact, the square root of variance.
data = [random.uniform(-2, 9) for _ in range(1000)]
expected = math.sqrt(statistics.variance(data))
self.assertEqual(self.func(data), expected)
self.assertAlmostEqual(self.func(data), expected)
def test_center_not_at_mean(self):
data = (1.0, 2.0)