diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index aee5ffa2aea..89882ede89c 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -79,6 +79,9 @@ class BaseConversationalRetrievalChain(Chain): get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None """An optional function to get a string of the chat history. If None is provided, will use a default.""" + response_if_no_docs_found: Optional[str] + """If specified, the chain will return a fixed response if no docs + are found for the question. """ class Config: """Configuration for this pydantic object.""" @@ -143,14 +146,19 @@ class BaseConversationalRetrievalChain(Chain): docs = self._get_docs(new_question, inputs, run_manager=_run_manager) else: docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] - new_inputs = inputs.copy() - if self.rephrase_question: - new_inputs["question"] = new_question - new_inputs["chat_history"] = chat_history_str - answer = self.combine_docs_chain.run( - input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs - ) - output: Dict[str, Any] = {self.output_key: answer} + output: Dict[str, Any] = {} + if self.response_if_no_docs_found is not None and len(docs) == 0: + output[self.output_key] = self.response_if_no_docs_found + else: + new_inputs = inputs.copy() + if self.rephrase_question: + new_inputs["question"] = new_question + new_inputs["chat_history"] = chat_history_str + answer = self.combine_docs_chain.run( + input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + ) + output[self.output_key] = answer + if self.return_source_documents: output["source_documents"] = docs if self.return_generated_question: diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py new file mode 100644 index 00000000000..038b476cf4b --- /dev/null +++ b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py @@ -0,0 +1,52 @@ +"""Test conversation chain and memory.""" +from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain +from langchain.llms.fake import FakeListLLM +from langchain.memory.buffer import ConversationBufferMemory +from langchain.schema import Document +from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever + + +def test_fixed_message_response_when_no_docs_found() -> None: + fixed_resp = "I don't know" + answer = "I know the answer!" + llm = FakeListLLM(responses=[answer]) + retriever = SequentialRetriever(sequential_responses=[[]]) + memory = ConversationBufferMemory( + k=1, output_key="answer", memory_key="chat_history", return_messages=True + ) + qa_chain = ConversationalRetrievalChain.from_llm( + llm=llm, + memory=memory, + retriever=retriever, + return_source_documents=True, + rephrase_question=False, + response_if_no_docs_found=fixed_resp, + verbose=True, + ) + got = qa_chain("What is the answer?") + assert got["chat_history"][1].content == fixed_resp + assert got["answer"] == fixed_resp + + +def test_fixed_message_response_when_docs_found() -> None: + fixed_resp = "I don't know" + answer = "I know the answer!" + llm = FakeListLLM(responses=[answer]) + retriever = SequentialRetriever( + sequential_responses=[[Document(page_content=answer)]] + ) + memory = ConversationBufferMemory( + k=1, output_key="answer", memory_key="chat_history", return_messages=True + ) + qa_chain = ConversationalRetrievalChain.from_llm( + llm=llm, + memory=memory, + retriever=retriever, + return_source_documents=True, + rephrase_question=False, + response_if_no_docs_found=fixed_resp, + verbose=True, + ) + got = qa_chain("What is the answer?") + assert got["chat_history"][1].content == answer + assert got["answer"] == answer diff --git a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py new file mode 100644 index 00000000000..b75913a9617 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py @@ -0,0 +1,26 @@ +from typing import List + +from langchain.schema import BaseRetriever, Document + + +class SequentialRetriever(BaseRetriever): + """Test util that returns a sequence of documents""" + + sequential_responses: List[List[Document]] + response_index: int = 0 + + def _get_relevant_documents( # type: ignore[override] + self, + query: str, + ) -> List[Document]: + if self.response_index >= len(self.sequential_responses): + return [] + else: + self.response_index += 1 + return self.sequential_responses[self.response_index - 1] + + async def _aget_relevant_documents( # type: ignore[override] + self, + query: str, + ) -> List[Document]: + return self._get_relevant_documents(query)