diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 184a4fcb154..d461c3b0952 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -3,8 +3,8 @@ import asyncio import re import time -from collections.abc import AsyncIterator, Iterator -from typing import Any, Optional, Union, cast +from collections.abc import AsyncIterator, Iterator, Sequence +from typing import Any, Callable, Dict, Optional, Union, cast from typing_extensions import override @@ -17,6 +17,12 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import RunnableConfig +from langchain_core.language_models.base import LanguageModelInput + +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.runnables.config import run_in_executor +from langchain_core.tools import BaseTool + class FakeMessagesListChatModel(BaseChatModel): """Fake ChatModel for testing purposes.""" @@ -367,3 +373,191 @@ class ParrotFakeChatModel(BaseChatModel): @property def _llm_type(self) -> str: return "parrot-fake-chat-model" + + +class FakeToolCallingListChatModel(BaseChatModel): + """Fake Tool calling ChatModel for testing purposes.""" + + responses: list[Union[dict,str,BaseMessage]] + """List of responses to **cycle** through in order.""" + sleep: Optional[float] = None + i: int = 0 + """List of responses to **cycle** through in order.""" + error_on_chunk_number: Optional[int] = None + """Internally incremented after every model invocation.""" + + @property + @override + def _llm_type(self) -> str: + return "fake-list-chat-model" + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + + if isinstance(response,BaseMessage): + message = response + elif isinstance(response,str): + message = AIMessage(content=response) + elif isinstance(response,dict): + message = AIMessage(content=response["content"], tool_calls = response["tool_calls"]) + else: + raise ValueError("Incorrect response type") + + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @override + async def _agenerate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return await run_in_executor( + None, + self._generate, + messages, + stop=stop, + run_manager=run_manager.get_sync() if run_manager else None, + **kwargs, + ) + + @override + def _stream( + self, + messages: list[BaseMessage], + stop: Union[list[str], None] = None, + run_manager: Union[CallbackManagerForLLMRun, None] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + + tool_calls = [] + if isinstance(response,BaseMessage): + output = response.content + elif isinstance(response,str): + output = response + elif isinstance(response,dict): + output = response["content"] + tool_calls = response["tool_calls"] + else: + raise ValueError("Incorrect response type") + + for i_c, c in enumerate(output): + if self.sleep is not None: + time.sleep(self.sleep) + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise FakeListChatModelError + + yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls)) + + @override + async def _astream( + self, + messages: list[BaseMessage], + stop: Union[list[str], None] = None, + run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + response = self.responses[self.i] + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + + tool_calls = [] + if isinstance(response,BaseMessage): + output = response.content + elif isinstance(response,str): + output = response + elif isinstance(response,dict): + output = response["content"] + tool_calls = response["tool_calls"] + else: + raise ValueError("Incorrect response type") + + for i_c, c in enumerate(output): + if self.sleep is not None: + await asyncio.sleep(self.sleep) + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise FakeListChatModelError + yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls)) + + @property + @override + def _identifying_params(self) -> dict[str, Any]: + return {"responses": self.responses} + + @override + # manually override batch to preserve batch ordering with no concurrency + def batch( + self, + inputs: list[Any], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> list[BaseMessage]: + if isinstance(config, list): + return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)] + return [self.invoke(m, config, **kwargs) for m in inputs] + + @override + async def abatch( + self, + inputs: list[Any], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> list[BaseMessage]: + if isinstance(config, list): + # do Not use an async iterator here because need explicit ordering + return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)] + # do Not use an async iterator here because need explicit ordering + return [await self.ainvoke(m, config, **kwargs) for m in inputs] + + @override + def bind_tools( + self, + tools: Sequence[ + Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 + ], + *, + tool_choice: Optional[Union[str]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools to the model. + + Args: + tools: Sequence of tools to bind to the model. + tool_choice: The tool to use. If "any" then any tool can be used. + + Returns: + A Runnable that returns a message. + """ + kwargs["tool_choice"] = tool_choice + raise super().bind(tools=tools, **kwargs) +