From bebf46c4a2c0b6eb3274df74a38dc6ca162262f5 Mon Sep 17 00:00:00 2001 From: Pedro Lima Date: Mon, 6 May 2024 21:27:54 +0100 Subject: [PATCH] community: added args_schema to YahooFinanceNewsTool (#21232) Description: this change adds args_schema (pydantic BaseModel) to YahooFinanceNewsTool for correct schema formatting on LLM function calls Issue: currently using YahooFinanceNewsTool with OpenAI function calling returns the following error "TypeError("YahooFinanceNewsTool._run() got an unexpected keyword argument '__arg1'")". This happens because the schema sent to the LLM is "input: "{'__arg1': 'MSFT'}"" while the method should be called with the "query" parameter. Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../langchain_community/tools/yahoo_finance_news.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/tools/yahoo_finance_news.py b/libs/community/langchain_community/tools/yahoo_finance_news.py index c470fa01d2b..4e120e45ce4 100644 --- a/libs/community/langchain_community/tools/yahoo_finance_news.py +++ b/libs/community/langchain_community/tools/yahoo_finance_news.py @@ -1,7 +1,8 @@ -from typing import Iterable, Optional +from typing import Iterable, Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.documents import Document +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool from requests.exceptions import HTTPError, ReadTimeout from urllib3.exceptions import ConnectionError @@ -9,6 +10,12 @@ from urllib3.exceptions import ConnectionError from langchain_community.document_loaders.web_base import WebBaseLoader +class YahooFinanceNewsInput(BaseModel): + """Input for the YahooFinanceNews tool.""" + + query: str = Field(description="company ticker query to look up") + + class YahooFinanceNewsTool(BaseTool): """Tool that searches financial news on Yahoo Finance.""" @@ -22,6 +29,8 @@ class YahooFinanceNewsTool(BaseTool): top_k: int = 10 """The number of results to return.""" + args_schema: Type[BaseModel] = YahooFinanceNewsInput + def _run( self, query: str,