This commit is contained in:
Erick Friis 2025-01-08 10:52:03 -08:00
parent 509e9a9821
commit 5b329a1dbe

View File

@ -9,7 +9,9 @@ from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelInput
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import Runnable, RunnableSequence
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import create_model
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
@ -22,12 +24,23 @@ DOCUMENTS_KEY = "context"
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None:
if document_variable_name not in prompt.input_variables:
raise ValueError(
def _validate_prompt(
prompt: Runnable[dict, LanguageModelInput], document_variable_name: str
) -> None:
# validate that the prompt accepts the document variable if it starts with a prompt
# template
# if it's a generic runnable, don't validate
first_runnable = prompt
while isinstance(first_runnable, RunnableSequence):
first_runnable = first_runnable.first
if isinstance(first_runnable, BasePromptTemplate):
if document_variable_name not in first_runnable.input_variables:
msg = (
f"Prompt must accept {document_variable_name} as an input variable. "
f"Received prompt with input variables: {prompt.input_variables}"
"Received prompt with input variables: "
f"{first_runnable.input_variables}"
)
raise ValueError(msg)
class BaseCombineDocumentsChain(Chain, ABC):