diff --git a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py b/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py index 70ee98513ea..7ace5730ad5 100644 --- a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py +++ b/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py @@ -19,10 +19,20 @@ def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None: qa = RetrievalQAWithSourcesChain.from_llm( llm=OpenAI(), retriever=docsearch.as_retriever() ) - qa("What did the president say about Ketanji Brown Jackson?") - - file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml" + result = qa("What did the president say about Ketanji Brown Jackson?") + assert "question" in result.keys() + assert "answer" in result.keys() + assert "sources" in result.keys() + file_path = str(tmp_path) + "/RetrievalQAWithSourcesChain.yaml" qa.save(file_path=file_path) qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever()) assert qa_loaded == qa + + qa2 = RetrievalQAWithSourcesChain.from_chain_type( + llm=OpenAI(), retriever=docsearch.as_retriever(), chain_type="stuff" + ) + result2 = qa2("What did the president say about Ketanji Brown Jackson?") + assert "question" in result2.keys() + assert "answer" in result2.keys() + assert "sources" in result2.keys()