This commit is contained in:
Erick Friis 2025-01-08 11:25:46 -08:00
parent 5b329a1dbe
commit 27f450f23f
2 changed files with 95 additions and 4 deletions

View File

@ -1,6 +1,6 @@
"""Chain that combines documents by stuffing into context.""" """Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
@ -26,7 +26,9 @@ def create_stuff_documents_chain(
prompt: Runnable[dict, LanguageModelInput], prompt: Runnable[dict, LanguageModelInput],
*, *,
output_parser: Optional[BaseOutputParser] = None, output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None, document_prompt: Optional[
Union[BasePromptTemplate, Runnable[Document, str]]
] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR, document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
document_variable_name: str = DOCUMENTS_KEY, document_variable_name: str = DOCUMENTS_KEY,
) -> Runnable[Dict[str, Any], Any]: ) -> Runnable[Dict[str, Any], Any]:
@ -36,6 +38,8 @@ def create_stuff_documents_chain(
llm: Language model. llm: Language model.
prompt: Prompt template. Must contain input variable "context" (override by prompt: Prompt template. Must contain input variable "context" (override by
setting document_variable), which will be used for passing in the formatted documents. setting document_variable), which will be used for passing in the formatted documents.
Can also be a Runnable that takes a dictionary with the "context" key and returns
a valid language model input.
output_parser: Output parser. Defaults to StrOutputParser. output_parser: Output parser. Defaults to StrOutputParser.
document_prompt: Prompt used for formatting each document into a string. Input document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all variables can be "page_content" or any metadata keys that are in all
@ -43,6 +47,7 @@ def create_stuff_documents_chain(
`Document.page_content`, and all other inputs variables will be `Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`. a prompt that only contains `Document.page_content`.
Can also
document_separator: String separator to use between formatted document strings. document_separator: String separator to use between formatted document strings.
document_variable_name: Variable name to use for the formatted documents in the prompt. document_variable_name: Variable name to use for the formatted documents in the prompt.
Defaults to "context". Defaults to "context".
@ -83,6 +88,8 @@ def create_stuff_documents_chain(
def format_docs(inputs: dict) -> str: def format_docs(inputs: dict) -> str:
return document_separator.join( return document_separator.join(
format_document(doc, _document_prompt) format_document(doc, _document_prompt)
if isinstance(_document_prompt, BasePromptTemplate)
else _document_prompt.invoke(doc)
for doc in inputs[document_variable_name] for doc in inputs[document_variable_name]
) )

View File

@ -1,11 +1,18 @@
"""Test functionality related to combining documents.""" """Test functionality related to combining documents."""
from typing import Any, List from typing import Any, List, Union
import pytest import pytest
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, aformat_document, format_document from langchain_core.prompt_values import StringPromptValue
from langchain_core.prompts import (
PromptTemplate,
aformat_document,
format_document,
)
from langchain_core.runnables import RunnableLambda
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.combine_documents.reduce import ( from langchain.chains.combine_documents.reduce import (
collapse_docs, collapse_docs,
split_list_of_docs, split_list_of_docs,
@ -142,3 +149,80 @@ async def test_format_doc_missing_metadata() -> None:
format_document(doc, prompt) format_document(doc, prompt)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await aformat_document(doc, prompt) await aformat_document(doc, prompt)
def test_create_stuff_documents_chain_init_prompt_good() -> None:
"""Test that initializing the chain with a prompt works."""
prompt = PromptTemplate.from_template("{context}")
llm = FakeLLM()
create_stuff_documents_chain(llm, prompt)
# confirm also works with custom document_variable_name
prompt = PromptTemplate.from_template("{docs}")
create_stuff_documents_chain(llm, prompt, document_variable_name="docs")
def test_create_stuff_documents_chain_init_prompt_bad() -> None:
"""Test that initializing the chain with a bad prompt fails."""
prompt = PromptTemplate.from_template("{foo}")
llm = FakeLLM()
with pytest.raises(ValueError):
create_stuff_documents_chain(llm, prompt)
def test_create_stuff_documents_chain_init_sequence_prompt_good() -> None:
"""Test that initializing the chain with a sequence starting with a prompt
works."""
prompt = PromptTemplate.from_template("{context}")
llm = FakeLLM()
prompt_runnable = prompt | (lambda x: x)
create_stuff_documents_chain(llm, prompt_runnable)
def test_create_stuff_documents_chain_init_sequence_prompt_bad() -> None:
"""Test that initializing the chain with a sequence starting with a bad prompt
fails."""
prompt = PromptTemplate.from_template("{foo}")
llm = FakeLLM()
prompt_runnable = prompt | (lambda x: x)
with pytest.raises(ValueError):
create_stuff_documents_chain(llm, prompt_runnable)
def test_create_stuff_documents_chain_init_sequence_no_prompt() -> None:
"""Test that initializing the chain with a sequence not starting with a prompt
succeeds. validation"""
llm = FakeLLM()
create_stuff_documents_chain(llm, lambda x: "{context}")
def test_create_stuff_documents_chain_truncate_docs() -> None:
"""Test a full chain that truncates the context variable."""
prompt = PromptTemplate.from_template("{context}")
# For the sake of this test, we will use a fake LLM that just returns the query.
# if doing this in practice, you would use an actual llm.
@RunnableLambda
def llm(query: Union[str, StringPromptValue]) -> str:
return str(query)
@RunnableLambda
def doc_prompt(doc: Document) -> str:
doc_format = f"Title: {doc.metadata["title"]}\n\n{doc.page_content}"
# truncate docs to 200 characters
return doc_format[:200]
chain = create_stuff_documents_chain(llm, prompt, document_prompt=doc_prompt)
docs = [
Document(page_content="foo" * 200, metadata={"title": "foo"}),
Document(page_content="bar" * 200, metadata={"title": "bar"}),
]
output = chain.invoke({"context": docs})
# confirm that it successfully truncated the inputs
# without truncation, we should get at least 600 characters from each of the
# 2 documents
assert isinstance(output, str)
assert len(output) == 415