GH-102670: Use sumprod() to simplify, speed up, and improve accuracy of statistics functions (GH-102649)

This commit is contained in:
Raymond Hettinger 2023-03-13 20:06:43 -05:00 committed by GitHub
parent 61479d4684
commit 457e4d1a51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 13 deletions

View file

@ -1036,7 +1036,7 @@ def covariance(x, y, /):
raise StatisticsError('covariance requires at least two data points')
xbar = fsum(x) / n
ybar = fsum(y) / n
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
sxy = sumprod((xi - xbar for xi in x), (yi - ybar for yi in y))
return sxy / (n - 1)
@ -1074,11 +1074,14 @@ def correlation(x, y, /, *, method='linear'):
start = (n - 1) / -2 # Center rankings around zero
x = _rank(x, start=start)
y = _rank(y, start=start)
xbar = fsum(x) / n
ybar = fsum(y) / n
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
sxx = fsum((d := xi - xbar) * d for xi in x)
syy = fsum((d := yi - ybar) * d for yi in y)
else:
xbar = fsum(x) / n
ybar = fsum(y) / n
x = [xi - xbar for xi in x]
y = [yi - ybar for yi in y]
sxy = sumprod(x, y)
sxx = sumprod(x, x)
syy = sumprod(y, y)
try:
return sxy / sqrt(sxx * syy)
except ZeroDivisionError:
@ -1131,14 +1134,13 @@ def linear_regression(x, y, /, *, proportional=False):
raise StatisticsError('linear regression requires that both inputs have same number of data points')
if n < 2:
raise StatisticsError('linear regression requires at least two data points')
if proportional:
sxy = fsum(xi * yi for xi, yi in zip(x, y))
sxx = fsum(xi * xi for xi in x)
else:
if not proportional:
xbar = fsum(x) / n
ybar = fsum(y) / n
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
sxx = fsum((d := xi - xbar) * d for xi in x)
x = [xi - xbar for xi in x] # List because used three times below
y = (yi - ybar for yi in y) # Generator because only used once below
sxy = sumprod(x, y) + 0.0 # Add zero to coerce result to a float
sxx = sumprod(x, x)
try:
slope = sxy / sxx # equivalent to: covariance(x, y) / variance(x)
except ZeroDivisionError: