bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case (GH-29828)

This commit is contained in:
Raymond Hettinger 2021-11-30 18:20:08 -06:00 committed by GitHub
parent 8a45ca542a
commit a39f46afde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 22 deletions

View file

@ -137,7 +137,7 @@ from decimal import Decimal
from itertools import groupby, repeat from itertools import groupby, repeat
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
from operator import itemgetter, mul from operator import mul
from collections import Counter, namedtuple from collections import Counter, namedtuple
_SQRT2 = sqrt(2.0) _SQRT2 = sqrt(2.0)
@ -248,6 +248,28 @@ def _exact_ratio(x):
x is expected to be an int, Fraction, Decimal or float. x is expected to be an int, Fraction, Decimal or float.
""" """
# XXX We should revisit whether using fractions to accumulate exact
# ratios is the right way to go.
# The integer ratios for binary floats can have numerators or
# denominators with over 300 decimal digits. The problem is more
# acute with decimal floats where the the default decimal context
# supports a huge range of exponents from Emin=-999999 to
# Emax=999999. When expanded with as_integer_ratio(), numbers like
# Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
# numerators or denominators that will slow computation.
# When the integer ratios are accumulated as fractions, the size
# grows to cover the full range from the smallest magnitude to the
# largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300),
# has a 616 digit numerator. Likewise,
# Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
# has 10,003 digit numerator.
# This doesn't seem to have been problem in practice, but it is a
# potential pitfall.
try: try:
return x.as_integer_ratio() return x.as_integer_ratio()
except AttributeError: except AttributeError:
@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):
raise StatisticsError(errmsg) raise StatisticsError(errmsg)
yield x yield x
def _isqrt_frac_rto(n: int, m: int) -> float:
def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
"""Square root of n/m, rounded to the nearest integer using round-to-odd.""" """Square root of n/m, rounded to the nearest integer using round-to-odd."""
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
a = math.isqrt(n // m) a = math.isqrt(n // m)
return a | (a*a*m != n) return a | (a*a*m != n)
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
def _sqrt_frac(n: int, m: int) -> float: # For 53 bit precision floats, the bit width used in
# _float_sqrt_of_frac() is 109.
_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
def _float_sqrt_of_frac(n: int, m: int) -> float:
"""Square root of n/m as a float, correctly rounded.""" """Square root of n/m as a float, correctly rounded."""
# See principle and proof sketch at: https://bugs.python.org/msg407078 # See principle and proof sketch at: https://bugs.python.org/msg407078
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2 q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
if q >= 0: if q >= 0:
numerator = _isqrt_frac_rto(n, m << 2 * q) << q numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
denominator = 1 denominator = 1
else: else:
numerator = _isqrt_frac_rto(n << -2 * q, m) numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
denominator = 1 << -q denominator = 1 << -q
return numerator / denominator # Convert to float return numerator / denominator # Convert to float
def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
"""Square root of n/m as a Decimal, correctly rounded."""
# Premise: For decimal, computing (n/m).sqrt() can be off
# by 1 ulp from the correctly rounded result.
# Method: Check the result, moving up or down a step if needed.
if n <= 0:
if not n:
return Decimal('0.0')
n, m = -n, -m
root = (Decimal(n) / Decimal(m)).sqrt()
nr, dr = root.as_integer_ratio()
plus = root.next_plus()
np, dp = plus.as_integer_ratio()
# test: n / m > ((root + plus) / 2) ** 2
if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
return plus
minus = root.next_minus()
nm, dm = minus.as_integer_ratio()
# test: n / m < ((root + minus) / 2) ** 2
if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
return minus
return root
# === Measures of central tendency (averages) === # === Measures of central tendency (averages) ===
def mean(data): def mean(data):
@ -869,7 +923,7 @@ def stdev(data, xbar=None):
if hasattr(T, 'sqrt'): if hasattr(T, 'sqrt'):
var = _convert(mss, T) var = _convert(mss, T)
return var.sqrt() return var.sqrt()
return _sqrt_frac(mss.numerator, mss.denominator) return _float_sqrt_of_frac(mss.numerator, mss.denominator)
def pstdev(data, mu=None): def pstdev(data, mu=None):
@ -888,10 +942,9 @@ def pstdev(data, mu=None):
raise StatisticsError('pstdev requires at least one data point') raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu) T, ss = _ss(data, mu)
mss = ss / n mss = ss / n
if hasattr(T, 'sqrt'): if issubclass(T, Decimal):
var = _convert(mss, T) return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
return var.sqrt() return _float_sqrt_of_frac(mss.numerator, mss.denominator)
return _sqrt_frac(mss.numerator, mss.denominator)
# === Statistics for relations between two inputs === # === Statistics for relations between two inputs ===

View file

@ -2164,9 +2164,9 @@ class TestPStdev(VarianceStdevMixin, NumericTestCase):
class TestSqrtHelpers(unittest.TestCase): class TestSqrtHelpers(unittest.TestCase):
def test_isqrt_frac_rto(self): def test_integer_sqrt_of_frac_rto(self):
for n, m in itertools.product(range(100), range(1, 1000)): for n, m in itertools.product(range(100), range(1, 1000)):
r = statistics._isqrt_frac_rto(n, m) r = statistics._integer_sqrt_of_frac_rto(n, m)
self.assertIsInstance(r, int) self.assertIsInstance(r, int)
if r*r*m == n: if r*r*m == n:
# Root is exact # Root is exact
@ -2177,7 +2177,7 @@ class TestSqrtHelpers(unittest.TestCase):
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2) self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
@requires_IEEE_754 @requires_IEEE_754
def test_sqrt_frac(self): def test_float_sqrt_of_frac(self):
def is_root_correctly_rounded(x: Fraction, root: float) -> bool: def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
if not x: if not x:
@ -2204,22 +2204,59 @@ class TestSqrtHelpers(unittest.TestCase):
denonimator: int = randrange(10 ** randrange(50)) + 1 denonimator: int = randrange(10 ** randrange(50)) + 1
with self.subTest(numerator=numerator, denonimator=denonimator): with self.subTest(numerator=numerator, denonimator=denonimator):
x: Fraction = Fraction(numerator, denonimator) x: Fraction = Fraction(numerator, denonimator)
root: float = statistics._sqrt_frac(numerator, denonimator) root: float = statistics._float_sqrt_of_frac(numerator, denonimator)
self.assertTrue(is_root_correctly_rounded(x, root)) self.assertTrue(is_root_correctly_rounded(x, root))
# Verify that corner cases and error handling match math.sqrt() # Verify that corner cases and error handling match math.sqrt()
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0) self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
statistics._sqrt_frac(-1, 1) statistics._float_sqrt_of_frac(-1, 1)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
statistics._sqrt_frac(1, -1) statistics._float_sqrt_of_frac(1, -1)
# Error handling for zero denominator matches that for Fraction(1, 0) # Error handling for zero denominator matches that for Fraction(1, 0)
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
statistics._sqrt_frac(1, 0) statistics._float_sqrt_of_frac(1, 0)
# The result is well defined if both inputs are negative # The result is well defined if both inputs are negative
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0)) self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1))
def test_decimal_sqrt_of_frac(self):
root: Decimal
numerator: int
denominator: int
for root, numerator, denominator in [
(Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj
(Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up
(Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down
]:
with decimal.localcontext(decimal.DefaultContext):
self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root)
# Confirm expected root with a quad precision decimal computation
with decimal.localcontext(decimal.DefaultContext) as ctx:
ctx.prec *= 4
high_prec_ratio = Decimal(numerator) / Decimal(denominator)
ctx.rounding = decimal.ROUND_05UP
high_prec_root = high_prec_ratio.sqrt()
with decimal.localcontext(decimal.DefaultContext):
target_root = +high_prec_root
self.assertEqual(root, target_root)
# Verify that corner cases and error handling match Decimal.sqrt()
self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0)
with self.assertRaises(decimal.InvalidOperation):
statistics._decimal_sqrt_of_frac(-1, 1)
with self.assertRaises(decimal.InvalidOperation):
statistics._decimal_sqrt_of_frac(1, -1)
# Error handling for zero denominator matches that for Fraction(1, 0)
with self.assertRaises(ZeroDivisionError):
statistics._decimal_sqrt_of_frac(1, 0)
# The result is well defined if both inputs are negative
self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1))
class TestStdev(VarianceStdevMixin, NumericTestCase): class TestStdev(VarianceStdevMixin, NumericTestCase):