diff --git a/libs/langchain/langchain/chat_models/fake.py b/libs/langchain/langchain/chat_models/fake.py index 1fe54fef64d..97596631e8c 100644 --- a/libs/langchain/langchain/chat_models/fake.py +++ b/libs/langchain/langchain/chat_models/fake.py @@ -7,9 +7,35 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.chat_models.base import SimpleChatModel +from langchain.chat_models.base import BaseChatModel, SimpleChatModel +from langchain.schema import ChatResult from langchain.schema.messages import AIMessageChunk, BaseMessage -from langchain.schema.output import ChatGenerationChunk +from langchain.schema.output import ChatGeneration, ChatGenerationChunk + + +class FakeMessagesListChatModel(BaseChatModel): + responses: List[BaseMessage] + sleep: Optional[float] = None + i: int = 0 + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + generation = ChatGeneration(message=response) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "fake-messages-list-chat-model" class FakeListChatModel(SimpleChatModel):