From ec026004cb83cc430ed9866fb776018639f79fa6 Mon Sep 17 00:00:00 2001 From: Jib Date: Mon, 18 Mar 2024 15:44:34 -0400 Subject: [PATCH] mongodb[patch]: Remove in-memory cache from cache abstractions (#18987) ## Description * In memory cache easily gets out of sync with the server cache, so we will remove it entirely to reduce the issues around invalidated caches. ## Dependencies None - [x] If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Co-authored-by: Erick Friis --- .../mongodb/langchain_mongodb/cache.py | 27 ------------------- .../mongodb/tests/unit_tests/test_cache.py | 2 -- 2 files changed, 29 deletions(-) diff --git a/libs/partners/mongodb/langchain_mongodb/cache.py b/libs/partners/mongodb/langchain_mongodb/cache.py index 1017948b175..b7ace63a1e7 100644 --- a/libs/partners/mongodb/langchain_mongodb/cache.py +++ b/libs/partners/mongodb/langchain_mongodb/cache.py @@ -130,7 +130,6 @@ class MongoDBCache(BaseCache): PROMPT = "prompt" LLM = "llm" RETURN_VAL = "return_val" - _local_cache: Dict[str, Any] def __init__( self, @@ -153,7 +152,6 @@ class MongoDBCache(BaseCache): self.client = _generate_mongo_client(connection_string) self.__database_name = database_name self.__collection_name = collection_name - self._local_cache = {} if self.__collection_name not in self.database.list_collection_names(): self.database.create_collection(self.__collection_name) @@ -172,10 +170,6 @@ class MongoDBCache(BaseCache): def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" - cache_key = self._generate_local_key(prompt, llm_string) - if cache_key in self._local_cache: - return self._local_cache[cache_key] - return_doc = ( self.collection.find_one(self._generate_keys(prompt, llm_string)) or {} ) @@ -184,9 +178,6 @@ class MongoDBCache(BaseCache): def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" - cache_key = self._generate_local_key(prompt, llm_string) - self._local_cache[cache_key] = return_val - self.collection.update_one( {**self._generate_keys(prompt, llm_string)}, {"$set": {self.RETURN_VAL: _dumps_generations(return_val)}}, @@ -197,10 +188,6 @@ class MongoDBCache(BaseCache): """Create keyed fields for caching layer""" return {self.PROMPT: prompt, self.LLM: llm_string} - def _generate_local_key(self, prompt: str, llm_string: str) -> str: - """Create keyed fields for local caching layer""" - return f"{prompt}#{llm_string}" - def clear(self, **kwargs: Any) -> None: """Clear cache that can take additional keyword arguments. Any additional arguments will propagate as filtration criteria for @@ -221,7 +208,6 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): LLM = "llm_string" RETURN_VAL = "return_val" - _local_cache: Dict[str, Any] def __init__( self, @@ -250,20 +236,15 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): self.collection = client[database_name][collection_name] self._wait_until_ready = wait_until_ready super().__init__(self.collection, embedding, **kwargs) # type: ignore - self._local_cache = dict() def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" - cache_key = self._generate_local_key(prompt, llm_string) - if cache_key in self._local_cache: - return self._local_cache[cache_key] search_response = self.similarity_search_with_score( prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}} ) if search_response: return_val = search_response[0][0].metadata.get(self.RETURN_VAL) response = _loads_generations(return_val) or return_val # type: ignore - self._local_cache[cache_key] = response return response return None @@ -275,9 +256,6 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): wait_until_ready: Optional[bool] = None, ) -> None: """Update cache based on prompt and llm_string.""" - cache_key = self._generate_local_key(prompt, llm_string) - self._local_cache[cache_key] = return_val - self.add_texts( [prompt], [ @@ -295,10 +273,6 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): if wait: _wait_until(is_indexed, return_val) - def _generate_local_key(self, prompt: str, llm_string: str) -> str: - """Create keyed fields for local caching layer""" - return f"{prompt}#{llm_string}" - def clear(self, **kwargs: Any) -> None: """Clear cache that can take additional keyword arguments. Any additional arguments will propagate as filtration criteria for @@ -309,4 +283,3 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): self.clear(llm_string="fake-model") """ self.collection.delete_many({**kwargs}) - self._local_cache.clear() diff --git a/libs/partners/mongodb/tests/unit_tests/test_cache.py b/libs/partners/mongodb/tests/unit_tests/test_cache.py index 326372a3ed2..6b1932ad8cc 100644 --- a/libs/partners/mongodb/tests/unit_tests/test_cache.py +++ b/libs/partners/mongodb/tests/unit_tests/test_cache.py @@ -30,7 +30,6 @@ class PatchedMongoDBCache(MongoDBCache): self.__database_name = database_name self.__collection_name = collection_name self.client = {self.__database_name: {self.__collection_name: MockCollection()}} # type: ignore - self._local_cache = {} @property def database(self) -> Any: # type: ignore @@ -55,7 +54,6 @@ class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache): ): self.collection = MockCollection() self._wait_until_ready = False - self._local_cache = dict() MongoDBAtlasVectorSearch.__init__( self, self.collection,