mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
x
This commit is contained in:
parent
5b329a1dbe
commit
27f450f23f
@ -1,6 +1,6 @@
|
|||||||
"""Chain that combines documents by stuffing into context."""
|
"""Chain that combines documents by stuffing into context."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
@ -26,7 +26,9 @@ def create_stuff_documents_chain(
|
|||||||
prompt: Runnable[dict, LanguageModelInput],
|
prompt: Runnable[dict, LanguageModelInput],
|
||||||
*,
|
*,
|
||||||
output_parser: Optional[BaseOutputParser] = None,
|
output_parser: Optional[BaseOutputParser] = None,
|
||||||
document_prompt: Optional[BasePromptTemplate] = None,
|
document_prompt: Optional[
|
||||||
|
Union[BasePromptTemplate, Runnable[Document, str]]
|
||||||
|
] = None,
|
||||||
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||||
document_variable_name: str = DOCUMENTS_KEY,
|
document_variable_name: str = DOCUMENTS_KEY,
|
||||||
) -> Runnable[Dict[str, Any], Any]:
|
) -> Runnable[Dict[str, Any], Any]:
|
||||||
@ -36,6 +38,8 @@ def create_stuff_documents_chain(
|
|||||||
llm: Language model.
|
llm: Language model.
|
||||||
prompt: Prompt template. Must contain input variable "context" (override by
|
prompt: Prompt template. Must contain input variable "context" (override by
|
||||||
setting document_variable), which will be used for passing in the formatted documents.
|
setting document_variable), which will be used for passing in the formatted documents.
|
||||||
|
Can also be a Runnable that takes a dictionary with the "context" key and returns
|
||||||
|
a valid language model input.
|
||||||
output_parser: Output parser. Defaults to StrOutputParser.
|
output_parser: Output parser. Defaults to StrOutputParser.
|
||||||
document_prompt: Prompt used for formatting each document into a string. Input
|
document_prompt: Prompt used for formatting each document into a string. Input
|
||||||
variables can be "page_content" or any metadata keys that are in all
|
variables can be "page_content" or any metadata keys that are in all
|
||||||
@ -43,6 +47,7 @@ def create_stuff_documents_chain(
|
|||||||
`Document.page_content`, and all other inputs variables will be
|
`Document.page_content`, and all other inputs variables will be
|
||||||
automatically retrieved from the `Document.metadata` dictionary. Default to
|
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||||
a prompt that only contains `Document.page_content`.
|
a prompt that only contains `Document.page_content`.
|
||||||
|
Can also
|
||||||
document_separator: String separator to use between formatted document strings.
|
document_separator: String separator to use between formatted document strings.
|
||||||
document_variable_name: Variable name to use for the formatted documents in the prompt.
|
document_variable_name: Variable name to use for the formatted documents in the prompt.
|
||||||
Defaults to "context".
|
Defaults to "context".
|
||||||
@ -83,6 +88,8 @@ def create_stuff_documents_chain(
|
|||||||
def format_docs(inputs: dict) -> str:
|
def format_docs(inputs: dict) -> str:
|
||||||
return document_separator.join(
|
return document_separator.join(
|
||||||
format_document(doc, _document_prompt)
|
format_document(doc, _document_prompt)
|
||||||
|
if isinstance(_document_prompt, BasePromptTemplate)
|
||||||
|
else _document_prompt.invoke(doc)
|
||||||
for doc in inputs[document_variable_name]
|
for doc in inputs[document_variable_name]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,11 +1,18 @@
|
|||||||
"""Test functionality related to combining documents."""
|
"""Test functionality related to combining documents."""
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.prompts import PromptTemplate, aformat_document, format_document
|
from langchain_core.prompt_values import StringPromptValue
|
||||||
|
from langchain_core.prompts import (
|
||||||
|
PromptTemplate,
|
||||||
|
aformat_document,
|
||||||
|
format_document,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
|
||||||
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||||
from langchain.chains.combine_documents.reduce import (
|
from langchain.chains.combine_documents.reduce import (
|
||||||
collapse_docs,
|
collapse_docs,
|
||||||
split_list_of_docs,
|
split_list_of_docs,
|
||||||
@ -142,3 +149,80 @@ async def test_format_doc_missing_metadata() -> None:
|
|||||||
format_document(doc, prompt)
|
format_document(doc, prompt)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await aformat_document(doc, prompt)
|
await aformat_document(doc, prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_stuff_documents_chain_init_prompt_good() -> None:
|
||||||
|
"""Test that initializing the chain with a prompt works."""
|
||||||
|
prompt = PromptTemplate.from_template("{context}")
|
||||||
|
llm = FakeLLM()
|
||||||
|
create_stuff_documents_chain(llm, prompt)
|
||||||
|
|
||||||
|
# confirm also works with custom document_variable_name
|
||||||
|
prompt = PromptTemplate.from_template("{docs}")
|
||||||
|
create_stuff_documents_chain(llm, prompt, document_variable_name="docs")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_stuff_documents_chain_init_prompt_bad() -> None:
|
||||||
|
"""Test that initializing the chain with a bad prompt fails."""
|
||||||
|
prompt = PromptTemplate.from_template("{foo}")
|
||||||
|
llm = FakeLLM()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
create_stuff_documents_chain(llm, prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_stuff_documents_chain_init_sequence_prompt_good() -> None:
|
||||||
|
"""Test that initializing the chain with a sequence starting with a prompt
|
||||||
|
works."""
|
||||||
|
prompt = PromptTemplate.from_template("{context}")
|
||||||
|
llm = FakeLLM()
|
||||||
|
prompt_runnable = prompt | (lambda x: x)
|
||||||
|
create_stuff_documents_chain(llm, prompt_runnable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_stuff_documents_chain_init_sequence_prompt_bad() -> None:
|
||||||
|
"""Test that initializing the chain with a sequence starting with a bad prompt
|
||||||
|
fails."""
|
||||||
|
prompt = PromptTemplate.from_template("{foo}")
|
||||||
|
llm = FakeLLM()
|
||||||
|
prompt_runnable = prompt | (lambda x: x)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
create_stuff_documents_chain(llm, prompt_runnable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_stuff_documents_chain_init_sequence_no_prompt() -> None:
|
||||||
|
"""Test that initializing the chain with a sequence not starting with a prompt
|
||||||
|
succeeds. validation"""
|
||||||
|
llm = FakeLLM()
|
||||||
|
create_stuff_documents_chain(llm, lambda x: "{context}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_stuff_documents_chain_truncate_docs() -> None:
|
||||||
|
"""Test a full chain that truncates the context variable."""
|
||||||
|
prompt = PromptTemplate.from_template("{context}")
|
||||||
|
|
||||||
|
# For the sake of this test, we will use a fake LLM that just returns the query.
|
||||||
|
# if doing this in practice, you would use an actual llm.
|
||||||
|
@RunnableLambda
|
||||||
|
def llm(query: Union[str, StringPromptValue]) -> str:
|
||||||
|
return str(query)
|
||||||
|
|
||||||
|
@RunnableLambda
|
||||||
|
def doc_prompt(doc: Document) -> str:
|
||||||
|
doc_format = f"Title: {doc.metadata["title"]}\n\n{doc.page_content}"
|
||||||
|
# truncate docs to 200 characters
|
||||||
|
return doc_format[:200]
|
||||||
|
|
||||||
|
chain = create_stuff_documents_chain(llm, prompt, document_prompt=doc_prompt)
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
Document(page_content="foo" * 200, metadata={"title": "foo"}),
|
||||||
|
Document(page_content="bar" * 200, metadata={"title": "bar"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
output = chain.invoke({"context": docs})
|
||||||
|
|
||||||
|
# confirm that it successfully truncated the inputs
|
||||||
|
# without truncation, we should get at least 600 characters from each of the
|
||||||
|
# 2 documents
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert len(output) == 415
|
||||||
|
Loading…
Reference in New Issue
Block a user