This commit is contained in:
amosctlee 2025-07-28 22:00:37 +08:00
parent 5617933649
commit 8fdf4c274e

View File

@ -3,7 +3,7 @@
import asyncio
import re
import time
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from typing import Any, Optional, Union, cast
from typing_extensions import override
@ -15,9 +15,8 @@ from langchain_core.callbacks import (
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig, Runnable, RunnableLambda
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from collections.abc import Sequence, Callable
class FakeMessagesListChatModel(BaseChatModel):
@ -217,25 +216,23 @@ class FakeChatModel(SimpleChatModel):
choice_input: Optional[Union[dict[str, Any], tuple]] = None,
**kwargs: Any,
) -> Runnable:
def _call_tool_chain(input: dict) -> dict:
def _call_tool_chain(tool_input: dict) -> dict:
if tool_choice is None:
return {}
if choice_input is None:
raise ValueError(
"choice_input must be provided if tool_choice is specified."
)
msg = "choice_input must be provided if tool_choice is specified."
raise ValueError(msg)
tool = next(t for t in tools if t.name == tool_choice)
params = choice_input if choice_input is not None else input
params = choice_input if choice_input is not None else tool_input
if hasattr(tool, "run"):
return tool.run(params)
else:
raise ValueError(
f"Tool {tool} is not callable or does not have a 'run' method."
)
raise ValueError(msg)
return RunnableLambda(func=_call_tool_chain)