diff --git a/libs/community/langchain_community/tools/tavily_search/tool.py b/libs/community/langchain_community/tools/tavily_search/tool.py index 765f99daecf..32cd21d4197 100644 --- a/libs/community/langchain_community/tools/tavily_search/tool.py +++ b/libs/community/langchain_community/tools/tavily_search/tool.py @@ -1,6 +1,6 @@ """Tool for the Tavily search API.""" -from typing import Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, @@ -149,6 +149,15 @@ class TavilySearchResults(BaseTool): # type: ignore[override, override] api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type] response_format: Literal["content_and_artifact"] = "content_and_artifact" + def __init__(self, **kwargs: Any) -> None: + # Create api_wrapper with tavily_api_key if provided + if "tavily_api_key" in kwargs: + kwargs["api_wrapper"] = TavilySearchAPIWrapper( + tavily_api_key=kwargs["tavily_api_key"] + ) + + super().__init__(**kwargs) + def _run( self, query: str,