This commit is contained in:
Erick Friis 2025-01-08 11:25:46 -08:00
parent 5b329a1dbe
commit 27f450f23f
2 changed files with 95 additions and 4 deletions

View File

@ -1,6 +1,6 @@
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks
@ -26,7 +26,9 @@ def create_stuff_documents_chain(
prompt: Runnable[dict, LanguageModelInput],
*,
output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None,
document_prompt: Optional[
Union[BasePromptTemplate, Runnable[Document, str]]
] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
document_variable_name: str = DOCUMENTS_KEY,
) -> Runnable[Dict[str, Any], Any]:
@ -36,6 +38,8 @@ def create_stuff_documents_chain(
llm: Language model.
prompt: Prompt template. Must contain input variable "context" (override by
setting document_variable), which will be used for passing in the formatted documents.
Can also be a Runnable that takes a dictionary with the "context" key and returns
a valid language model input.
output_parser: Output parser. Defaults to StrOutputParser.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
@ -43,6 +47,7 @@ def create_stuff_documents_chain(
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
Can also
document_separator: String separator to use between formatted document strings.
document_variable_name: Variable name to use for the formatted documents in the prompt.
Defaults to "context".
@ -83,6 +88,8 @@ def create_stuff_documents_chain(
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt)
if isinstance(_document_prompt, BasePromptTemplate)
else _document_prompt.invoke(doc)
for doc in inputs[document_variable_name]
)

View File

@ -1,11 +1,18 @@
"""Test functionality related to combining documents."""
from typing import Any, List
from typing import Any, List, Union
import pytest
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, aformat_document, format_document
from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts import (
PromptTemplate,
aformat_document,
format_document,
)
from langchain_core.runnables import RunnableLambda
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.combine_documents.reduce import (
collapse_docs,
split_list_of_docs,
@ -142,3 +149,80 @@ async def test_format_doc_missing_metadata() -> None:
format_document(doc, prompt)
with pytest.raises(ValueError):
await aformat_document(doc, prompt)
def test_create_stuff_documents_chain_init_prompt_good() -> None:
"""Test that initializing the chain with a prompt works."""
prompt = PromptTemplate.from_template("{context}")
llm = FakeLLM()
create_stuff_documents_chain(llm, prompt)
# confirm also works with custom document_variable_name
prompt = PromptTemplate.from_template("{docs}")
create_stuff_documents_chain(llm, prompt, document_variable_name="docs")
def test_create_stuff_documents_chain_init_prompt_bad() -> None:
"""Test that initializing the chain with a bad prompt fails."""
prompt = PromptTemplate.from_template("{foo}")
llm = FakeLLM()
with pytest.raises(ValueError):
create_stuff_documents_chain(llm, prompt)
def test_create_stuff_documents_chain_init_sequence_prompt_good() -> None:
"""Test that initializing the chain with a sequence starting with a prompt
works."""
prompt = PromptTemplate.from_template("{context}")
llm = FakeLLM()
prompt_runnable = prompt | (lambda x: x)
create_stuff_documents_chain(llm, prompt_runnable)
def test_create_stuff_documents_chain_init_sequence_prompt_bad() -> None:
"""Test that initializing the chain with a sequence starting with a bad prompt
fails."""
prompt = PromptTemplate.from_template("{foo}")
llm = FakeLLM()
prompt_runnable = prompt | (lambda x: x)
with pytest.raises(ValueError):
create_stuff_documents_chain(llm, prompt_runnable)
def test_create_stuff_documents_chain_init_sequence_no_prompt() -> None:
"""Test that initializing the chain with a sequence not starting with a prompt
succeeds. validation"""
llm = FakeLLM()
create_stuff_documents_chain(llm, lambda x: "{context}")
def test_create_stuff_documents_chain_truncate_docs() -> None:
"""Test a full chain that truncates the context variable."""
prompt = PromptTemplate.from_template("{context}")
# For the sake of this test, we will use a fake LLM that just returns the query.
# if doing this in practice, you would use an actual llm.
@RunnableLambda
def llm(query: Union[str, StringPromptValue]) -> str:
return str(query)
@RunnableLambda
def doc_prompt(doc: Document) -> str:
doc_format = f"Title: {doc.metadata["title"]}\n\n{doc.page_content}"
# truncate docs to 200 characters
return doc_format[:200]
chain = create_stuff_documents_chain(llm, prompt, document_prompt=doc_prompt)
docs = [
Document(page_content="foo" * 200, metadata={"title": "foo"}),
Document(page_content="bar" * 200, metadata={"title": "bar"}),
]
output = chain.invoke({"context": docs})
# confirm that it successfully truncated the inputs
# without truncation, we should get at least 600 characters from each of the
# 2 documents
assert isinstance(output, str)
assert len(output) == 415