mongodb[patch]: Remove in-memory cache from cache abstractions (#18987)

## Description
* In memory cache easily gets out of sync with the server cache, so we
will remove it entirely to reduce the issues around invalidated caches.

## Dependencies
None

- [x]  If you're adding a new integration, please include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Jib 2024-03-18 15:44:34 -04:00 committed by GitHub
parent 866d6408af
commit ec026004cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 0 additions and 29 deletions

View File

@ -130,7 +130,6 @@ class MongoDBCache(BaseCache):
PROMPT = "prompt" PROMPT = "prompt"
LLM = "llm" LLM = "llm"
RETURN_VAL = "return_val" RETURN_VAL = "return_val"
_local_cache: Dict[str, Any]
def __init__( def __init__(
self, self,
@ -153,7 +152,6 @@ class MongoDBCache(BaseCache):
self.client = _generate_mongo_client(connection_string) self.client = _generate_mongo_client(connection_string)
self.__database_name = database_name self.__database_name = database_name
self.__collection_name = collection_name self.__collection_name = collection_name
self._local_cache = {}
if self.__collection_name not in self.database.list_collection_names(): if self.__collection_name not in self.database.list_collection_names():
self.database.create_collection(self.__collection_name) self.database.create_collection(self.__collection_name)
@ -172,10 +170,6 @@ class MongoDBCache(BaseCache):
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
if cache_key in self._local_cache:
return self._local_cache[cache_key]
return_doc = ( return_doc = (
self.collection.find_one(self._generate_keys(prompt, llm_string)) or {} self.collection.find_one(self._generate_keys(prompt, llm_string)) or {}
) )
@ -184,9 +178,6 @@ class MongoDBCache(BaseCache):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
self._local_cache[cache_key] = return_val
self.collection.update_one( self.collection.update_one(
{**self._generate_keys(prompt, llm_string)}, {**self._generate_keys(prompt, llm_string)},
{"$set": {self.RETURN_VAL: _dumps_generations(return_val)}}, {"$set": {self.RETURN_VAL: _dumps_generations(return_val)}},
@ -197,10 +188,6 @@ class MongoDBCache(BaseCache):
"""Create keyed fields for caching layer""" """Create keyed fields for caching layer"""
return {self.PROMPT: prompt, self.LLM: llm_string} return {self.PROMPT: prompt, self.LLM: llm_string}
def _generate_local_key(self, prompt: str, llm_string: str) -> str:
"""Create keyed fields for local caching layer"""
return f"{prompt}#{llm_string}"
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments. """Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for Any additional arguments will propagate as filtration criteria for
@ -221,7 +208,6 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
LLM = "llm_string" LLM = "llm_string"
RETURN_VAL = "return_val" RETURN_VAL = "return_val"
_local_cache: Dict[str, Any]
def __init__( def __init__(
self, self,
@ -250,20 +236,15 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
self.collection = client[database_name][collection_name] self.collection = client[database_name][collection_name]
self._wait_until_ready = wait_until_ready self._wait_until_ready = wait_until_ready
super().__init__(self.collection, embedding, **kwargs) # type: ignore super().__init__(self.collection, embedding, **kwargs) # type: ignore
self._local_cache = dict()
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
if cache_key in self._local_cache:
return self._local_cache[cache_key]
search_response = self.similarity_search_with_score( search_response = self.similarity_search_with_score(
prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}} prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}}
) )
if search_response: if search_response:
return_val = search_response[0][0].metadata.get(self.RETURN_VAL) return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
response = _loads_generations(return_val) or return_val # type: ignore response = _loads_generations(return_val) or return_val # type: ignore
self._local_cache[cache_key] = response
return response return response
return None return None
@ -275,9 +256,6 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
wait_until_ready: Optional[bool] = None, wait_until_ready: Optional[bool] = None,
) -> None: ) -> None:
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
self._local_cache[cache_key] = return_val
self.add_texts( self.add_texts(
[prompt], [prompt],
[ [
@ -295,10 +273,6 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
if wait: if wait:
_wait_until(is_indexed, return_val) _wait_until(is_indexed, return_val)
def _generate_local_key(self, prompt: str, llm_string: str) -> str:
"""Create keyed fields for local caching layer"""
return f"{prompt}#{llm_string}"
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments. """Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for Any additional arguments will propagate as filtration criteria for
@ -309,4 +283,3 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
self.clear(llm_string="fake-model") self.clear(llm_string="fake-model")
""" """
self.collection.delete_many({**kwargs}) self.collection.delete_many({**kwargs})
self._local_cache.clear()

View File

@ -30,7 +30,6 @@ class PatchedMongoDBCache(MongoDBCache):
self.__database_name = database_name self.__database_name = database_name
self.__collection_name = collection_name self.__collection_name = collection_name
self.client = {self.__database_name: {self.__collection_name: MockCollection()}} # type: ignore self.client = {self.__database_name: {self.__collection_name: MockCollection()}} # type: ignore
self._local_cache = {}
@property @property
def database(self) -> Any: # type: ignore def database(self) -> Any: # type: ignore
@ -55,7 +54,6 @@ class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache):
): ):
self.collection = MockCollection() self.collection = MockCollection()
self._wait_until_ready = False self._wait_until_ready = False
self._local_cache = dict()
MongoDBAtlasVectorSearch.__init__( MongoDBAtlasVectorSearch.__init__(
self, self,
self.collection, self.collection,