diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index 30be599a7a1..178c637ea25 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -1,6 +1,7 @@ """Load prompts from disk.""" import importlib import json +import logging from pathlib import Path from typing import Union @@ -12,17 +13,20 @@ from langchain.prompts.prompt import PromptTemplate from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" +logger = logging.getLogger(__file__) def load_prompt_from_config(config: dict) -> BasePromptTemplate: - """Get the right type from the config and load it accordingly.""" - prompt_type = config.pop("_type", "prompt") - if prompt_type == "prompt": - return _load_prompt(config) - elif prompt_type == "few_shot": - return _load_few_shot_prompt(config) - else: - raise ValueError + """Load prompt from Config Dict.""" + if "_type" not in config: + logger.warning("No `_type` key found, defaulting to `prompt`.") + config_type = config.pop("_type", "prompt") + + if config_type not in type_to_loader_dict: + raise ValueError(f"Loading {config_type} prompt not supported") + + prompt_loader = type_to_loader_dict[config_type] + return prompt_loader(config) def _load_template(var_name: str, config: dict) -> dict: @@ -150,3 +154,10 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: raise ValueError(f"Got unsupported file type {file_path.suffix}") # Load the prompt from the config now. return load_prompt_from_config(config) + + +type_to_loader_dict = { + "prompt": _load_prompt, + "few_shot": _load_few_shot_prompt, + # "few_shot_with_templates": _load_few_shot_with_templates_prompt, +}