mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-17 04:18:53 +00:00
Update VectorDBQA to RetrievalQA in tools (#3698)
Because `VectorDBQA` and `VectorDBQAWithSourcesChain` are deprecated
This commit is contained in:
parent
32793f94fd
commit
1bf1c37c0c
@ -5,8 +5,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
|
||||||
from langchain.chains.retrieval_qa.base import VectorDBQA
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
@ -45,12 +44,14 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
|
|||||||
|
|
||||||
def _run(self, query: str) -> str:
|
def _run(self, query: str) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore)
|
chain = RetrievalQA.from_chain_type(
|
||||||
|
self.llm, retriever=self.vectorstore.as_retriever()
|
||||||
|
)
|
||||||
return chain.run(query)
|
return chain.run(query)
|
||||||
|
|
||||||
async def _arun(self, query: str) -> str:
|
async def _arun(self, query: str) -> str:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
raise NotImplementedError("VectorDBQATool does not support async")
|
raise NotImplementedError("VectorStoreQATool does not support async")
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
||||||
@ -71,11 +72,11 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
|||||||
|
|
||||||
def _run(self, query: str) -> str:
|
def _run(self, query: str) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
chain = VectorDBQAWithSourcesChain.from_chain_type(
|
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
||||||
self.llm, vectorstore=self.vectorstore
|
self.llm, retriever=self.vectorstore.as_retriever()
|
||||||
)
|
)
|
||||||
return json.dumps(chain({chain.question_key: query}, return_only_outputs=True))
|
return json.dumps(chain({chain.question_key: query}, return_only_outputs=True))
|
||||||
|
|
||||||
async def _arun(self, query: str) -> str:
|
async def _arun(self, query: str) -> str:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
raise NotImplementedError("VectorDBQATool does not support async")
|
raise NotImplementedError("VectorStoreQAWithSourcesTool does not support async")
|
||||||
|
Loading…
Reference in New Issue
Block a user