bpo-44151: linear_regression() minor API improvements (GH-26199)

This commit is contained in:
Zack Kneupper 2021-05-24 20:30:58 -04:00 committed by GitHub
parent 8450e8a81f
commit 2f3a87856c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 26 deletions

View file

@ -94,7 +94,7 @@ for two inputs:
>>> correlation(x, y) #doctest: +ELLIPSIS
0.31622776601...
>>> linear_regression(x, y) #doctest:
LinearRegression(intercept=1.5, slope=0.1)
LinearRegression(slope=0.1, intercept=1.5)
Exceptions
@ -932,18 +932,18 @@ def correlation(x, y, /):
raise StatisticsError('at least one of the inputs is constant')
LinearRegression = namedtuple('LinearRegression', ['intercept', 'slope'])
LinearRegression = namedtuple('LinearRegression', ('slope', 'intercept'))
def linear_regression(regressor, dependent_variable, /):
def linear_regression(x, y, /):
"""Intercept and slope for simple linear regression
Return the intercept and slope of simple linear regression
parameters estimated using ordinary least squares. Simple linear
regression describes relationship between *regressor* and
*dependent variable* in terms of linear function:
regression describes relationship between *x* and
*y* in terms of linear function:
dependent_variable = intercept + slope * regressor + noise
y = intercept + slope * x + noise
where *intercept* and *slope* are the regression parameters that are
estimated, and noise represents the variability of the data that was
@ -953,19 +953,18 @@ def linear_regression(regressor, dependent_variable, /):
The parameters are returned as a named tuple.
>>> regressor = [1, 2, 3, 4, 5]
>>> x = [1, 2, 3, 4, 5]
>>> noise = NormalDist().samples(5, seed=42)
>>> dependent_variable = [2 + 3 * regressor[i] + noise[i] for i in range(5)]
>>> linear_regression(regressor, dependent_variable) #doctest: +ELLIPSIS
LinearRegression(intercept=1.75684970486..., slope=3.09078914170...)
>>> y = [2 + 3 * x[i] + noise[i] for i in range(5)]
>>> linear_regression(x, y) #doctest: +ELLIPSIS
LinearRegression(slope=3.09078914170..., intercept=1.75684970486...)
"""
n = len(regressor)
if len(dependent_variable) != n:
n = len(x)
if len(y) != n:
raise StatisticsError('linear regression requires that both inputs have same number of data points')
if n < 2:
raise StatisticsError('linear regression requires at least two data points')
x, y = regressor, dependent_variable
xbar = fsum(x) / n
ybar = fsum(y) / n
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
@ -973,9 +972,9 @@ def linear_regression(regressor, dependent_variable, /):
try:
slope = sxy / s2x # equivalent to: covariance(x, y) / variance(x)
except ZeroDivisionError:
raise StatisticsError('regressor is constant')
raise StatisticsError('x is constant')
intercept = ybar - slope * xbar
return LinearRegression(intercept=intercept, slope=slope)
return LinearRegression(slope=slope, intercept=intercept)
## Normal Distribution #####################################################