Extend _sqrtprod() to cover the full range of inputs. Add tests. (GH-107855)

This commit is contained in:
Raymond Hettinger 2023-08-11 17:19:19 +01:00 committed by GitHub
parent 637f7ff2c6
commit 52e0797f8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 6 deletions

View file

@ -28,6 +28,12 @@ import statistics
# === Helper functions and class ===
# Test copied from Lib/test/test_math.py
# detect evidence of double-rounding: fsum is not always correctly
# rounded on machines that suffer from double rounding.
x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer
HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4)
def sign(x):
"""Return -1.0 for negatives, including -0.0, otherwise +1.0."""
return math.copysign(1, x)
@ -2564,6 +2570,79 @@ class TestCorrelationAndCovariance(unittest.TestCase):
self.assertAlmostEqual(statistics.correlation(x, y), 1)
self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
def test_sqrtprod_helper_function_fundamentals(self):
# Verify that results are close to sqrt(x * y)
for i in range(100):
x = random.expovariate()
y = random.expovariate()
expected = math.sqrt(x * y)
actual = statistics._sqrtprod(x, y)
with self.subTest(x=x, y=y, expected=expected, actual=actual):
self.assertAlmostEqual(expected, actual)
x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
self.assertEqual(statistics._sqrtprod(x, y), target)
self.assertNotEqual(math.sqrt(x * y), target)
# Test that range extremes avoid underflow and overflow
smallest = sys.float_info.min * sys.float_info.epsilon
self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest)
biggest = sys.float_info.max
self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest)
# Check special values and the sign of the result
special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0,
math.nan, -math.nan, math.inf, -math.inf]
for x, y in itertools.product(special_values, repeat=2):
try:
expected = math.sqrt(x * y)
except ValueError:
expected = 'ValueError'
try:
actual = statistics._sqrtprod(x, y)
except ValueError:
actual = 'ValueError'
with self.subTest(x=x, y=y, expected=expected, actual=actual):
if isinstance(expected, str) and expected == 'ValueError':
self.assertEqual(actual, 'ValueError')
continue
self.assertIsInstance(actual, float)
if math.isnan(expected):
self.assertTrue(math.isnan(actual))
continue
self.assertEqual(actual, expected)
self.assertEqual(sign(actual), sign(expected))
@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"accuracy not guaranteed on machines with double rounding")
@support.cpython_only # Allow for a weaker sumprod() implmentation
def test_sqrtprod_helper_function_improved_accuracy(self):
# Test a known example where accuracy is improved
x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
self.assertEqual(statistics._sqrtprod(x, y), target)
self.assertNotEqual(math.sqrt(x * y), target)
def reference_value(x: float, y: float) -> float:
x = decimal.Decimal(x)
y = decimal.Decimal(y)
with decimal.localcontext() as ctx:
ctx.prec = 200
return float((x * y).sqrt())
# Verify that the new function with improved accuracy
# agrees with a reference value more often than old version.
new_agreements = 0
old_agreements = 0
for i in range(10_000):
x = random.expovariate()
y = random.expovariate()
new = statistics._sqrtprod(x, y)
old = math.sqrt(x * y)
ref = reference_value(x, y)
new_agreements += (new == ref)
old_agreements += (old == ref)
self.assertGreater(new_agreements, old_agreements)
def test_correlation_spearman(self):
# https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php