mirror of
https://github.com/python/cpython.git
synced 2025-08-04 17:08:35 +00:00
bpo-20499: Rounding error in statistics.pvariance (GH-28230)
This commit is contained in:
parent
f235dd0784
commit
4a5cccb02b
3 changed files with 51 additions and 56 deletions
|
@ -147,21 +147,17 @@ class StatisticsError(ValueError):
|
|||
|
||||
# === Private utilities ===
|
||||
|
||||
def _sum(data, start=0):
|
||||
"""_sum(data [, start]) -> (type, sum, count)
|
||||
def _sum(data):
|
||||
"""_sum(data) -> (type, sum, count)
|
||||
|
||||
Return a high-precision sum of the given numeric data as a fraction,
|
||||
together with the type to be converted to and the count of items.
|
||||
|
||||
If optional argument ``start`` is given, it is added to the total.
|
||||
If ``data`` is empty, ``start`` (defaulting to 0) is returned.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
|
||||
(<class 'float'>, Fraction(11, 1), 5)
|
||||
>>> _sum([3, 2.25, 4.5, -0.5, 0.25])
|
||||
(<class 'float'>, Fraction(19, 2), 5)
|
||||
|
||||
Some sources of round-off error will be avoided:
|
||||
|
||||
|
@ -184,10 +180,9 @@ def _sum(data, start=0):
|
|||
allowed.
|
||||
"""
|
||||
count = 0
|
||||
n, d = _exact_ratio(start)
|
||||
partials = {d: n}
|
||||
partials = {}
|
||||
partials_get = partials.get
|
||||
T = _coerce(int, type(start))
|
||||
T = int
|
||||
for typ, values in groupby(data, type):
|
||||
T = _coerce(T, typ) # or raise TypeError
|
||||
for n, d in map(_exact_ratio, values):
|
||||
|
@ -200,8 +195,7 @@ def _sum(data, start=0):
|
|||
assert not _isfinite(total)
|
||||
else:
|
||||
# Sum all the partial sums using builtin sum.
|
||||
# FIXME is this faster if we sum them in order of the denominator?
|
||||
total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
|
||||
total = sum(Fraction(n, d) for d, n in partials.items())
|
||||
return (T, total, count)
|
||||
|
||||
|
||||
|
@ -252,27 +246,19 @@ def _exact_ratio(x):
|
|||
x is expected to be an int, Fraction, Decimal or float.
|
||||
"""
|
||||
try:
|
||||
# Optimise the common case of floats. We expect that the most often
|
||||
# used numeric type will be builtin floats, so try to make this as
|
||||
# fast as possible.
|
||||
if type(x) is float or type(x) is Decimal:
|
||||
return x.as_integer_ratio()
|
||||
try:
|
||||
# x may be an int, Fraction, or Integral ABC.
|
||||
return (x.numerator, x.denominator)
|
||||
except AttributeError:
|
||||
try:
|
||||
# x may be a float or Decimal subclass.
|
||||
return x.as_integer_ratio()
|
||||
except AttributeError:
|
||||
# Just give up?
|
||||
pass
|
||||
return x.as_integer_ratio()
|
||||
except AttributeError:
|
||||
pass
|
||||
except (OverflowError, ValueError):
|
||||
# float NAN or INF.
|
||||
assert not _isfinite(x)
|
||||
return (x, None)
|
||||
msg = "can't convert type '{}' to numerator/denominator"
|
||||
raise TypeError(msg.format(type(x).__name__))
|
||||
try:
|
||||
# x may be an Integral ABC.
|
||||
return (x.numerator, x.denominator)
|
||||
except AttributeError:
|
||||
msg = f"can't convert type '{type(x).__name__}' to numerator/denominator"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def _convert(value, T):
|
||||
|
@ -730,18 +716,20 @@ def _ss(data, c=None):
|
|||
if c is not None:
|
||||
T, total, count = _sum((d := x - c) * d for x in data)
|
||||
return (T, total)
|
||||
# Compute the mean accurate to within 1/2 ulp
|
||||
c = mean(data)
|
||||
# Initial computation for the sum of square deviations
|
||||
T, total, count = _sum((d := x - c) * d for x in data)
|
||||
# Correct any remaining inaccuracy in the mean c.
|
||||
# The following sum should mathematically equal zero,
|
||||
# but due to the final rounding of the mean, it may not.
|
||||
U, error, count2 = _sum((x - c) for x in data)
|
||||
assert count == count2
|
||||
correction = error * error / len(data)
|
||||
total -= correction
|
||||
assert not total < 0, 'negative sum of square deviations: %f' % total
|
||||
T, total, count = _sum(data)
|
||||
mean_n, mean_d = (total / count).as_integer_ratio()
|
||||
partials = Counter()
|
||||
for n, d in map(_exact_ratio, data):
|
||||
diff_n = n * mean_d - d * mean_n
|
||||
diff_d = d * mean_d
|
||||
partials[diff_d * diff_d] += diff_n * diff_n
|
||||
if None in partials:
|
||||
# The sum will be a NAN or INF. We can ignore all the finite
|
||||
# partials, and just look at this special one.
|
||||
total = partials[None]
|
||||
assert not _isfinite(total)
|
||||
else:
|
||||
total = sum(Fraction(n, d) for d, n in partials.items())
|
||||
return (T, total)
|
||||
|
||||
|
||||
|
@ -845,6 +833,9 @@ def stdev(data, xbar=None):
|
|||
1.0810874155219827
|
||||
|
||||
"""
|
||||
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
|
||||
# remain because there are two rounding steps. The first occurs in
|
||||
# the _convert() step for variance(), the second occurs in math.sqrt().
|
||||
var = variance(data, xbar)
|
||||
try:
|
||||
return var.sqrt()
|
||||
|
@ -861,6 +852,9 @@ def pstdev(data, mu=None):
|
|||
0.986893273527251
|
||||
|
||||
"""
|
||||
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
|
||||
# remain because there are two rounding steps. The first occurs in
|
||||
# the _convert() step for pvariance(), the second occurs in math.sqrt().
|
||||
var = pvariance(data, mu)
|
||||
try:
|
||||
return var.sqrt()
|
||||
|
|
|
@ -1250,20 +1250,14 @@ class TestSum(NumericTestCase):
|
|||
# Override test for empty data.
|
||||
for data in ([], (), iter([])):
|
||||
self.assertEqual(self.func(data), (int, Fraction(0), 0))
|
||||
self.assertEqual(self.func(data, 23), (int, Fraction(23), 0))
|
||||
self.assertEqual(self.func(data, 2.3), (float, Fraction(2.3), 0))
|
||||
|
||||
def test_ints(self):
|
||||
self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
|
||||
(int, Fraction(60), 8))
|
||||
self.assertEqual(self.func([4, 2, 3, -8, 7], 1000),
|
||||
(int, Fraction(1008), 5))
|
||||
|
||||
def test_floats(self):
|
||||
self.assertEqual(self.func([0.25]*20),
|
||||
(float, Fraction(5.0), 20))
|
||||
self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5),
|
||||
(float, Fraction(3.125), 4))
|
||||
|
||||
def test_fractions(self):
|
||||
self.assertEqual(self.func([Fraction(1, 1000)]*500),
|
||||
|
@ -1284,14 +1278,6 @@ class TestSum(NumericTestCase):
|
|||
data = [random.uniform(-100, 1000) for _ in range(1000)]
|
||||
self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
|
||||
|
||||
def test_start_argument(self):
|
||||
# Test that the optional start argument works correctly.
|
||||
data = [random.uniform(1, 1000) for _ in range(100)]
|
||||
t = self.func(data)[1]
|
||||
self.assertEqual(t+42, self.func(data, 42)[1])
|
||||
self.assertEqual(t-23, self.func(data, -23)[1])
|
||||
self.assertEqual(t+Fraction(1e20), self.func(data, 1e20)[1])
|
||||
|
||||
def test_strings_fail(self):
|
||||
# Sum of strings should fail.
|
||||
self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
|
||||
|
@ -2101,6 +2087,13 @@ class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
|
|||
self.assertEqual(result, exact)
|
||||
self.assertIsInstance(result, Decimal)
|
||||
|
||||
def test_accuracy_bug_20499(self):
|
||||
data = [0, 0, 1]
|
||||
exact = 2 / 9
|
||||
result = self.func(data)
|
||||
self.assertEqual(result, exact)
|
||||
self.assertIsInstance(result, float)
|
||||
|
||||
|
||||
class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
|
||||
# Tests for sample variance.
|
||||
|
@ -2141,6 +2134,13 @@ class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
|
|||
self.assertEqual(self.func(data), 0.5)
|
||||
self.assertEqual(self.func(data, xbar=2.0), 1.0)
|
||||
|
||||
def test_accuracy_bug_20499(self):
|
||||
data = [0, 0, 2]
|
||||
exact = 4 / 3
|
||||
result = self.func(data)
|
||||
self.assertEqual(result, exact)
|
||||
self.assertIsInstance(result, float)
|
||||
|
||||
class TestPStdev(VarianceStdevMixin, NumericTestCase):
|
||||
# Tests for population standard deviation.
|
||||
def setUp(self):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Improve the speed and accuracy of statistics.pvariance().
|
Loading…
Add table
Add a link
Reference in a new issue