mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
x
This commit is contained in:
parent
509e9a9821
commit
5b329a1dbe
@ -9,7 +9,9 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnableSequence
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.utils import create_model
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
@ -22,12 +24,23 @@ DOCUMENTS_KEY = "context"
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
||||
|
||||
|
||||
def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None:
|
||||
if document_variable_name not in prompt.input_variables:
|
||||
raise ValueError(
|
||||
f"Prompt must accept {document_variable_name} as an input variable. "
|
||||
f"Received prompt with input variables: {prompt.input_variables}"
|
||||
)
|
||||
def _validate_prompt(
|
||||
prompt: Runnable[dict, LanguageModelInput], document_variable_name: str
|
||||
) -> None:
|
||||
# validate that the prompt accepts the document variable if it starts with a prompt
|
||||
# template
|
||||
# if it's a generic runnable, don't validate
|
||||
first_runnable = prompt
|
||||
while isinstance(first_runnable, RunnableSequence):
|
||||
first_runnable = first_runnable.first
|
||||
if isinstance(first_runnable, BasePromptTemplate):
|
||||
if document_variable_name not in first_runnable.input_variables:
|
||||
msg = (
|
||||
f"Prompt must accept {document_variable_name} as an input variable. "
|
||||
"Received prompt with input variables: "
|
||||
f"{first_runnable.input_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
|
Loading…
Reference in New Issue
Block a user