This commit is contained in:
Erick Friis 2025-01-08 10:52:03 -08:00
parent 509e9a9821
commit 5b329a1dbe

View File

@ -9,7 +9,9 @@ from langchain_core.callbacks import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelInput
from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import Runnable, RunnableSequence
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import create_model from langchain_core.runnables.utils import create_model
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
@ -22,12 +24,23 @@ DOCUMENTS_KEY = "context"
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}") DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None: def _validate_prompt(
if document_variable_name not in prompt.input_variables: prompt: Runnable[dict, LanguageModelInput], document_variable_name: str
raise ValueError( ) -> None:
f"Prompt must accept {document_variable_name} as an input variable. " # validate that the prompt accepts the document variable if it starts with a prompt
f"Received prompt with input variables: {prompt.input_variables}" # template
) # if it's a generic runnable, don't validate
first_runnable = prompt
while isinstance(first_runnable, RunnableSequence):
first_runnable = first_runnable.first
if isinstance(first_runnable, BasePromptTemplate):
if document_variable_name not in first_runnable.input_variables:
msg = (
f"Prompt must accept {document_variable_name} as an input variable. "
"Received prompt with input variables: "
f"{first_runnable.input_variables}"
)
raise ValueError(msg)
class BaseCombineDocumentsChain(Chain, ABC): class BaseCombineDocumentsChain(Chain, ABC):