mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
enable serde retrieval qa with sources (#10132)
#3983 mentions serialization/deserialization issues with both `RetrievalQA` & `RetrievalQAWithSourcesChain`. `RetrievalQA` has already been fixed in #5818. Mimicing #5818, I added the logic for `RetrievalQAWithSourcesChain`. --------- Co-authored-by: Markus Tretzmüller <markus.tretzmueller@cortecs.at> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
62fa2bc518
commit
b3a8fc7cb1
@ -20,6 +20,7 @@ from langchain.chains.llm_checker.base import LLMCheckerChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.llm_requests import LLMRequestsChain
|
||||
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
|
||||
from langchain.llms.loading import load_llm, load_llm_from_config
|
||||
@ -424,6 +425,30 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
|
||||
)
|
||||
|
||||
|
||||
def _load_retrieval_qa_with_sources_chain(
|
||||
config: dict, **kwargs: Any
|
||||
) -> RetrievalQAWithSourcesChain:
|
||||
if "retriever" in kwargs:
|
||||
retriever = kwargs.pop("retriever")
|
||||
else:
|
||||
raise ValueError("`retriever` must be present.")
|
||||
if "combine_documents_chain" in config:
|
||||
combine_documents_chain_config = config.pop("combine_documents_chain")
|
||||
combine_documents_chain = load_chain_from_config(combine_documents_chain_config)
|
||||
elif "combine_documents_chain_path" in config:
|
||||
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
|
||||
else:
|
||||
raise ValueError(
|
||||
"One of `combine_documents_chain` or "
|
||||
"`combine_documents_chain_path` must be present."
|
||||
)
|
||||
return RetrievalQAWithSourcesChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
retriever=retriever,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
|
||||
if "vectorstore" in kwargs:
|
||||
vectorstore = kwargs.pop("vectorstore")
|
||||
@ -537,6 +562,7 @@ type_to_loader_dict = {
|
||||
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
|
||||
"vector_db_qa": _load_vector_db_qa,
|
||||
"retrieval_qa": _load_retrieval_qa,
|
||||
"retrieval_qa_with_sources_chain": _load_retrieval_qa_with_sources_chain,
|
||||
"graph_cypher_chain": _load_graph_cypher_chain,
|
||||
}
|
||||
|
||||
|
@ -60,3 +60,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
return "retrieval_qa_with_sources_chain"
|
||||
|
@ -0,0 +1,28 @@
|
||||
"""Test RetrievalQA functionality."""
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.loading import load_chain
|
||||
from langchain.document_loaders import DirectoryLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
|
||||
def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None:
|
||||
"""Test saving and loading."""
|
||||
loader = DirectoryLoader("docs/extras/modules/", glob="*.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(documents)
|
||||
embeddings = OpenAIEmbeddings()
|
||||
docsearch = FAISS.from_documents(texts, embeddings)
|
||||
qa = RetrievalQAWithSourcesChain.from_llm(
|
||||
llm=OpenAI(), retriever=docsearch.as_retriever()
|
||||
)
|
||||
qa("What did the president say about Ketanji Brown Jackson?")
|
||||
|
||||
file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml"
|
||||
qa.save(file_path=file_path)
|
||||
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
|
||||
|
||||
assert qa_loaded == qa
|
Loading…
Reference in New Issue
Block a user