diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 2406cd4215f..298d9b64346 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -9,7 +9,9 @@ from langchain_core.callbacks import ( CallbackManagerForChainRun, ) from langchain_core.documents import Document +from langchain_core.language_models import LanguageModelInput from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.runnables import Runnable, RunnableSequence from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import create_model from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter @@ -22,12 +24,23 @@ DOCUMENTS_KEY = "context" DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}") -def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None: - if document_variable_name not in prompt.input_variables: - raise ValueError( - f"Prompt must accept {document_variable_name} as an input variable. " - f"Received prompt with input variables: {prompt.input_variables}" - ) +def _validate_prompt( + prompt: Runnable[dict, LanguageModelInput], document_variable_name: str +) -> None: + # validate that the prompt accepts the document variable if it starts with a prompt + # template + # if it's a generic runnable, don't validate + first_runnable = prompt + while isinstance(first_runnable, RunnableSequence): + first_runnable = first_runnable.first + if isinstance(first_runnable, BasePromptTemplate): + if document_variable_name not in first_runnable.input_variables: + msg = ( + f"Prompt must accept {document_variable_name} as an input variable. " + "Received prompt with input variables: " + f"{first_runnable.input_variables}" + ) + raise ValueError(msg) class BaseCombineDocumentsChain(Chain, ABC):