Remove the lfu_cache. Add more tests.

This commit is contained in:
Raymond Hettinger 2010-08-15 03:30:45 +00:00
parent 0f56e90f05
commit f309828175
3 changed files with 29 additions and 129 deletions

View file

@ -110,58 +110,6 @@ def cmp_to_key(mycmp):
raise TypeError('hash not implemented')
return K
def lfu_cache(maxsize=100):
"""Least-frequently-used cache decorator.
Arguments to the cached function must be hashable.
Cache performance statistics stored in f.hits and f.misses.
Clear the cache using f.clear().
http://en.wikipedia.org/wiki/Cache_algorithms#Least-Frequently_Used
"""
def decorating_function(user_function, tuple=tuple, sorted=sorted,
len=len, KeyError=KeyError):
cache = {} # mapping of args to results
use_count = Counter() # times each key has been accessed
kwd_mark = object() # separate positional and keyword args
lock = Lock()
@wraps(user_function)
def wrapper(*args, **kwds):
key = args
if kwds:
key += (kwd_mark,) + tuple(sorted(kwds.items()))
try:
with lock:
use_count[key] += 1 # count a use of this key
result = cache[key]
wrapper.hits += 1
except KeyError:
result = user_function(*args, **kwds)
with lock:
use_count[key] += 1 # count a use of this key
cache[key] = result
wrapper.misses += 1
if len(cache) > maxsize:
# purge the 10% least frequently used entries
for key, _ in nsmallest(maxsize // 10 or 1,
use_count.items(),
key=itemgetter(1)):
del cache[key], use_count[key]
return result
def clear():
"""Clear the cache and cache statistics"""
with lock:
cache.clear()
use_count.clear()
wrapper.hits = wrapper.misses = 0
wrapper.hits = wrapper.misses = 0
wrapper.clear = clear
return wrapper
return decorating_function
def lru_cache(maxsize=100):
"""Least-recently-used cache decorator.

View file

@ -483,73 +483,38 @@ class TestLRU(unittest.TestCase):
self.assertEqual(f.misses, 1)
# test size zero (which means "never-cache")
f_cnt = 0
@functools.lru_cache(0)
def f():
nonlocal f_cnt
f_cnt += 1
return 20
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
self.assertEqual(f_cnt, 3)
f_cnt = 0
for i in range(5):
self.assertEqual(f(), 20)
self.assertEqual(f_cnt, 5)
# test size one
f_cnt = 0
@functools.lru_cache(1)
def f():
nonlocal f_cnt
f_cnt += 1
return 20
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
f_cnt = 0
for i in range(5):
self.assertEqual(f(), 20)
self.assertEqual(f_cnt, 1)
def test_lfu(self):
def orig(x, y):
return 3*x+y
f = functools.lfu_cache(maxsize=20)(orig)
domain = range(5)
for i in range(1000):
x, y = choice(domain), choice(domain)
actual = f(x, y)
expected = orig(x, y)
self.assertEquals(actual, expected)
self.assert_(f.hits > f.misses)
self.assertEquals(f.hits + f.misses, 1000)
f.clear() # test clearing
self.assertEqual(f.hits, 0)
self.assertEqual(f.misses, 0)
f(x, y)
self.assertEqual(f.hits, 0)
self.assertEqual(f.misses, 1)
# test size zero (which means "never-cache")
f_cnt = 0
@functools.lfu_cache(0)
def f():
# test size two
@functools.lru_cache(2)
def f(x):
nonlocal f_cnt
f_cnt += 1
return 20
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
self.assertEqual(f_cnt, 3)
# test size one
return x*10
f_cnt = 0
@functools.lfu_cache(1)
def f():
nonlocal f_cnt
f_cnt += 1
return 20
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
self.assertEqual(f(), 20)
self.assertEqual(f_cnt, 1)
for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
# * * * *
self.assertEqual(f(x), x*10)
self.assertEqual(f_cnt, 4)
def test_main(verbose=None):
test_classes = (