mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
x
This commit is contained in:
parent
5b329a1dbe
commit
27f450f23f
@ -1,6 +1,6 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
@ -26,7 +26,9 @@ def create_stuff_documents_chain(
|
||||
prompt: Runnable[dict, LanguageModelInput],
|
||||
*,
|
||||
output_parser: Optional[BaseOutputParser] = None,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_prompt: Optional[
|
||||
Union[BasePromptTemplate, Runnable[Document, str]]
|
||||
] = None,
|
||||
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||
document_variable_name: str = DOCUMENTS_KEY,
|
||||
) -> Runnable[Dict[str, Any], Any]:
|
||||
@ -36,6 +38,8 @@ def create_stuff_documents_chain(
|
||||
llm: Language model.
|
||||
prompt: Prompt template. Must contain input variable "context" (override by
|
||||
setting document_variable), which will be used for passing in the formatted documents.
|
||||
Can also be a Runnable that takes a dictionary with the "context" key and returns
|
||||
a valid language model input.
|
||||
output_parser: Output parser. Defaults to StrOutputParser.
|
||||
document_prompt: Prompt used for formatting each document into a string. Input
|
||||
variables can be "page_content" or any metadata keys that are in all
|
||||
@ -43,6 +47,7 @@ def create_stuff_documents_chain(
|
||||
`Document.page_content`, and all other inputs variables will be
|
||||
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||
a prompt that only contains `Document.page_content`.
|
||||
Can also
|
||||
document_separator: String separator to use between formatted document strings.
|
||||
document_variable_name: Variable name to use for the formatted documents in the prompt.
|
||||
Defaults to "context".
|
||||
@ -83,6 +88,8 @@ def create_stuff_documents_chain(
|
||||
def format_docs(inputs: dict) -> str:
|
||||
return document_separator.join(
|
||||
format_document(doc, _document_prompt)
|
||||
if isinstance(_document_prompt, BasePromptTemplate)
|
||||
else _document_prompt.invoke(doc)
|
||||
for doc in inputs[document_variable_name]
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,18 @@
|
||||
"""Test functionality related to combining documents."""
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import PromptTemplate, aformat_document, format_document
|
||||
from langchain_core.prompt_values import StringPromptValue
|
||||
from langchain_core.prompts import (
|
||||
PromptTemplate,
|
||||
aformat_document,
|
||||
format_document,
|
||||
)
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
from langchain.chains.combine_documents.reduce import (
|
||||
collapse_docs,
|
||||
split_list_of_docs,
|
||||
@ -142,3 +149,80 @@ async def test_format_doc_missing_metadata() -> None:
|
||||
format_document(doc, prompt)
|
||||
with pytest.raises(ValueError):
|
||||
await aformat_document(doc, prompt)
|
||||
|
||||
|
||||
def test_create_stuff_documents_chain_init_prompt_good() -> None:
|
||||
"""Test that initializing the chain with a prompt works."""
|
||||
prompt = PromptTemplate.from_template("{context}")
|
||||
llm = FakeLLM()
|
||||
create_stuff_documents_chain(llm, prompt)
|
||||
|
||||
# confirm also works with custom document_variable_name
|
||||
prompt = PromptTemplate.from_template("{docs}")
|
||||
create_stuff_documents_chain(llm, prompt, document_variable_name="docs")
|
||||
|
||||
|
||||
def test_create_stuff_documents_chain_init_prompt_bad() -> None:
|
||||
"""Test that initializing the chain with a bad prompt fails."""
|
||||
prompt = PromptTemplate.from_template("{foo}")
|
||||
llm = FakeLLM()
|
||||
with pytest.raises(ValueError):
|
||||
create_stuff_documents_chain(llm, prompt)
|
||||
|
||||
|
||||
def test_create_stuff_documents_chain_init_sequence_prompt_good() -> None:
|
||||
"""Test that initializing the chain with a sequence starting with a prompt
|
||||
works."""
|
||||
prompt = PromptTemplate.from_template("{context}")
|
||||
llm = FakeLLM()
|
||||
prompt_runnable = prompt | (lambda x: x)
|
||||
create_stuff_documents_chain(llm, prompt_runnable)
|
||||
|
||||
|
||||
def test_create_stuff_documents_chain_init_sequence_prompt_bad() -> None:
|
||||
"""Test that initializing the chain with a sequence starting with a bad prompt
|
||||
fails."""
|
||||
prompt = PromptTemplate.from_template("{foo}")
|
||||
llm = FakeLLM()
|
||||
prompt_runnable = prompt | (lambda x: x)
|
||||
with pytest.raises(ValueError):
|
||||
create_stuff_documents_chain(llm, prompt_runnable)
|
||||
|
||||
|
||||
def test_create_stuff_documents_chain_init_sequence_no_prompt() -> None:
|
||||
"""Test that initializing the chain with a sequence not starting with a prompt
|
||||
succeeds. validation"""
|
||||
llm = FakeLLM()
|
||||
create_stuff_documents_chain(llm, lambda x: "{context}")
|
||||
|
||||
|
||||
def test_create_stuff_documents_chain_truncate_docs() -> None:
|
||||
"""Test a full chain that truncates the context variable."""
|
||||
prompt = PromptTemplate.from_template("{context}")
|
||||
|
||||
# For the sake of this test, we will use a fake LLM that just returns the query.
|
||||
# if doing this in practice, you would use an actual llm.
|
||||
@RunnableLambda
|
||||
def llm(query: Union[str, StringPromptValue]) -> str:
|
||||
return str(query)
|
||||
|
||||
@RunnableLambda
|
||||
def doc_prompt(doc: Document) -> str:
|
||||
doc_format = f"Title: {doc.metadata["title"]}\n\n{doc.page_content}"
|
||||
# truncate docs to 200 characters
|
||||
return doc_format[:200]
|
||||
|
||||
chain = create_stuff_documents_chain(llm, prompt, document_prompt=doc_prompt)
|
||||
|
||||
docs = [
|
||||
Document(page_content="foo" * 200, metadata={"title": "foo"}),
|
||||
Document(page_content="bar" * 200, metadata={"title": "bar"}),
|
||||
]
|
||||
|
||||
output = chain.invoke({"context": docs})
|
||||
|
||||
# confirm that it successfully truncated the inputs
|
||||
# without truncation, we should get at least 600 characters from each of the
|
||||
# 2 documents
|
||||
assert isinstance(output, str)
|
||||
assert len(output) == 415
|
||||
|
Loading…
Reference in New Issue
Block a user