diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 721176af33d..d683fbdac34 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -3,7 +3,7 @@ import asyncio import re import time -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Callable, Iterator, Sequence from typing import Any, Optional, Union, cast from typing_extensions import override @@ -15,9 +15,8 @@ from langchain_core.callbacks import ( from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import RunnableConfig, Runnable, RunnableLambda +from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda from langchain_core.tools import BaseTool -from collections.abc import Sequence, Callable class FakeMessagesListChatModel(BaseChatModel): @@ -217,25 +216,23 @@ class FakeChatModel(SimpleChatModel): choice_input: Optional[Union[dict[str, Any], tuple]] = None, **kwargs: Any, ) -> Runnable: - def _call_tool_chain(input: dict) -> dict: + def _call_tool_chain(tool_input: dict) -> dict: if tool_choice is None: return {} if choice_input is None: - raise ValueError( - "choice_input must be provided if tool_choice is specified." - ) + msg = "choice_input must be provided if tool_choice is specified." + raise ValueError(msg) tool = next(t for t in tools if t.name == tool_choice) - params = choice_input if choice_input is not None else input + params = choice_input if choice_input is not None else tool_input if hasattr(tool, "run"): return tool.run(params) - else: - raise ValueError( - f"Tool {tool} is not callable or does not have a 'run' method." - ) + + f"Tool {tool} is not callable or does not have a 'run' method." + raise ValueError(msg) return RunnableLambda(func=_call_tool_chain)