From 27f450f23f137f0dedb2e6eca10883772a109233 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Wed, 8 Jan 2025 11:25:46 -0800 Subject: [PATCH] x --- .../chains/combine_documents/stuff.py | 11 ++- .../chains/test_combine_documents.py | 88 ++++++++++++++++++- 2 files changed, 95 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 86ff6bf3630..694d50b19ae 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -1,6 +1,6 @@ """Chain that combines documents by stuffing into context.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from langchain_core._api import deprecated from langchain_core.callbacks import Callbacks @@ -26,7 +26,9 @@ def create_stuff_documents_chain( prompt: Runnable[dict, LanguageModelInput], *, output_parser: Optional[BaseOutputParser] = None, - document_prompt: Optional[BasePromptTemplate] = None, + document_prompt: Optional[ + Union[BasePromptTemplate, Runnable[Document, str]] + ] = None, document_separator: str = DEFAULT_DOCUMENT_SEPARATOR, document_variable_name: str = DOCUMENTS_KEY, ) -> Runnable[Dict[str, Any], Any]: @@ -36,6 +38,8 @@ def create_stuff_documents_chain( llm: Language model. prompt: Prompt template. Must contain input variable "context" (override by setting document_variable), which will be used for passing in the formatted documents. + Can also be a Runnable that takes a dictionary with the "context" key and returns + a valid language model input. output_parser: Output parser. Defaults to StrOutputParser. document_prompt: Prompt used for formatting each document into a string. Input variables can be "page_content" or any metadata keys that are in all @@ -43,6 +47,7 @@ def create_stuff_documents_chain( `Document.page_content`, and all other inputs variables will be automatically retrieved from the `Document.metadata` dictionary. Default to a prompt that only contains `Document.page_content`. + Can also document_separator: String separator to use between formatted document strings. document_variable_name: Variable name to use for the formatted documents in the prompt. Defaults to "context". @@ -83,6 +88,8 @@ def create_stuff_documents_chain( def format_docs(inputs: dict) -> str: return document_separator.join( format_document(doc, _document_prompt) + if isinstance(_document_prompt, BasePromptTemplate) + else _document_prompt.invoke(doc) for doc in inputs[document_variable_name] ) diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index 8de556bb8b9..25f2aa496fa 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -1,11 +1,18 @@ """Test functionality related to combining documents.""" -from typing import Any, List +from typing import Any, List, Union import pytest from langchain_core.documents import Document -from langchain_core.prompts import PromptTemplate, aformat_document, format_document +from langchain_core.prompt_values import StringPromptValue +from langchain_core.prompts import ( + PromptTemplate, + aformat_document, + format_document, +) +from langchain_core.runnables import RunnableLambda +from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains.combine_documents.reduce import ( collapse_docs, split_list_of_docs, @@ -142,3 +149,80 @@ async def test_format_doc_missing_metadata() -> None: format_document(doc, prompt) with pytest.raises(ValueError): await aformat_document(doc, prompt) + + +def test_create_stuff_documents_chain_init_prompt_good() -> None: + """Test that initializing the chain with a prompt works.""" + prompt = PromptTemplate.from_template("{context}") + llm = FakeLLM() + create_stuff_documents_chain(llm, prompt) + + # confirm also works with custom document_variable_name + prompt = PromptTemplate.from_template("{docs}") + create_stuff_documents_chain(llm, prompt, document_variable_name="docs") + + +def test_create_stuff_documents_chain_init_prompt_bad() -> None: + """Test that initializing the chain with a bad prompt fails.""" + prompt = PromptTemplate.from_template("{foo}") + llm = FakeLLM() + with pytest.raises(ValueError): + create_stuff_documents_chain(llm, prompt) + + +def test_create_stuff_documents_chain_init_sequence_prompt_good() -> None: + """Test that initializing the chain with a sequence starting with a prompt + works.""" + prompt = PromptTemplate.from_template("{context}") + llm = FakeLLM() + prompt_runnable = prompt | (lambda x: x) + create_stuff_documents_chain(llm, prompt_runnable) + + +def test_create_stuff_documents_chain_init_sequence_prompt_bad() -> None: + """Test that initializing the chain with a sequence starting with a bad prompt + fails.""" + prompt = PromptTemplate.from_template("{foo}") + llm = FakeLLM() + prompt_runnable = prompt | (lambda x: x) + with pytest.raises(ValueError): + create_stuff_documents_chain(llm, prompt_runnable) + + +def test_create_stuff_documents_chain_init_sequence_no_prompt() -> None: + """Test that initializing the chain with a sequence not starting with a prompt + succeeds. validation""" + llm = FakeLLM() + create_stuff_documents_chain(llm, lambda x: "{context}") + + +def test_create_stuff_documents_chain_truncate_docs() -> None: + """Test a full chain that truncates the context variable.""" + prompt = PromptTemplate.from_template("{context}") + + # For the sake of this test, we will use a fake LLM that just returns the query. + # if doing this in practice, you would use an actual llm. + @RunnableLambda + def llm(query: Union[str, StringPromptValue]) -> str: + return str(query) + + @RunnableLambda + def doc_prompt(doc: Document) -> str: + doc_format = f"Title: {doc.metadata["title"]}\n\n{doc.page_content}" + # truncate docs to 200 characters + return doc_format[:200] + + chain = create_stuff_documents_chain(llm, prompt, document_prompt=doc_prompt) + + docs = [ + Document(page_content="foo" * 200, metadata={"title": "foo"}), + Document(page_content="bar" * 200, metadata={"title": "bar"}), + ] + + output = chain.invoke({"context": docs}) + + # confirm that it successfully truncated the inputs + # without truncation, we should get at least 600 characters from each of the + # 2 documents + assert isinstance(output, str) + assert len(output) == 415