diff --git a/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py b/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py index be594e09212..9ad45ae4e1a 100644 --- a/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py +++ b/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py @@ -68,6 +68,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): session_id_key: str = DEFAULT_SESSION_ID_KEY, history_key: str = DEFAULT_HISTORY_KEY, create_index: bool = True, + history_size: Optional[int] = None, index_kwargs: Optional[Dict] = None, ): """Initialize with a MongoDBChatMessageHistory instance. @@ -88,6 +89,8 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): name of the field that stores the chat history. create_index: Optional[bool] whether to create an index on the session id field. + history_size: Optional[int] + count of (most recent) messages to fetch from MongoDB. index_kwargs: Optional[Dict] additional keyword arguments to pass to the index creation. """ @@ -97,6 +100,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): self.collection_name = collection_name self.session_id_key = session_id_key self.history_key = history_key + self.history_size = history_size try: self.client: MongoClient = MongoClient(connection_string) @@ -114,7 +118,15 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve the messages from MongoDB""" try: - cursor = self.collection.find({self.session_id_key: self.session_id}) + if self.history_size is None: + cursor = self.collection.find({self.session_id_key: self.session_id}) + else: + skip_count = max( + 0, self.collection.count_documents({}) - self.history_size + ) + cursor = self.collection.find( + {self.session_id_key: self.session_id}, skip=skip_count + ) except errors.OperationFailure as error: logger.error(error) diff --git a/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py b/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py index 4f3fa4af2b3..2031602c2b2 100644 --- a/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py +++ b/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py @@ -16,6 +16,7 @@ class PatchedMongoDBChatMessageHistory(MongoDBChatMessageHistory): self.collection = MockCollection() self.session_id_key = "SessionId" self.history_key = "History" + self.history_size = None def test_memory_with_message_store() -> None: