diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index da01bd7098a..de057275225 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -2,7 +2,7 @@ """Load tools.""" import warnings from typing import Any, Dict, List, Optional, Callable, Tuple -from mypy_extensions import KwArg +from mypy_extensions import Arg, KwArg from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager @@ -74,7 +74,7 @@ def _get_terminal() -> BaseTool: ) -_BASE_TOOLS = { +_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = { "python_repl": _get_python_repl, "requests": _get_tools_requests_get, # preserved for backwards compatability "requests_get": _get_tools_requests_get, @@ -120,7 +120,7 @@ def _get_open_meteo_api(llm: BaseLLM) -> BaseTool: ) -_LLM_TOOLS = { +_LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = { "pal-math": _get_pal_math, "pal-colored-objects": _get_pal_colored_objects, "llm-math": _get_llm_math, @@ -226,7 +226,9 @@ def _get_human_tool(**kwargs: Any) -> BaseTool: return HumanInputRun(**kwargs) -_EXTRA_LLM_TOOLS = { +_EXTRA_LLM_TOOLS: Dict[ + str, Tuple[Callable[[Arg(BaseLLM, "llm"), KwArg(Any)], BaseTool], List[str]] +] = { "news-api": (_get_news_api, ["news_api_key"]), "tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]), "podcast-api": (_get_podcast_api, ["listen_api_key"]),