mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community[patch]: Add async methods to VectorStoreQATool (#16949)
This commit is contained in:
parent
fb7552bfcf
commit
ab025507bc
@ -3,7 +3,10 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForToolRun,
|
||||||
|
CallbackManagerForToolRun,
|
||||||
|
)
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@ -51,9 +54,30 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
|
|||||||
chain = RetrievalQA.from_chain_type(
|
chain = RetrievalQA.from_chain_type(
|
||||||
self.llm, retriever=self.vectorstore.as_retriever()
|
self.llm, retriever=self.vectorstore.as_retriever()
|
||||||
)
|
)
|
||||||
return chain.run(
|
return chain.invoke(
|
||||||
query, callbacks=run_manager.get_child() if run_manager else None
|
{chain.input_key: query},
|
||||||
|
config={"callbacks": [run_manager.get_child() if run_manager else None]},
|
||||||
|
)[chain.output_key]
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
from langchain.chains.retrieval_qa.base import RetrievalQA
|
||||||
|
|
||||||
|
chain = RetrievalQA.from_chain_type(
|
||||||
|
self.llm, retriever=self.vectorstore.as_retriever()
|
||||||
)
|
)
|
||||||
|
return (
|
||||||
|
await chain.ainvoke(
|
||||||
|
{chain.input_key: query},
|
||||||
|
config={
|
||||||
|
"callbacks": [run_manager.get_child() if run_manager else None]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)[chain.output_key]
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
||||||
@ -87,7 +111,28 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
|||||||
self.llm, retriever=self.vectorstore.as_retriever()
|
self.llm, retriever=self.vectorstore.as_retriever()
|
||||||
)
|
)
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
chain(
|
chain.invoke(
|
||||||
|
{chain.question_key: query},
|
||||||
|
return_only_outputs=True,
|
||||||
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
from langchain.chains.qa_with_sources.retrieval import (
|
||||||
|
RetrievalQAWithSourcesChain,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
||||||
|
self.llm, retriever=self.vectorstore.as_retriever()
|
||||||
|
)
|
||||||
|
return json.dumps(
|
||||||
|
await chain.ainvoke(
|
||||||
{chain.question_key: query},
|
{chain.question_key: query},
|
||||||
return_only_outputs=True,
|
return_only_outputs=True,
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
|
Loading…
Reference in New Issue
Block a user