This commit is contained in:
amosctlee 2025-07-28 22:00:37 +08:00
parent 5617933649
commit 8fdf4c274e

View File

@ -3,7 +3,7 @@
import asyncio import asyncio
import re import re
import time import time
from collections.abc import AsyncIterator, Iterator from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from typing_extensions import override from typing_extensions import override
@ -15,9 +15,8 @@ from langchain_core.callbacks import (
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig, Runnable, RunnableLambda from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from collections.abc import Sequence, Callable
class FakeMessagesListChatModel(BaseChatModel): class FakeMessagesListChatModel(BaseChatModel):
@ -217,25 +216,23 @@ class FakeChatModel(SimpleChatModel):
choice_input: Optional[Union[dict[str, Any], tuple]] = None, choice_input: Optional[Union[dict[str, Any], tuple]] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable: ) -> Runnable:
def _call_tool_chain(input: dict) -> dict: def _call_tool_chain(tool_input: dict) -> dict:
if tool_choice is None: if tool_choice is None:
return {} return {}
if choice_input is None: if choice_input is None:
raise ValueError( msg = "choice_input must be provided if tool_choice is specified."
"choice_input must be provided if tool_choice is specified." raise ValueError(msg)
)
tool = next(t for t in tools if t.name == tool_choice) tool = next(t for t in tools if t.name == tool_choice)
params = choice_input if choice_input is not None else input params = choice_input if choice_input is not None else tool_input
if hasattr(tool, "run"): if hasattr(tool, "run"):
return tool.run(params) return tool.run(params)
else:
raise ValueError( f"Tool {tool} is not callable or does not have a 'run' method."
f"Tool {tool} is not callable or does not have a 'run' method." raise ValueError(msg)
)
return RunnableLambda(func=_call_tool_chain) return RunnableLambda(func=_call_tool_chain)