langchain[minor], community[minor], core[minor]: Async Cache support and AsyncRedisCache (#15817)

* This PR adds async methods to the LLM cache. 
* Adds an implementation using Redis called AsyncRedisCache.
* Adds a docker compose file at the /docker to help spin up docker
* Updates redis tests to use a context manager so flushing always happens by default
This commit is contained in:
Dmitry Kankalovich
2024-02-08 04:06:09 +01:00
committed by GitHub
parent 19546081c6
commit f92738a6f6
8 changed files with 472 additions and 133 deletions

View File

@@ -27,6 +27,7 @@ import json
import logging
import uuid
import warnings
from abc import ABC
from datetime import timedelta
from functools import lru_cache
from typing import (
@@ -351,49 +352,26 @@ class UpstashRedisCache(BaseCache):
self.redis.flushdb(flush_type=asynchronous)
class RedisCache(BaseCache):
"""Cache that uses Redis as a backend."""
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of RedisCache.
This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class, allowing the object to interact with a Redis
server for caching purposes.
Parameters:
redis_ (Any): An instance of a Redis client class
(e.g., redis.Redis) used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
If provided, it sets the time duration for how long cached
items will remain valid. If not provided, cached items will not
have an automatic expiration.
"""
try:
from redis import Redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass in Redis object.")
self.redis = redis_
self.ttl = ttl
def _key(self, prompt: str, llm_string: str) -> str:
class _RedisCacheBase(BaseCache, ABC):
@staticmethod
def _key(prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@staticmethod
def _ensure_generation_type(return_val: RETURN_VAL_TYPE) -> None:
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
@staticmethod
def _get_generations(
results: dict[str | bytes, str | bytes],
) -> Optional[List[Generation]]:
generations = []
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
if results:
for _, text in results.items():
try:
@@ -410,28 +388,69 @@ class RedisCache(BaseCache):
generations.append(Generation(text=text))
return generations if generations else None
@staticmethod
def _configure_pipeline_for_update(
key: str, pipe: Any, return_val: RETURN_VAL_TYPE, ttl: Optional[int] = None
) -> None:
pipe.hset(
key,
mapping={
str(idx): dumps(generation) for idx, generation in enumerate(return_val)
},
)
if ttl is not None:
pipe.expire(key, ttl)
class RedisCache(_RedisCacheBase):
"""
Cache that uses Redis as a backend. Allows to use a sync `redis.Redis` client.
"""
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of RedisCache.
This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class (`redis.Redis`), allowing the object
to interact with a Redis server for caching purposes.
Parameters:
redis_ (Any): An instance of a Redis client class
(`redis.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
If provided, it sets the time duration for how long cached
items will remain valid. If not provided, cached items will not
have an automatic expiration.
"""
try:
from redis import Redis
except ImportError:
raise ValueError(
"Could not import `redis` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass a valid `redis.Redis` client.")
self.redis = redis_
self.ttl = ttl
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
return self._get_generations(results) # type: ignore[arg-type]
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
# Write to a Redis HASH
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)
with self.redis.pipeline() as pipe:
pipe.hset(
key,
mapping={
str(idx): dumps(generation)
for idx, generation in enumerate(return_val)
},
)
if self.ttl is not None:
pipe.expire(key, self.ttl)
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
pipe.execute()
def clear(self, **kwargs: Any) -> None:
@@ -440,6 +459,89 @@ class RedisCache(BaseCache):
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class AsyncRedisCache(_RedisCacheBase):
"""
Cache that uses Redis as a backend. Allows to use an
async `redis.asyncio.Redis` client.
"""
def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of AsyncRedisCache.
This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class (`redis.asyncio.Redis`), allowing the object
to interact with a Redis server for caching purposes.
Parameters:
redis_ (Any): An instance of a Redis client class
(`redis.asyncio.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
If provided, it sets the time duration for how long cached
items will remain valid. If not provided, cached items will not
have an automatic expiration.
"""
try:
from redis.asyncio import Redis
except ImportError:
raise ValueError(
"Could not import `redis.asyncio` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass a valid `redis.asyncio.Redis` client.")
self.redis = redis_
self.ttl = ttl
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
raise NotImplementedError(
"This async Redis cache does not implement `lookup()` method. "
"Consider using the async `alookup()` version."
)
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string. Async version."""
results = await self.redis.hgetall(self._key(prompt, llm_string))
return self._get_generations(results) # type: ignore[arg-type]
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
raise NotImplementedError(
"This async Redis cache does not implement `update()` method. "
"Consider using the async `aupdate()` version."
)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string. Async version."""
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)
async with self.redis.pipeline() as pipe:
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
await pipe.execute() # type: ignore[attr-defined]
def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
raise NotImplementedError(
"This async Redis cache does not implement `clear()` method. "
"Consider using the async `aclear()` version."
)
async def aclear(self, **kwargs: Any) -> None:
"""
Clear cache. If `asynchronous` is True, flush asynchronously.
Async version.
"""
asynchronous = kwargs.get("asynchronous", False)
await self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class RedisSemanticCache(BaseCache):
"""Cache that uses Redis as a vector-store backend."""