diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 90e965996de..00b6002da0e 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -22,11 +22,11 @@ DOCUMENTS_KEY = "context" DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}") -def _validate_prompt(prompt: BasePromptTemplate) -> None: - if DOCUMENTS_KEY not in prompt.input_variables: +def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None: + if document_variable_name not in prompt.input_variables: raise ValueError( - f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt " - f"with input variables: {prompt.input_variables}" + f"Prompt must accept {document_variable_name} as an input variable. " + f"Received prompt with input variables: {prompt.input_variables}" ) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 5ffd86c9718..cdecec0f40b 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -76,7 +76,7 @@ def create_stuff_documents_chain( chain.invoke({"context": docs}) """ # noqa: E501 - _validate_prompt(prompt) + _validate_prompt(prompt, document_variable_name) _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT _output_parser = output_parser or StrOutputParser()