From d1d693b2a77940d264cd7df3906c58d1dedb3f58 Mon Sep 17 00:00:00 2001 From: sudranga Date: Wed, 29 Nov 2023 19:06:13 -0800 Subject: [PATCH] =?UTF-8?q?Fix=20issue=20where=20response=5Fif=5Fno=5Fdocs?= =?UTF-8?q?=5Ffound=20is=20not=20implemented=20on=20async=E2=80=A6=20(#132?= =?UTF-8?q?97)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Response_if_no_docs_found is not implemented in ConversationalRetrievalChain for async code paths. Implemented it and added test cases Co-authored-by: Harrison Chase --- .../chains/conversational_retrieval/base.py | 21 +++++--- .../chains/test_conversation_retrieval.py | 49 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 3b3fd149fbd..dbe28df9b27 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -204,14 +204,19 @@ class BaseConversationalRetrievalChain(Chain): else: docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg] - new_inputs = inputs.copy() - if self.rephrase_question: - new_inputs["question"] = new_question - new_inputs["chat_history"] = chat_history_str - answer = await self.combine_docs_chain.arun( - input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs - ) - output: Dict[str, Any] = {self.output_key: answer} + output: Dict[str, Any] = {} + if self.response_if_no_docs_found is not None and len(docs) == 0: + output[self.output_key] = self.response_if_no_docs_found + else: + new_inputs = inputs.copy() + if self.rephrase_question: + new_inputs["question"] = new_question + new_inputs["chat_history"] = chat_history_str + answer = await self.combine_docs_chain.arun( + input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + ) + output[self.output_key] = answer + if self.return_source_documents: output["source_documents"] = docs if self.return_generated_question: diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py index a97e7786763..d7a56603bbb 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py @@ -1,4 +1,5 @@ """Test conversation chain and memory.""" +import pytest from langchain_core.documents import Document from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain @@ -7,6 +8,54 @@ from langchain.memory.buffer import ConversationBufferMemory from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever +@pytest.mark.asyncio +async def atest_simple() -> None: + fixed_resp = "I don't know" + answer = "I know the answer!" + llm = FakeListLLM(responses=[answer]) + retriever = SequentialRetriever(sequential_responses=[[]]) + memory = ConversationBufferMemory( + k=1, output_key="answer", memory_key="chat_history", return_messages=True + ) + qa_chain = ConversationalRetrievalChain.from_llm( + llm=llm, + memory=memory, + retriever=retriever, + return_source_documents=True, + rephrase_question=False, + response_if_no_docs_found=fixed_resp, + verbose=True, + ) + got = await qa_chain.acall("What is the answer?") + assert got["chat_history"][1].content == fixed_resp + assert got["answer"] == fixed_resp + + +@pytest.mark.asyncio +async def atest_fixed_message_response_when_docs_found() -> None: + fixed_resp = "I don't know" + answer = "I know the answer!" + llm = FakeListLLM(responses=[answer]) + retriever = SequentialRetriever( + sequential_responses=[[Document(page_content=answer)]] + ) + memory = ConversationBufferMemory( + k=1, output_key="answer", memory_key="chat_history", return_messages=True + ) + qa_chain = ConversationalRetrievalChain.from_llm( + llm=llm, + memory=memory, + retriever=retriever, + return_source_documents=True, + rephrase_question=False, + response_if_no_docs_found=fixed_resp, + verbose=True, + ) + got = await qa_chain.acall("What is the answer?") + assert got["chat_history"][1].content == answer + assert got["answer"] == answer + + def test_fixed_message_response_when_no_docs_found() -> None: fixed_resp = "I don't know" answer = "I know the answer!"