GH-100425: Improve accuracy of builtin sum() for float inputs (GH-100426)

This commit is contained in:
Raymond Hettinger 2022-12-23 14:35:58 -08:00 committed by GitHub
parent 1ecfd1ebf1
commit 5d84966cce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 8 deletions

View file

@ -1733,6 +1733,10 @@ are always available. They are listed here in alphabetical order.
.. versionchanged:: 3.8 .. versionchanged:: 3.8
The *start* parameter can be specified as a keyword argument. The *start* parameter can be specified as a keyword argument.
.. versionchanged:: 3.12 Summation of floats switched to an algorithm
that gives higher accuracy on most builds.
.. class:: super() .. class:: super()
super(type, object_or_type=None) super(type, object_or_type=None)

View file

@ -108,12 +108,7 @@ Number-theoretic and representation functions
.. function:: fsum(iterable) .. function:: fsum(iterable)
Return an accurate floating point sum of values in the iterable. Avoids Return an accurate floating point sum of values in the iterable. Avoids
loss of precision by tracking multiple intermediate partial sums: loss of precision by tracking multiple intermediate partial sums.
>>> sum([.1, .1, .1, .1, .1, .1, .1, .1, .1, .1])
0.9999999999999999
>>> fsum([.1, .1, .1, .1, .1, .1, .1, .1, .1, .1])
1.0
The algorithm's accuracy depends on IEEE-754 arithmetic guarantees and the The algorithm's accuracy depends on IEEE-754 arithmetic guarantees and the
typical case where the rounding mode is half-even. On some non-Windows typical case where the rounding mode is half-even. On some non-Windows

View file

@ -192,7 +192,7 @@ added onto a running total. That can make a difference in overall accuracy
so that the errors do not accumulate to the point where they affect the so that the errors do not accumulate to the point where they affect the
final total: final total:
>>> sum([0.1] * 10) == 1.0 >>> 0.1 + 0.1 + 0.1 + 0.1 + 0.1 + 0.1 + 0.1 + 0.1 + 0.1 + 0.1 == 1.0
False False
>>> math.fsum([0.1] * 10) == 1.0 >>> math.fsum([0.1] * 10) == 1.0
True True

View file

@ -9,6 +9,7 @@ import fractions
import gc import gc
import io import io
import locale import locale
import math
import os import os
import pickle import pickle
import platform import platform
@ -31,6 +32,7 @@ from test.support import (swap_attr, maybe_get_event_loop_policy)
from test.support.os_helper import (EnvironmentVarGuard, TESTFN, unlink) from test.support.os_helper import (EnvironmentVarGuard, TESTFN, unlink)
from test.support.script_helper import assert_python_ok from test.support.script_helper import assert_python_ok
from test.support.warnings_helper import check_warnings from test.support.warnings_helper import check_warnings
from test.support import requires_IEEE_754
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
try: try:
import pty, signal import pty, signal
@ -38,6 +40,12 @@ except ImportError:
pty = signal = None pty = signal = None
# Detect evidence of double-rounding: sum() does not always
# get improved accuracy on machines that suffer from double rounding.
x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer
HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4)
class Squares: class Squares:
def __init__(self, max): def __init__(self, max):
@ -1617,6 +1625,8 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(repr(sum([-0.0])), '0.0') self.assertEqual(repr(sum([-0.0])), '0.0')
self.assertEqual(repr(sum([-0.0], -0.0)), '-0.0') self.assertEqual(repr(sum([-0.0], -0.0)), '-0.0')
self.assertEqual(repr(sum([], -0.0)), '-0.0') self.assertEqual(repr(sum([], -0.0)), '-0.0')
self.assertTrue(math.isinf(sum([float("inf"), float("inf")])))
self.assertTrue(math.isinf(sum([1e308, 1e308])))
self.assertRaises(TypeError, sum) self.assertRaises(TypeError, sum)
self.assertRaises(TypeError, sum, 42) self.assertRaises(TypeError, sum, 42)
@ -1641,6 +1651,14 @@ class BuiltinTest(unittest.TestCase):
sum(([x] for x in range(10)), empty) sum(([x] for x in range(10)), empty)
self.assertEqual(empty, []) self.assertEqual(empty, [])
@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"sum accuracy not guaranteed on machines with double rounding")
@support.cpython_only # Other implementations may choose a different algorithm
def test_sum_accuracy(self):
self.assertEqual(sum([0.1] * 10), 1.0)
self.assertEqual(sum([1.0, 10E100, 1.0, -10E100]), 2.0)
def test_type(self): def test_type(self):
self.assertEqual(type(''), type('123')) self.assertEqual(type(''), type('123'))
self.assertNotEqual(type(''), type(())) self.assertNotEqual(type(''), type(()))

View file

@ -0,0 +1 @@
Improve the accuracy of ``sum()`` with compensated summation.

View file

@ -2532,6 +2532,7 @@ builtin_sum_impl(PyObject *module, PyObject *iterable, PyObject *start)
if (PyFloat_CheckExact(result)) { if (PyFloat_CheckExact(result)) {
double f_result = PyFloat_AS_DOUBLE(result); double f_result = PyFloat_AS_DOUBLE(result);
double c = 0.0;
Py_SETREF(result, NULL); Py_SETREF(result, NULL);
while(result == NULL) { while(result == NULL) {
item = PyIter_Next(iter); item = PyIter_Next(iter);
@ -2539,10 +2540,25 @@ builtin_sum_impl(PyObject *module, PyObject *iterable, PyObject *start)
Py_DECREF(iter); Py_DECREF(iter);
if (PyErr_Occurred()) if (PyErr_Occurred())
return NULL; return NULL;
/* Avoid losing the sign on a negative result,
and don't let adding the compensation convert
an infinite or overflowed sum to a NaN. */
if (c && Py_IS_FINITE(c)) {
f_result += c;
}
return PyFloat_FromDouble(f_result); return PyFloat_FromDouble(f_result);
} }
if (PyFloat_CheckExact(item)) { if (PyFloat_CheckExact(item)) {
f_result += PyFloat_AS_DOUBLE(item); // Improved KahanBabuška algorithm by Arnold Neumaier
// https://www.mat.univie.ac.at/~neum/scan/01.pdf
double x = PyFloat_AS_DOUBLE(item);
double t = f_result + x;
if (fabs(f_result) >= fabs(x)) {
c += (f_result - t) + x;
} else {
c += (x - t) + f_result;
}
f_result = t;
_Py_DECREF_SPECIALIZED(item, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(item, _PyFloat_ExactDealloc);
continue; continue;
} }
@ -2556,6 +2572,9 @@ builtin_sum_impl(PyObject *module, PyObject *iterable, PyObject *start)
continue; continue;
} }
} }
if (c && Py_IS_FINITE(c)) {
f_result += c;
}
result = PyFloat_FromDouble(f_result); result = PyFloat_FromDouble(f_result);
if (result == NULL) { if (result == NULL) {
Py_DECREF(item); Py_DECREF(item);