mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
x
This commit is contained in:
parent
509e9a9821
commit
5b329a1dbe
@ -9,7 +9,9 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
|
from langchain_core.runnables import Runnable, RunnableSequence
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.runnables.utils import create_model
|
from langchain_core.runnables.utils import create_model
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
@ -22,12 +24,23 @@ DOCUMENTS_KEY = "context"
|
|||||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_prompt(prompt: BasePromptTemplate, document_variable_name: str) -> None:
|
def _validate_prompt(
|
||||||
if document_variable_name not in prompt.input_variables:
|
prompt: Runnable[dict, LanguageModelInput], document_variable_name: str
|
||||||
raise ValueError(
|
) -> None:
|
||||||
f"Prompt must accept {document_variable_name} as an input variable. "
|
# validate that the prompt accepts the document variable if it starts with a prompt
|
||||||
f"Received prompt with input variables: {prompt.input_variables}"
|
# template
|
||||||
)
|
# if it's a generic runnable, don't validate
|
||||||
|
first_runnable = prompt
|
||||||
|
while isinstance(first_runnable, RunnableSequence):
|
||||||
|
first_runnable = first_runnable.first
|
||||||
|
if isinstance(first_runnable, BasePromptTemplate):
|
||||||
|
if document_variable_name not in first_runnable.input_variables:
|
||||||
|
msg = (
|
||||||
|
f"Prompt must accept {document_variable_name} as an input variable. "
|
||||||
|
"Received prompt with input variables: "
|
||||||
|
f"{first_runnable.input_variables}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
class BaseCombineDocumentsChain(Chain, ABC):
|
class BaseCombineDocumentsChain(Chain, ABC):
|
||||||
|
Loading…
Reference in New Issue
Block a user