Fix issue where response_if_no_docs_found is not implemented on async… (#13297)

Response_if_no_docs_found is not implemented in
ConversationalRetrievalChain for async code paths. Implemented it and
added test cases

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
sudranga 2023-11-29 19:06:13 -08:00 committed by GitHub
parent 67c55cb5b0
commit d1d693b2a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 8 deletions

View File

@ -204,14 +204,19 @@ class BaseConversationalRetrievalChain(Chain):
else:
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = await self.combine_docs_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
output: Dict[str, Any] = {self.output_key: answer}
output: Dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = await self.combine_docs_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
output[self.output_key] = answer
if self.return_source_documents:
output["source_documents"] = docs
if self.return_generated_question:

View File

@ -1,4 +1,5 @@
"""Test conversation chain and memory."""
import pytest
from langchain_core.documents import Document
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
@ -7,6 +8,54 @@ from langchain.memory.buffer import ConversationBufferMemory
from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever
@pytest.mark.asyncio
async def atest_simple() -> None:
fixed_resp = "I don't know"
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(sequential_responses=[[]])
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=retriever,
return_source_documents=True,
rephrase_question=False,
response_if_no_docs_found=fixed_resp,
verbose=True,
)
got = await qa_chain.acall("What is the answer?")
assert got["chat_history"][1].content == fixed_resp
assert got["answer"] == fixed_resp
@pytest.mark.asyncio
async def atest_fixed_message_response_when_docs_found() -> None:
fixed_resp = "I don't know"
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]]
)
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=retriever,
return_source_documents=True,
rephrase_question=False,
response_if_no_docs_found=fixed_resp,
verbose=True,
)
got = await qa_chain.acall("What is the answer?")
assert got["chat_history"][1].content == answer
assert got["answer"] == answer
def test_fixed_message_response_when_no_docs_found() -> None:
fixed_resp = "I don't know"
answer = "I know the answer!"