diff --git a/langchain/prompts/__init__.py b/langchain/prompts/__init__.py index efe15ce2128..72d0f438685 100644 --- a/langchain/prompts/__init__.py +++ b/langchain/prompts/__init__.py @@ -1,7 +1,7 @@ """Prompt template classes.""" from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.loading import load_prompt +from langchain.prompts.loading import load_from_hub, load_prompt from langchain.prompts.prompt import Prompt, PromptTemplate __all__ = [ @@ -10,4 +10,5 @@ __all__ = [ "PromptTemplate", "FewShotPromptTemplate", "Prompt", + "load_from_hub", ] diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index d99e51eff73..e7ceb6feb01 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -1,8 +1,11 @@ """Load prompts from disk.""" +import importlib import json +import tempfile from pathlib import Path from typing import Union +import requests import yaml from langchain.prompts.base import BasePromptTemplate @@ -97,7 +100,38 @@ def load_prompt(file: Union[str, Path]) -> BasePromptTemplate: elif file_path.suffix == ".yaml": with open(file_path, "r") as f: config = yaml.safe_load(f) + elif file_path.suffix == ".py": + spec = importlib.util.spec_from_loader( + "prompt", loader=None, origin=str(file_path) + ) + if spec is None: + raise ValueError("could not load spec") + helper = importlib.util.module_from_spec(spec) + with open(file_path, "rb") as f: + exec(f.read(), helper.__dict__) + if not isinstance(helper.PROMPT, BasePromptTemplate): + raise ValueError("Did not get object of type BasePromptTemplate.") + return helper.PROMPT else: - raise ValueError + raise ValueError(f"Got unsupported file type {file_path.suffix}") # Load the prompt from the config now. return load_prompt_from_config(config) + + +URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" + + +def load_from_hub(path: str) -> BasePromptTemplate: + """Load prompt from hub.""" + suffix = path.split(".")[-1] + if suffix not in {"py", "json", "yaml"}: + raise ValueError("Unsupported file type.") + full_url = URL_BASE + path + r = requests.get(full_url) + if r.status_code != 200: + raise ValueError(f"Could not find file at {full_url}") + with tempfile.TemporaryDirectory() as tmpdirname: + file = tmpdirname + "/prompt." + suffix + with open(file, "wb") as f: + f.write(r.content) + return load_prompt(file)