bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case (GH-29828)

This commit is contained in:
Raymond Hettinger 2021-11-30 18:20:08 -06:00 committed by GitHub
parent 8a45ca542a
commit a39f46afde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 22 deletions

View file

@ -137,7 +137,7 @@ from decimal import Decimal
from itertools import groupby, repeat
from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
from operator import itemgetter, mul
from operator import mul
from collections import Counter, namedtuple
_SQRT2 = sqrt(2.0)
@ -248,6 +248,28 @@ def _exact_ratio(x):
x is expected to be an int, Fraction, Decimal or float.
"""
# XXX We should revisit whether using fractions to accumulate exact
# ratios is the right way to go.
# The integer ratios for binary floats can have numerators or
# denominators with over 300 decimal digits. The problem is more
# acute with decimal floats where the the default decimal context
# supports a huge range of exponents from Emin=-999999 to
# Emax=999999. When expanded with as_integer_ratio(), numbers like
# Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
# numerators or denominators that will slow computation.
# When the integer ratios are accumulated as fractions, the size
# grows to cover the full range from the smallest magnitude to the
# largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300),
# has a 616 digit numerator. Likewise,
# Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
# has 10,003 digit numerator.
# This doesn't seem to have been problem in practice, but it is a
# potential pitfall.
try:
return x.as_integer_ratio()
except AttributeError:
@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):
raise StatisticsError(errmsg)
yield x
def _isqrt_frac_rto(n: int, m: int) -> float:
def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
a = math.isqrt(n // m)
return a | (a*a*m != n)
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
def _sqrt_frac(n: int, m: int) -> float:
# For 53 bit precision floats, the bit width used in
# _float_sqrt_of_frac() is 109.
_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
def _float_sqrt_of_frac(n: int, m: int) -> float:
"""Square root of n/m as a float, correctly rounded."""
# See principle and proof sketch at: https://bugs.python.org/msg407078
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
if q >= 0:
numerator = _isqrt_frac_rto(n, m << 2 * q) << q
numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
denominator = 1
else:
numerator = _isqrt_frac_rto(n << -2 * q, m)
numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
denominator = 1 << -q
return numerator / denominator # Convert to float
def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
"""Square root of n/m as a Decimal, correctly rounded."""
# Premise: For decimal, computing (n/m).sqrt() can be off
# by 1 ulp from the correctly rounded result.
# Method: Check the result, moving up or down a step if needed.
if n <= 0:
if not n:
return Decimal('0.0')
n, m = -n, -m
root = (Decimal(n) / Decimal(m)).sqrt()
nr, dr = root.as_integer_ratio()
plus = root.next_plus()
np, dp = plus.as_integer_ratio()
# test: n / m > ((root + plus) / 2) ** 2
if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
return plus
minus = root.next_minus()
nm, dm = minus.as_integer_ratio()
# test: n / m < ((root + minus) / 2) ** 2
if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
return minus
return root
# === Measures of central tendency (averages) ===
def mean(data):
@ -869,7 +923,7 @@ def stdev(data, xbar=None):
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt()
return _sqrt_frac(mss.numerator, mss.denominator)
return _float_sqrt_of_frac(mss.numerator, mss.denominator)
def pstdev(data, mu=None):
@ -888,10 +942,9 @@ def pstdev(data, mu=None):
raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu)
mss = ss / n
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt()
return _sqrt_frac(mss.numerator, mss.denominator)
if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
return _float_sqrt_of_frac(mss.numerator, mss.denominator)
# === Statistics for relations between two inputs ===