mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
parent
04bc5f3b18
commit
89be10f6b4
@ -216,10 +216,25 @@ class SQLiteCache(SQLAlchemyCache):
|
||||
class RedisCache(BaseCache):
|
||||
"""Cache that uses Redis as a backend."""
|
||||
|
||||
# TODO - implement a TTL policy in Redis
|
||||
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
|
||||
"""
|
||||
Initialize an instance of RedisCache.
|
||||
|
||||
def __init__(self, redis_: Any):
|
||||
"""Initialize by passing in Redis instance."""
|
||||
This method initializes an object with Redis caching capabilities.
|
||||
It takes a `redis_` parameter, which should be an instance of a Redis
|
||||
client class, allowing the object to interact with a Redis
|
||||
server for caching purposes.
|
||||
|
||||
Parameters:
|
||||
redis_ (Any): An instance of a Redis client class
|
||||
(e.g., redis.Redis) used for caching.
|
||||
This allows the object to communicate with a
|
||||
Redis server for caching operations.
|
||||
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
|
||||
If provided, it sets the time duration for how long cached
|
||||
items will remain valid. If not provided, cached items will not
|
||||
have an automatic expiration.
|
||||
"""
|
||||
try:
|
||||
from redis import Redis
|
||||
except ImportError:
|
||||
@ -230,6 +245,7 @@ class RedisCache(BaseCache):
|
||||
if not isinstance(redis_, Redis):
|
||||
raise ValueError("Please pass in Redis object.")
|
||||
self.redis = redis_
|
||||
self.ttl = ttl
|
||||
|
||||
def _key(self, prompt: str, llm_string: str) -> str:
|
||||
"""Compute key from prompt and llm_string"""
|
||||
@ -261,12 +277,19 @@ class RedisCache(BaseCache):
|
||||
return
|
||||
# Write to a Redis HASH
|
||||
key = self._key(prompt, llm_string)
|
||||
self.redis.hset(
|
||||
key,
|
||||
mapping={
|
||||
str(idx): generation.text for idx, generation in enumerate(return_val)
|
||||
},
|
||||
)
|
||||
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.hset(
|
||||
key,
|
||||
mapping={
|
||||
str(idx): generation.text
|
||||
for idx, generation in enumerate(return_val)
|
||||
},
|
||||
)
|
||||
if self.ttl is not None:
|
||||
pipe.expire(key, self.ttl)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
|
||||
|
@ -11,6 +11,15 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
REDIS_TEST_URL = "redis://localhost:6379"
|
||||
|
||||
|
||||
def test_redis_cache_ttl() -> None:
|
||||
import redis
|
||||
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=1)
|
||||
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
||||
key = langchain.llm_cache._key("foo", "bar")
|
||||
assert langchain.llm_cache.redis.pttl(key) > 0
|
||||
|
||||
|
||||
def test_redis_cache() -> None:
|
||||
import redis
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user