Compare commits

...

7 Commits

Author SHA1 Message Date
jacoblee93
f90e665413 Lint 2024-01-28 20:46:46 -08:00
jacoblee93
fad076fa06 Lint 2024-01-28 20:43:54 -08:00
jacoblee93
59ffccf27d Fix lint 2024-01-28 20:40:39 -08:00
jacoblee93
ac85fca6f0 Switch to messages param 2024-01-28 20:28:00 -08:00
jacoblee93
f29ad020a0 Small tweak 2024-01-28 10:08:41 -08:00
jacoblee93
b67561890b Fix lint + tests 2024-01-28 10:07:21 -08:00
jacoblee93
b970bfe8da Make input param optional for retrieval chain and history aware retriever chain 2024-01-28 09:59:16 -08:00
3 changed files with 80 additions and 5 deletions

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
from typing import Dict
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.retrievers import RetrieverLike, RetrieverOutputLike
@@ -48,13 +51,31 @@ def create_history_aware_retriever(
chain.invoke({"input": "...", "chat_history": })
"""
if "input" not in prompt.input_variables:
input_vars = prompt.input_variables
if "input" not in input_vars and "messages" not in input_vars:
raise ValueError(
"Expected `input` to be a prompt variable, "
f"but got {prompt.input_variables}"
"Expected either `input` or `messages` to be prompt variables, "
f"but got {input_vars}"
)
def messages_param_is_message_list(x: Dict) -> bool:
return (
isinstance(x.get("messages", []), list)
and len(x.get("messages", [])) > 0
and all(isinstance(i, BaseMessage) for i in x.get("messages", []))
)
retrieve_documents: RetrieverOutputLike = RunnableBranch(
(
lambda x: messages_param_is_message_list(x)
and len(x.get("messages", [])) > 1,
prompt | llm | StrOutputParser() | retriever,
),
(
lambda x: messages_param_is_message_list(x)
and len(x.get("messages", [])) == 1,
(lambda x: x["messages"][-1].content) | retriever,
),
(
# Both empty string and empty list evaluate to False
lambda x: not x.get("chat_history", False),

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, Union
from langchain_core.messages import BaseMessage
from langchain_core.retrievers import (
BaseRetriever,
RetrieverOutput,
@@ -55,10 +56,30 @@ def create_retrieval_chain(
chain.invoke({"input": "..."})
"""
def messages_param_is_message_list(x: Dict) -> bool:
return (
isinstance(x.get("messages", []), list)
and len(x.get("messages", [])) > 0
and all(isinstance(i, BaseMessage) for i in x.get("messages", []))
)
def extract_retriever_input_string(x: Dict) -> str:
if not x.get("input"):
if messages_param_is_message_list(x):
return x["messages"][-1].content
else:
raise ValueError(
"If `input` not provided, ",
"`messages` parameter must be a list of messages.",
)
else:
return x["input"]
if not isinstance(retriever, BaseRetriever):
retrieval_docs: Runnable[dict, RetrieverOutput] = retriever
else:
retrieval_docs = (lambda x: x["input"]) | retriever
retrieval_docs = extract_retriever_input_string | retriever
retrieval_chain = (
RunnablePassthrough.assign(

View File

@@ -1,7 +1,12 @@
"""Test conversation chain and memory."""
from langchain_community.llms.fake import FakeListLLM
from langchain_core.documents import Document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
PromptTemplate,
)
from langchain.chains import create_retrieval_chain
from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever
@@ -22,3 +27,31 @@ def test_create() -> None:
}
output = chain.invoke({"input": "What is the answer?", "chat_history": "foo"})
assert output == expected_output
def test_create_with_chat_history_messages_only() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = FakeParrotRetriever()
question_gen_prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="messages"),
]
)
chain = create_retrieval_chain(retriever, question_gen_prompt | llm)
expected_output = {
"answer": "I know the answer!",
"messages": [
HumanMessage(content="What is the answer?"),
],
"context": [Document(page_content="What is the answer?")],
}
output = chain.invoke(
{
"messages": [
HumanMessage(content="What is the answer?"),
],
}
)
assert output == expected_output