Optimize fmean() weighted average (#102626)

This commit is contained in:
Raymond Hettinger 2023-03-12 12:48:25 -05:00 committed by GitHub
parent e6210621be
commit 6cd7572f85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -136,9 +136,9 @@ from fractions import Fraction
from decimal import Decimal from decimal import Decimal
from itertools import count, groupby, repeat from itertools import count, groupby, repeat
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
from functools import reduce from functools import reduce
from operator import mul, itemgetter from operator import itemgetter
from collections import Counter, namedtuple, defaultdict from collections import Counter, namedtuple, defaultdict
_SQRT2 = sqrt(2.0) _SQRT2 = sqrt(2.0)
@ -496,28 +496,26 @@ def fmean(data, weights=None):
>>> fmean([3.5, 4.0, 5.25]) >>> fmean([3.5, 4.0, 5.25])
4.25 4.25
""" """
try:
n = len(data)
except TypeError:
# Handle iterators that do not define __len__().
n = 0
def count(iterable):
nonlocal n
for n, x in enumerate(iterable, start=1):
yield x
data = count(data)
if weights is None: if weights is None:
try:
n = len(data)
except TypeError:
# Handle iterators that do not define __len__().
n = 0
def count(iterable):
nonlocal n
for n, x in enumerate(iterable, start=1):
yield x
data = count(data)
total = fsum(data) total = fsum(data)
if not n: if not n:
raise StatisticsError('fmean requires at least one data point') raise StatisticsError('fmean requires at least one data point')
return total / n return total / n
try: if not isinstance(weights, (list, tuple)):
num_weights = len(weights)
except TypeError:
weights = list(weights) weights = list(weights)
num_weights = len(weights) try:
num = fsum(map(mul, data, weights)) num = sumprod(data, weights)
if n != num_weights: except ValueError:
raise StatisticsError('data and weights must be the same length') raise StatisticsError('data and weights must be the same length')
den = fsum(weights) den = fsum(weights)
if not den: if not den: