mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 04:55:14 +00:00
cosmosdbnosql: Added Cosmos DB NoSQL Semantic Cache Integration with tests and jupyter notebook (#24424)
* Added Cosmos DB NoSQL Semantic Cache Integration with tests and jupyter notebook --------- Co-authored-by: Aayush Kataria <aayushkataria3011@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -80,7 +80,10 @@ from langchain_community.utilities.astradb import (
|
||||
from langchain_community.utilities.astradb import (
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
|
||||
from langchain_community.vectorstores import (
|
||||
AzureCosmosDBNoSqlVectorSearch,
|
||||
AzureCosmosDBVectorSearch,
|
||||
)
|
||||
from langchain_community.vectorstores import (
|
||||
OpenSearchVectorSearch as OpenSearchVectorStore,
|
||||
)
|
||||
@@ -93,6 +96,7 @@ if TYPE_CHECKING:
|
||||
import momento
|
||||
import pymemcache
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
from azure.cosmos.cosmos_client import CosmosClient
|
||||
from cassandra.cluster import Session as CassandraSession
|
||||
|
||||
|
||||
@@ -2103,7 +2107,7 @@ class AzureCosmosDBSemanticCache(BaseCache):
|
||||
ef_construction: int = 64,
|
||||
ef_search: int = 40,
|
||||
score_threshold: Optional[float] = None,
|
||||
application_name: str = "LANGCHAIN_CACHING_PYTHON",
|
||||
application_name: str = "LangChain-CDBMongoVCore-SemanticCache-Python",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -2268,7 +2272,6 @@ class AzureCosmosDBSemanticCache(BaseCache):
|
||||
index_name = self._index_name(kwargs["llm_string"])
|
||||
if index_name in self._cache_dict:
|
||||
self._cache_dict[index_name].get_collection().delete_many({})
|
||||
# self._cache_dict[index_name].clear_collection()
|
||||
|
||||
@staticmethod
|
||||
def _validate_enum_value(value: Any, enum_type: Type[Enum]) -> None:
|
||||
@@ -2276,6 +2279,111 @@ class AzureCosmosDBSemanticCache(BaseCache):
|
||||
raise ValueError(f"Invalid enum value: {value}. Expected {enum_type}.")
|
||||
|
||||
|
||||
class AzureCosmosDBNoSqlSemanticCache(BaseCache):
|
||||
"""Cache that uses Cosmos DB NoSQL backend"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
cosmos_client: CosmosClient,
|
||||
database_name: str = "CosmosNoSqlCacheDB",
|
||||
container_name: str = "CosmosNoSqlCacheContainer",
|
||||
*,
|
||||
vector_embedding_policy: Dict[str, Any],
|
||||
indexing_policy: Dict[str, Any],
|
||||
cosmos_container_properties: Dict[str, Any],
|
||||
cosmos_database_properties: Dict[str, Any],
|
||||
create_container: bool = True,
|
||||
):
|
||||
self.cosmos_client = cosmos_client
|
||||
self.database_name = database_name
|
||||
self.container_name = container_name
|
||||
self.embedding = embedding
|
||||
self.vector_embedding_policy = vector_embedding_policy
|
||||
self.indexing_policy = indexing_policy
|
||||
self.cosmos_container_properties = cosmos_container_properties
|
||||
self.cosmos_database_properties = cosmos_database_properties
|
||||
self.create_container = create_container
|
||||
self._cache_dict: Dict[str, AzureCosmosDBNoSqlVectorSearch] = {}
|
||||
|
||||
def _cache_name(self, llm_string: str) -> str:
|
||||
hashed_index = _hash(llm_string)
|
||||
return f"cache:{hashed_index}"
|
||||
|
||||
def _get_llm_cache(self, llm_string: str) -> AzureCosmosDBNoSqlVectorSearch:
|
||||
cache_name = self._cache_name(llm_string)
|
||||
|
||||
# return vectorstore client for the specific llm string
|
||||
if cache_name in self._cache_dict:
|
||||
return self._cache_dict[cache_name]
|
||||
|
||||
# create new vectorstore client to create the cache
|
||||
if self.cosmos_client:
|
||||
self._cache_dict[cache_name] = AzureCosmosDBNoSqlVectorSearch(
|
||||
cosmos_client=self.cosmos_client,
|
||||
embedding=self.embedding,
|
||||
vector_embedding_policy=self.vector_embedding_policy,
|
||||
indexing_policy=self.indexing_policy,
|
||||
cosmos_container_properties=self.cosmos_container_properties,
|
||||
cosmos_database_properties=self.cosmos_database_properties,
|
||||
database_name=self.database_name,
|
||||
container_name=self.container_name,
|
||||
create_container=self.create_container,
|
||||
)
|
||||
|
||||
return self._cache_dict[cache_name]
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt."""
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
generations: List = []
|
||||
# Read from a Hash
|
||||
results = llm_cache.similarity_search(
|
||||
query=prompt,
|
||||
k=1,
|
||||
)
|
||||
if results:
|
||||
for document in results:
|
||||
try:
|
||||
generations.extend(loads(document.metadata["return_val"]))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Retrieving a cache value that could not be deserialized "
|
||||
"properly. This is likely due to the cache being in an "
|
||||
"older format. Please recreate your cache to avoid this "
|
||||
"error."
|
||||
)
|
||||
|
||||
generations.extend(
|
||||
_load_generations_from_json(document.metadata["return_val"])
|
||||
)
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for gen in return_val:
|
||||
if not isinstance(gen, Generation):
|
||||
raise ValueError(
|
||||
"CosmosDBNoSqlSemanticCache only supports caching of "
|
||||
f"normal LLM generations, got {type(gen)}"
|
||||
)
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
metadata = {
|
||||
"llm_string": llm_string,
|
||||
"prompt": prompt,
|
||||
"return_val": dumps([g for g in return_val]),
|
||||
}
|
||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear semantic cache for a given llm_string."""
|
||||
cache_name = self._cache_name(llm_string=kwargs["llm-string"])
|
||||
if cache_name in self._cache_dict:
|
||||
container = self._cache_dict["cache_name"].get_container()
|
||||
for item in container.read_all_items():
|
||||
container.delete_item(item)
|
||||
|
||||
|
||||
class OpenSearchSemanticCache(BaseCache):
|
||||
"""Cache that uses OpenSearch vector store backend"""
|
||||
|
||||
|
Reference in New Issue
Block a user