This commit is contained in:
Dingning 2025-11-17 13:45:10 +01:00 committed by GitHub
commit e2dc26fe94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 43 additions and 7 deletions

View file

@ -44,6 +44,7 @@ class CacheHandler(BaseConnectionHandler):
params = self.settings[alias].copy() params = self.settings[alias].copy()
backend = params.pop("BACKEND") backend = params.pop("BACKEND")
location = params.pop("LOCATION", "") location = params.pop("LOCATION", "")
params["ALIAS"] = alias
try: try:
backend_cls = import_string(backend) backend_cls = import_string(backend)
except ImportError as e: except ImportError as e:

View file

@ -82,6 +82,7 @@ class BaseCache:
self.key_prefix = params.get("KEY_PREFIX", "") self.key_prefix = params.get("KEY_PREFIX", "")
self.version = params.get("VERSION", 1) self.version = params.get("VERSION", 1)
self.alias = params.get("ALIAS", "default")
self.key_func = get_key_func(params.get("KEY_FUNCTION")) self.key_func = get_key_func(params.get("KEY_FUNCTION"))
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT): def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):

View file

@ -29,9 +29,12 @@ class RedisSerializer:
class RedisCacheClient: class RedisCacheClient:
_pools = {}
def __init__( def __init__(
self, self,
servers, servers,
alias=None,
serializer=None, serializer=None,
pool_class=None, pool_class=None,
parser_class=None, parser_class=None,
@ -41,7 +44,7 @@ class RedisCacheClient:
self._lib = redis self._lib = redis
self._servers = servers self._servers = servers
self._pools = {} self._alias = alias
self._client = self._lib.Redis self._client = self._lib.Redis
@ -70,12 +73,17 @@ class RedisCacheClient:
def _get_connection_pool(self, write): def _get_connection_pool(self, write):
index = self._get_connection_pool_index(write) index = self._get_connection_pool_index(write)
if index not in self._pools: key = f"{self._alias}:{index}"
self._pools[index] = self._pool_class.from_url(
self._servers[index], if key not in self._pools:
**self._pool_options, self._pools.setdefault(
key,
self._pool_class.from_url(
self._servers[index],
**self._pool_options,
),
) )
return self._pools[index] return self._pools[key]
def get_client(self, key=None, *, write=False): def get_client(self, key=None, *, write=False):
# key is used so that the method signature remains the same and custom # key is used so that the method signature remains the same and custom
@ -170,7 +178,7 @@ class RedisCache(BaseCache):
@cached_property @cached_property
def _cache(self): def _cache(self):
return self._class(self._servers, **self._options) return self._class(self._servers, alias=self.alias, **self._options)
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT): def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
if timeout == DEFAULT_TIMEOUT: if timeout == DEFAULT_TIMEOUT:

26
tests/cache/tests.py vendored
View file

@ -1831,6 +1831,10 @@ class RedisCacheTests(BaseCacheTests, TestCase):
super().setUp() super().setUp()
self.lib = redis self.lib = redis
# Clear pools, because the pool is process-global,
# so every case will use the pools created in previous cases
self.addCleanup(cache._cache._pools.clear)
@property @property
def incr_decr_type_error(self): def incr_decr_type_error(self):
return self.lib.ResponseError return self.lib.ResponseError
@ -1902,6 +1906,28 @@ class RedisCacheTests(BaseCacheTests, TestCase):
self.assertEqual(pool.connection_kwargs["socket_timeout"], 0.1) self.assertEqual(pool.connection_kwargs["socket_timeout"], 0.1)
self.assertIs(pool.connection_kwargs["retry_on_timeout"], True) self.assertIs(pool.connection_kwargs["retry_on_timeout"], True)
def test_redis_pool_is_global(self):
class ResultContainer:
def __init__(self):
self.result1 = None
self.result2 = None
result = ResultContainer()
def get_connection_pool(result, slot):
setattr(result, slot, cache._cache._get_connection_pool(write=False))
t1 = threading.Thread(target=get_connection_pool, args=(result, "result1"))
t2 = threading.Thread(target=get_connection_pool, args=(result, "result2"))
t1.start()
t2.start()
t1.join()
t2.join()
self.assertEqual(id(result.result1), id(result.result2))
class FileBasedCachePathLibTests(FileBasedCacheTests): class FileBasedCachePathLibTests(FileBasedCacheTests):
def mkdtemp(self): def mkdtemp(self):