langchain[minor], community[minor], core[minor]: Async Cache support and AsyncRedisCache (#15817)

* This PR adds async methods to the LLM cache. 
* Adds an implementation using Redis called AsyncRedisCache.
* Adds a docker compose file at the /docker to help spin up docker
* Updates redis tests to use a context manager so flushing always happens by default
This commit is contained in:
Dmitry Kankalovich
2024-02-08 04:06:09 +01:00
committed by GitHub
parent 19546081c6
commit f92738a6f6
8 changed files with 472 additions and 133 deletions

View File

@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor
RETURN_VAL_TYPE = Sequence[Generation]
@@ -22,3 +23,17 @@ class BaseCache(ABC):
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
return await run_in_executor(None, self.lookup, prompt, llm_string)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
return await run_in_executor(None, self.update, prompt, llm_string, return_val)
async def aclear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
return await run_in_executor(None, self.clear, **kwargs)

View File

@@ -622,7 +622,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
@@ -632,7 +632,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
@abstractmethod

View File

@@ -139,6 +139,26 @@ def get_prompts(
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
async def aget_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""Get prompts that are already cached. Async version."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
llm_cache = get_llm_cache()
for i, prompt in enumerate(prompts):
if llm_cache:
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
def update_cache(
existing_prompts: Dict[int, List],
llm_string: str,
@@ -157,6 +177,24 @@ def update_cache(
return llm_output
async def aupdate_cache(
existing_prompts: Dict[int, List],
llm_string: str,
missing_prompt_idxs: List[int],
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output. Async version"""
llm_cache = get_llm_cache()
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface.
@@ -869,7 +907,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
) = await aget_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
@@ -917,7 +955,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
llm_output = update_cache(
llm_output = await aupdate_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = (