Issue #14373: Added C implementation of functools.lru_cache(). Based on

patches by Matt Joiner and Alexey Kachayev.
This commit is contained in:
Serhiy Storchaka 2015-05-23 22:42:49 +03:00
parent c70908558d
commit 1c858c352b
4 changed files with 747 additions and 115 deletions

View file

@ -7,6 +7,10 @@ import sys
from test import support
import unittest
from weakref import proxy
try:
import threading
except ImportError:
threading = None
import functools
@ -912,12 +916,12 @@ class Orderable_LT:
return self.value == other.value
class TestLRU(unittest.TestCase):
class TestLRU:
def test_lru(self):
def orig(x, y):
return 3 * x + y
f = functools.lru_cache(maxsize=20)(orig)
f = self.module.lru_cache(maxsize=20)(orig)
hits, misses, maxsize, currsize = f.cache_info()
self.assertEqual(maxsize, 20)
self.assertEqual(currsize, 0)
@ -955,7 +959,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 1)
# test size zero (which means "never-cache")
@functools.lru_cache(0)
@self.module.lru_cache(0)
def f():
nonlocal f_cnt
f_cnt += 1
@ -971,7 +975,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 0)
# test size one
@functools.lru_cache(1)
@self.module.lru_cache(1)
def f():
nonlocal f_cnt
f_cnt += 1
@ -987,7 +991,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 1)
# test size two
@functools.lru_cache(2)
@self.module.lru_cache(2)
def f(x):
nonlocal f_cnt
f_cnt += 1
@ -1004,7 +1008,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 2)
def test_lru_with_maxsize_none(self):
@functools.lru_cache(maxsize=None)
@self.module.lru_cache(maxsize=None)
def fib(n):
if n < 2:
return n
@ -1012,17 +1016,26 @@ class TestLRU(unittest.TestCase):
self.assertEqual([fib(n) for n in range(16)],
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
def test_lru_with_maxsize_negative(self):
@self.module.lru_cache(maxsize=-10)
def eq(n):
return n
for i in (0, 1):
self.assertEqual([eq(n) for n in range(150)], list(range(150)))
self.assertEqual(eq.cache_info(),
self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
def test_lru_with_exceptions(self):
# Verify that user_function exceptions get passed through without
# creating a hard-to-read chained exception.
# http://bugs.python.org/issue13177
for maxsize in (None, 128):
@functools.lru_cache(maxsize)
@self.module.lru_cache(maxsize)
def func(i):
return 'abc'[i]
self.assertEqual(func(0), 'a')
@ -1035,7 +1048,7 @@ class TestLRU(unittest.TestCase):
def test_lru_with_types(self):
for maxsize in (None, 128):
@functools.lru_cache(maxsize=maxsize, typed=True)
@self.module.lru_cache(maxsize=maxsize, typed=True)
def square(x):
return x * x
self.assertEqual(square(3), 9)
@ -1050,7 +1063,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(square.cache_info().misses, 4)
def test_lru_with_keyword_args(self):
@functools.lru_cache()
@self.module.lru_cache()
def fib(n):
if n < 2:
return n
@ -1060,13 +1073,13 @@ class TestLRU(unittest.TestCase):
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
)
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
def test_lru_with_keyword_args_maxsize_none(self):
@functools.lru_cache(maxsize=None)
@self.module.lru_cache(maxsize=None)
def fib(n):
if n < 2:
return n
@ -1074,15 +1087,71 @@ class TestLRU(unittest.TestCase):
self.assertEqual([fib(n=number) for number in range(16)],
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
def test_lru_cache_decoration(self):
def f(zomg: 'zomg_annotation'):
"""f doc string"""
return 42
g = self.module.lru_cache()(f)
for attr in self.module.WRAPPER_ASSIGNMENTS:
self.assertEqual(getattr(g, attr), getattr(f, attr))
@unittest.skipUnless(threading, 'This test requires threading.')
def test_lru_cache_threaded(self):
def orig(x, y):
return 3 * x + y
f = self.module.lru_cache(maxsize=20)(orig)
hits, misses, maxsize, currsize = f.cache_info()
self.assertEqual(currsize, 0)
def full(f, *args):
for _ in range(10):
f(*args)
def clear(f):
for _ in range(10):
f.cache_clear()
orig_si = sys.getswitchinterval()
sys.setswitchinterval(1e-6)
try:
# create 5 threads in order to fill cache
threads = []
for k in range(5):
t = threading.Thread(target=full, args=[f, k, k])
t.start()
threads.append(t)
for t in threads:
t.join()
hits, misses, maxsize, currsize = f.cache_info()
self.assertEqual(hits, 45)
self.assertEqual(misses, 5)
self.assertEqual(currsize, 5)
# create 5 threads in order to fill cache and 1 to clear it
cleaner = threading.Thread(target=clear, args=[f])
cleaner.start()
threads = [cleaner]
for k in range(5):
t = threading.Thread(target=full, args=[f, k, k])
t.start()
threads.append(t)
for t in threads:
t.join()
finally:
sys.setswitchinterval(orig_si)
def test_need_for_rlock(self):
# This will deadlock on an LRU cache that uses a regular lock
@functools.lru_cache(maxsize=10)
@self.module.lru_cache(maxsize=10)
def test_func(x):
'Used to demonstrate a reentrant lru_cache call within a single thread'
return x
@ -1110,6 +1179,12 @@ class TestLRU(unittest.TestCase):
def f():
pass
class TestLRUC(TestLRU, unittest.TestCase):
module = c_functools
class TestLRUPy(TestLRU, unittest.TestCase):
module = py_functools
class TestSingleDispatch(unittest.TestCase):
def test_simple_overloads(self):