diff --git a/libs/community/langchain_community/tools/vectorstore/tool.py b/libs/community/langchain_community/tools/vectorstore/tool.py index e51deaeaa1d..c0cbdb67062 100644 --- a/libs/community/langchain_community/tools/vectorstore/tool.py +++ b/libs/community/langchain_community/tools/vectorstore/tool.py @@ -3,7 +3,10 @@ import json from typing import Any, Dict, Optional -from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain_core.language_models import BaseLanguageModel from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool @@ -51,9 +54,30 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool): chain = RetrievalQA.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) - return chain.run( - query, callbacks=run_manager.get_child() if run_manager else None + return chain.invoke( + {chain.input_key: query}, + config={"callbacks": [run_manager.get_child() if run_manager else None]}, + )[chain.output_key] + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Use the tool asynchronously.""" + from langchain.chains.retrieval_qa.base import RetrievalQA + + chain = RetrievalQA.from_chain_type( + self.llm, retriever=self.vectorstore.as_retriever() ) + return ( + await chain.ainvoke( + {chain.input_key: query}, + config={ + "callbacks": [run_manager.get_child() if run_manager else None] + }, + ) + )[chain.output_key] class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): @@ -87,7 +111,28 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): self.llm, retriever=self.vectorstore.as_retriever() ) return json.dumps( - chain( + chain.invoke( + {chain.question_key: query}, + return_only_outputs=True, + callbacks=run_manager.get_child() if run_manager else None, + ) + ) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Use the tool asynchronously.""" + from langchain.chains.qa_with_sources.retrieval import ( + RetrievalQAWithSourcesChain, + ) + + chain = RetrievalQAWithSourcesChain.from_chain_type( + self.llm, retriever=self.vectorstore.as_retriever() + ) + return json.dumps( + await chain.ainvoke( {chain.question_key: query}, return_only_outputs=True, callbacks=run_manager.get_child() if run_manager else None,