bpo-36324: Add inv_cdf() to statistics.NormalDist() (GH-12377)

This commit is contained in:
Raymond Hettinger 2019-03-18 20:17:14 -07:00 committed by GitHub
parent faddaedd05
commit 714c60d7ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 182 additions and 0 deletions

View file

@ -745,6 +745,101 @@ class NormalDist:
raise StatisticsError('cdf() not defined when sigma is zero')
return 0.5 * (1.0 + erf((x - self.mu) / (self.sigma * sqrt(2.0))))
def inv_cdf(self, p):
''' Inverse cumulative distribution function: x : P(X <= x) = p
Finds the value of the random variable such that the probability of the
variable being less than or equal to that value equals the given probability.
This function is also called the percent-point function or quantile function.
'''
if (p <= 0.0 or p >= 1.0):
raise StatisticsError('p must be in the range 0.0 < p < 1.0')
if self.sigma <= 0.0:
raise StatisticsError('cdf() not defined when sigma at or below zero')
# There is no closed-form solution to the inverse CDF for the normal
# distribution, so we use a rational approximation instead:
# Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
# Normal Distribution". Applied Statistics. Blackwell Publishing. 37
# (3): 477484. doi:10.2307/2347330. JSTOR 2347330.
q = p - 0.5
if fabs(q) <= 0.425:
a0 = 3.38713_28727_96366_6080e+0
a1 = 1.33141_66789_17843_7745e+2
a2 = 1.97159_09503_06551_4427e+3
a3 = 1.37316_93765_50946_1125e+4
a4 = 4.59219_53931_54987_1457e+4
a5 = 6.72657_70927_00870_0853e+4
a6 = 3.34305_75583_58812_8105e+4
a7 = 2.50908_09287_30122_6727e+3
b1 = 4.23133_30701_60091_1252e+1
b2 = 6.87187_00749_20579_0830e+2
b3 = 5.39419_60214_24751_1077e+3
b4 = 2.12137_94301_58659_5867e+4
b5 = 3.93078_95800_09271_0610e+4
b6 = 2.87290_85735_72194_2674e+4
b7 = 5.22649_52788_52854_5610e+3
r = 0.180625 - q * q
num = (q * (((((((a7 * r + a6) * r + a5) * r + a4) * r + a3)
* r + a2) * r + a1) * r + a0))
den = ((((((((b7 * r + b6) * r + b5) * r + b4) * r + b3)
* r + b2) * r + b1) * r + 1.0))
x = num / den
return self.mu + (x * self.sigma)
r = p if q <= 0.0 else 1.0 - p
r = sqrt(-log(r))
if r <= 5.0:
c0 = 1.42343_71107_49683_57734e+0
c1 = 4.63033_78461_56545_29590e+0
c2 = 5.76949_72214_60691_40550e+0
c3 = 3.64784_83247_63204_60504e+0
c4 = 1.27045_82524_52368_38258e+0
c5 = 2.41780_72517_74506_11770e-1
c6 = 2.27238_44989_26918_45833e-2
c7 = 7.74545_01427_83414_07640e-4
d1 = 2.05319_16266_37758_82187e+0
d2 = 1.67638_48301_83803_84940e+0
d3 = 6.89767_33498_51000_04550e-1
d4 = 1.48103_97642_74800_74590e-1
d5 = 1.51986_66563_61645_71966e-2
d6 = 5.47593_80849_95344_94600e-4
d7 = 1.05075_00716_44416_84324e-9
r = r - 1.6
num = ((((((((c7 * r + c6) * r + c5) * r + c4) * r + c3)
* r + c2) * r + c1) * r + c0))
den = ((((((((d7 * r + d6) * r + d5) * r + d4) * r + d3)
* r + d2) * r + d1) * r + 1.0))
else:
e0 = 6.65790_46435_01103_77720e+0
e1 = 5.46378_49111_64114_36990e+0
e2 = 1.78482_65399_17291_33580e+0
e3 = 2.96560_57182_85048_91230e-1
e4 = 2.65321_89526_57612_30930e-2
e5 = 1.24266_09473_88078_43860e-3
e6 = 2.71155_55687_43487_57815e-5
e7 = 2.01033_43992_92288_13265e-7
f1 = 5.99832_20655_58879_37690e-1
f2 = 1.36929_88092_27358_05310e-1
f3 = 1.48753_61290_85061_48525e-2
f4 = 7.86869_13114_56132_59100e-4
f5 = 1.84631_83175_10054_68180e-5
f6 = 1.42151_17583_16445_88870e-7
f7 = 2.04426_31033_89939_78564e-15
r = r - 5.0
num = ((((((((e7 * r + e6) * r + e5) * r + e4) * r + e3)
* r + e2) * r + e1) * r + e0))
den = ((((((((f7 * r + f6) * r + f5) * r + f4) * r + f3)
* r + f2) * r + f1) * r + 1.0))
x = num / den
if q < 0.0:
x = -x
return self.mu + (x * self.sigma)
def overlap(self, other):
'''Compute the overlapping coefficient (OVL) between two normal distributions.