community[patch]: Add OpenSearch as semantic cache (#20254)

### Description
Use OpenSearch vector store as Semantic Cache.

### Twitter Handle
**@OpenSearchProj**

---------

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
Co-authored-by: Harish Tatikonda <harishtatikonda@Harishs-MacBook-Air.local>
Co-authored-by: EC2 Default User <ec2-user@ip-172-31-31-155.ec2.internal>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Naveen Tatikonda
2024-04-26 19:20:24 -05:00
committed by GitHub
parent 61f14f00d7
commit 8bbdb4f6a0
4 changed files with 431 additions and 21 deletions

View File

@@ -19,6 +19,7 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
BaseCache --> <name>Cache # Examples: InMemoryCache, RedisCache, GPTCache
"""
from __future__ import annotations
import hashlib
@@ -76,6 +77,9 @@ from langchain_community.utilities.astradb import (
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
from langchain_community.vectorstores import (
OpenSearchVectorSearch as OpenSearchVectorStore,
)
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__)
@@ -2049,3 +2053,105 @@ class AzureCosmosDBSemanticCache(BaseCache):
def _validate_enum_value(value: Any, enum_type: Type[Enum]) -> None:
if not isinstance(value, enum_type):
raise ValueError(f"Invalid enum value: {value}. Expected {enum_type}.")
class OpenSearchSemanticCache(BaseCache):
"""Cache that uses OpenSearch vector store backend"""
def __init__(
self, opensearch_url: str, embedding: Embeddings, score_threshold: float = 0.2
):
"""
Args:
opensearch_url (str): URL to connect to OpenSearch.
embedding (Embedding): Embedding provider for semantic encoding and search.
score_threshold (float, 0.2):
Example:
.. code-block:: python
import langchain
from langchain.cache import OpenSearchSemanticCache
from langchain.embeddings import OpenAIEmbeddings
langchain.llm_cache = OpenSearchSemanticCache(
opensearch_url="http//localhost:9200",
embedding=OpenAIEmbeddings()
)
"""
self._cache_dict: Dict[str, OpenSearchVectorStore] = {}
self.opensearch_url = opensearch_url
self.embedding = embedding
self.score_threshold = score_threshold
def _index_name(self, llm_string: str) -> str:
hashed_index = _hash(llm_string)
return f"cache_{hashed_index}"
def _get_llm_cache(self, llm_string: str) -> OpenSearchVectorStore:
index_name = self._index_name(llm_string)
# return vectorstore client for the specific llm string
if index_name in self._cache_dict:
return self._cache_dict[index_name]
# create new vectorstore client for the specific llm string
self._cache_dict[index_name] = OpenSearchVectorStore(
opensearch_url=self.opensearch_url,
index_name=index_name,
embedding_function=self.embedding,
)
# create index for the vectorstore
vectorstore = self._cache_dict[index_name]
if not vectorstore.index_exists():
_embedding = self.embedding.embed_query(text="test")
vectorstore.create_index(len(_embedding), index_name)
return vectorstore
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
generations: List = []
# Read from a Hash
results = llm_cache.similarity_search(
query=prompt,
k=1,
score_threshold=self.score_threshold,
)
if results:
for document in results:
try:
generations.extend(loads(document.metadata["return_val"]))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"OpenSearchSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
llm_cache = self._get_llm_cache(llm_string)
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": dumps([g for g in return_val]),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
def clear(self, **kwargs: Any) -> None:
"""Clear semantic cache for a given llm_string."""
index_name = self._index_name(kwargs["llm_string"])
if index_name in self._cache_dict:
self._cache_dict[index_name].delete_index(index_name=index_name)
del self._cache_dict[index_name]