This commit is contained in:
Sijil S 2025-07-28 17:48:23 -07:00 committed by GitHub
commit 7069a0661d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,8 +3,8 @@
import asyncio
import re
import time
from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional, Union, cast
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import Any, Callable, Dict, Optional, Union, cast
from typing_extensions import override
@ -17,6 +17,12 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool
class FakeMessagesListChatModel(BaseChatModel):
"""Fake ChatModel for testing purposes."""
@ -367,3 +373,191 @@ class ParrotFakeChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"
class FakeToolCallingListChatModel(BaseChatModel):
"""Fake Tool calling ChatModel for testing purposes."""
responses: list[Union[dict,str,BaseMessage]]
"""List of responses to **cycle** through in order."""
sleep: Optional[float] = None
i: int = 0
"""List of responses to **cycle** through in order."""
error_on_chunk_number: Optional[int] = None
"""Internally incremented after every model invocation."""
@property
@override
def _llm_type(self) -> str:
return "fake-list-chat-model"
@override
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
if isinstance(response,BaseMessage):
message = response
elif isinstance(response,str):
message = AIMessage(content=response)
elif isinstance(response,dict):
message = AIMessage(content=response["content"], tool_calls = response["tool_calls"])
else:
raise ValueError("Incorrect response type")
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@override
async def _agenerate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return await run_in_executor(
None,
self._generate,
messages,
stop=stop,
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)
@override
def _stream(
self,
messages: list[BaseMessage],
stop: Union[list[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
tool_calls = []
if isinstance(response,BaseMessage):
output = response.content
elif isinstance(response,str):
output = response
elif isinstance(response,dict):
output = response["content"]
tool_calls = response["tool_calls"]
else:
raise ValueError("Incorrect response type")
for i_c, c in enumerate(output):
if self.sleep is not None:
time.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls))
@override
async def _astream(
self,
messages: list[BaseMessage],
stop: Union[list[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
tool_calls = []
if isinstance(response,BaseMessage):
output = response.content
elif isinstance(response,str):
output = response
elif isinstance(response,dict):
output = response["content"]
tool_calls = response["tool_calls"]
else:
raise ValueError("Incorrect response type")
for i_c, c in enumerate(output):
if self.sleep is not None:
await asyncio.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls))
@property
@override
def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses}
@override
# manually override batch to preserve batch ordering with no concurrency
def batch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
return [self.invoke(m, config, **kwargs) for m in inputs]
@override
async def abatch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, config, **kwargs) for m in inputs]
@override
def bind_tools(
self,
tools: Sequence[
Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
],
*,
tool_choice: Optional[Union[str]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model.
Args:
tools: Sequence of tools to bind to the model.
tool_choice: The tool to use. If "any" then any tool can be used.
Returns:
A Runnable that returns a message.
"""
kwargs["tool_choice"] = tool_choice
raise super().bind(tools=tools, **kwargs)