mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 19:57:51 +00:00
Fix issue where response_if_no_docs_found is not implemented on async… (#13297)
Response_if_no_docs_found is not implemented in ConversationalRetrievalChain for async code paths. Implemented it and added test cases Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
67c55cb5b0
commit
d1d693b2a7
@ -204,14 +204,19 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
else:
|
else:
|
||||||
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
||||||
|
|
||||||
new_inputs = inputs.copy()
|
output: Dict[str, Any] = {}
|
||||||
if self.rephrase_question:
|
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
||||||
new_inputs["question"] = new_question
|
output[self.output_key] = self.response_if_no_docs_found
|
||||||
new_inputs["chat_history"] = chat_history_str
|
else:
|
||||||
answer = await self.combine_docs_chain.arun(
|
new_inputs = inputs.copy()
|
||||||
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
if self.rephrase_question:
|
||||||
)
|
new_inputs["question"] = new_question
|
||||||
output: Dict[str, Any] = {self.output_key: answer}
|
new_inputs["chat_history"] = chat_history_str
|
||||||
|
answer = await self.combine_docs_chain.arun(
|
||||||
|
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
||||||
|
)
|
||||||
|
output[self.output_key] = answer
|
||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
output["source_documents"] = docs
|
output["source_documents"] = docs
|
||||||
if self.return_generated_question:
|
if self.return_generated_question:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test conversation chain and memory."""
|
"""Test conversation chain and memory."""
|
||||||
|
import pytest
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
|
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
|
||||||
@ -7,6 +8,54 @@ from langchain.memory.buffer import ConversationBufferMemory
|
|||||||
from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever
|
from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def atest_simple() -> None:
|
||||||
|
fixed_resp = "I don't know"
|
||||||
|
answer = "I know the answer!"
|
||||||
|
llm = FakeListLLM(responses=[answer])
|
||||||
|
retriever = SequentialRetriever(sequential_responses=[[]])
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||||
|
)
|
||||||
|
qa_chain = ConversationalRetrievalChain.from_llm(
|
||||||
|
llm=llm,
|
||||||
|
memory=memory,
|
||||||
|
retriever=retriever,
|
||||||
|
return_source_documents=True,
|
||||||
|
rephrase_question=False,
|
||||||
|
response_if_no_docs_found=fixed_resp,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
got = await qa_chain.acall("What is the answer?")
|
||||||
|
assert got["chat_history"][1].content == fixed_resp
|
||||||
|
assert got["answer"] == fixed_resp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def atest_fixed_message_response_when_docs_found() -> None:
|
||||||
|
fixed_resp = "I don't know"
|
||||||
|
answer = "I know the answer!"
|
||||||
|
llm = FakeListLLM(responses=[answer])
|
||||||
|
retriever = SequentialRetriever(
|
||||||
|
sequential_responses=[[Document(page_content=answer)]]
|
||||||
|
)
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||||
|
)
|
||||||
|
qa_chain = ConversationalRetrievalChain.from_llm(
|
||||||
|
llm=llm,
|
||||||
|
memory=memory,
|
||||||
|
retriever=retriever,
|
||||||
|
return_source_documents=True,
|
||||||
|
rephrase_question=False,
|
||||||
|
response_if_no_docs_found=fixed_resp,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
got = await qa_chain.acall("What is the answer?")
|
||||||
|
assert got["chat_history"][1].content == answer
|
||||||
|
assert got["answer"] == answer
|
||||||
|
|
||||||
|
|
||||||
def test_fixed_message_response_when_no_docs_found() -> None:
|
def test_fixed_message_response_when_no_docs_found() -> None:
|
||||||
fixed_resp = "I don't know"
|
fixed_resp = "I don't know"
|
||||||
answer = "I know the answer!"
|
answer = "I know the answer!"
|
||||||
|
Loading…
Reference in New Issue
Block a user