mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
Harrison/standarize prompt loading (#1036)
Co-authored-by: Ibis Prevedello <ibiscp@gmail.com>
This commit is contained in:
parent
f30dcc6359
commit
8c45f06d58
@ -1,6 +1,7 @@
|
|||||||
"""Load prompts from disk."""
|
"""Load prompts from disk."""
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -12,17 +13,20 @@ from langchain.prompts.prompt import PromptTemplate
|
|||||||
from langchain.utilities.loading import try_load_from_hub
|
from langchain.utilities.loading import try_load_from_hub
|
||||||
|
|
||||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||||
"""Get the right type from the config and load it accordingly."""
|
"""Load prompt from Config Dict."""
|
||||||
prompt_type = config.pop("_type", "prompt")
|
if "_type" not in config:
|
||||||
if prompt_type == "prompt":
|
logger.warning("No `_type` key found, defaulting to `prompt`.")
|
||||||
return _load_prompt(config)
|
config_type = config.pop("_type", "prompt")
|
||||||
elif prompt_type == "few_shot":
|
|
||||||
return _load_few_shot_prompt(config)
|
if config_type not in type_to_loader_dict:
|
||||||
else:
|
raise ValueError(f"Loading {config_type} prompt not supported")
|
||||||
raise ValueError
|
|
||||||
|
prompt_loader = type_to_loader_dict[config_type]
|
||||||
|
return prompt_loader(config)
|
||||||
|
|
||||||
|
|
||||||
def _load_template(var_name: str, config: dict) -> dict:
|
def _load_template(var_name: str, config: dict) -> dict:
|
||||||
@ -150,3 +154,10 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
|||||||
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
||||||
# Load the prompt from the config now.
|
# Load the prompt from the config now.
|
||||||
return load_prompt_from_config(config)
|
return load_prompt_from_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
type_to_loader_dict = {
|
||||||
|
"prompt": _load_prompt,
|
||||||
|
"few_shot": _load_few_shot_prompt,
|
||||||
|
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user