diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py index 4961519717b..dd6e0c35478 100644 --- a/libs/core/langchain_core/prompts/loading.py +++ b/libs/core/langchain_core/prompts/loading.py @@ -8,6 +8,7 @@ import yaml from langchain_core.output_parsers.string import StrOutputParser from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.utils import try_load_from_hub @@ -154,7 +155,21 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: return load_prompt_from_config(config) +def _load_chat_prompt(config: Dict) -> ChatPromptTemplate: + """Load chat prompt from config""" + + messages = config.pop("messages") + template = messages[0]["prompt"].pop("template") if messages else None + config.pop("input_variables") + + if not template: + raise ValueError("Can't load chat prompt without template") + + return ChatPromptTemplate.from_template(template=template, **config) + + type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = { "prompt": _load_prompt, "few_shot": _load_few_shot_prompt, + "chat": _load_chat_prompt, }