diff --git a/libs/community/langchain_community/tools/yahoo_finance_news.py b/libs/community/langchain_community/tools/yahoo_finance_news.py index c470fa01d2b..4e120e45ce4 100644 --- a/libs/community/langchain_community/tools/yahoo_finance_news.py +++ b/libs/community/langchain_community/tools/yahoo_finance_news.py @@ -1,7 +1,8 @@ -from typing import Iterable, Optional +from typing import Iterable, Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.documents import Document +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool from requests.exceptions import HTTPError, ReadTimeout from urllib3.exceptions import ConnectionError @@ -9,6 +10,12 @@ from urllib3.exceptions import ConnectionError from langchain_community.document_loaders.web_base import WebBaseLoader +class YahooFinanceNewsInput(BaseModel): + """Input for the YahooFinanceNews tool.""" + + query: str = Field(description="company ticker query to look up") + + class YahooFinanceNewsTool(BaseTool): """Tool that searches financial news on Yahoo Finance.""" @@ -22,6 +29,8 @@ class YahooFinanceNewsTool(BaseTool): top_k: int = 10 """The number of results to return.""" + args_schema: Type[BaseModel] = YahooFinanceNewsInput + def _run( self, query: str,