mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
community[patch]: Add OpenSearch as semantic cache (#20254)
### Description Use OpenSearch vector store as Semantic Cache. ### Twitter Handle **@OpenSearchProj** --------- Signed-off-by: Naveen Tatikonda <navtat@amazon.com> Co-authored-by: Harish Tatikonda <harishtatikonda@Harishs-MacBook-Air.local> Co-authored-by: EC2 Default User <ec2-user@ip-172-31-31-155.ec2.internal> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
|
||||
|
||||
BaseCache --> <name>Cache # Examples: InMemoryCache, RedisCache, GPTCache
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
@@ -76,6 +77,9 @@ from langchain_community.utilities.astradb import (
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
|
||||
from langchain_community.vectorstores import (
|
||||
OpenSearchVectorSearch as OpenSearchVectorStore,
|
||||
)
|
||||
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
@@ -2049,3 +2053,105 @@ class AzureCosmosDBSemanticCache(BaseCache):
|
||||
def _validate_enum_value(value: Any, enum_type: Type[Enum]) -> None:
|
||||
if not isinstance(value, enum_type):
|
||||
raise ValueError(f"Invalid enum value: {value}. Expected {enum_type}.")
|
||||
|
||||
|
||||
class OpenSearchSemanticCache(BaseCache):
|
||||
"""Cache that uses OpenSearch vector store backend"""
|
||||
|
||||
def __init__(
|
||||
self, opensearch_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
opensearch_url (str): URL to connect to OpenSearch.
|
||||
embedding (Embedding): Embedding provider for semantic encoding and search.
|
||||
score_threshold (float, 0.2):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
import langchain
|
||||
from langchain.cache import OpenSearchSemanticCache
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
langchain.llm_cache = OpenSearchSemanticCache(
|
||||
opensearch_url="http//localhost:9200",
|
||||
embedding=OpenAIEmbeddings()
|
||||
)
|
||||
"""
|
||||
self._cache_dict: Dict[str, OpenSearchVectorStore] = {}
|
||||
self.opensearch_url = opensearch_url
|
||||
self.embedding = embedding
|
||||
self.score_threshold = score_threshold
|
||||
|
||||
def _index_name(self, llm_string: str) -> str:
|
||||
hashed_index = _hash(llm_string)
|
||||
return f"cache_{hashed_index}"
|
||||
|
||||
def _get_llm_cache(self, llm_string: str) -> OpenSearchVectorStore:
|
||||
index_name = self._index_name(llm_string)
|
||||
|
||||
# return vectorstore client for the specific llm string
|
||||
if index_name in self._cache_dict:
|
||||
return self._cache_dict[index_name]
|
||||
|
||||
# create new vectorstore client for the specific llm string
|
||||
self._cache_dict[index_name] = OpenSearchVectorStore(
|
||||
opensearch_url=self.opensearch_url,
|
||||
index_name=index_name,
|
||||
embedding_function=self.embedding,
|
||||
)
|
||||
|
||||
# create index for the vectorstore
|
||||
vectorstore = self._cache_dict[index_name]
|
||||
if not vectorstore.index_exists():
|
||||
_embedding = self.embedding.embed_query(text="test")
|
||||
vectorstore.create_index(len(_embedding), index_name)
|
||||
return vectorstore
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
generations: List = []
|
||||
# Read from a Hash
|
||||
results = llm_cache.similarity_search(
|
||||
query=prompt,
|
||||
k=1,
|
||||
score_threshold=self.score_threshold,
|
||||
)
|
||||
if results:
|
||||
for document in results:
|
||||
try:
|
||||
generations.extend(loads(document.metadata["return_val"]))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Retrieving a cache value that could not be deserialized "
|
||||
"properly. This is likely due to the cache being in an "
|
||||
"older format. Please recreate your cache to avoid this "
|
||||
"error."
|
||||
)
|
||||
|
||||
generations.extend(
|
||||
_load_generations_from_json(document.metadata["return_val"])
|
||||
)
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for gen in return_val:
|
||||
if not isinstance(gen, Generation):
|
||||
raise ValueError(
|
||||
"OpenSearchSemanticCache only supports caching of "
|
||||
f"normal LLM generations, got {type(gen)}"
|
||||
)
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
metadata = {
|
||||
"llm_string": llm_string,
|
||||
"prompt": prompt,
|
||||
"return_val": dumps([g for g in return_val]),
|
||||
}
|
||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear semantic cache for a given llm_string."""
|
||||
index_name = self._index_name(kwargs["llm_string"])
|
||||
if index_name in self._cache_dict:
|
||||
self._cache_dict[index_name].delete_index(index_name=index_name)
|
||||
del self._cache_dict[index_name]
|
||||
|
Reference in New Issue
Block a user