added common prompt load method (#699)

Co-authored-by: scadEfUr
This commit is contained in:
scadEfUr 2023-01-22 23:46:11 -08:00 committed by GitHub
parent 36b6b3cdf6
commit 4aba0abeaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 8 deletions

View File

@ -1,7 +1,7 @@
"""Prompt template classes.""" """Prompt template classes."""
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.loading import load_from_hub, load_prompt from langchain.prompts.loading import load_prompt
from langchain.prompts.prompt import Prompt, PromptTemplate from langchain.prompts.prompt import Prompt, PromptTemplate
__all__ = [ __all__ = [
@ -10,5 +10,4 @@ __all__ = [
"PromptTemplate", "PromptTemplate",
"FewShotPromptTemplate", "FewShotPromptTemplate",
"Prompt", "Prompt",
"load_from_hub",
] ]

View File

@ -12,6 +12,8 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
def load_prompt_from_config(config: dict) -> BasePromptTemplate: def load_prompt_from_config(config: dict) -> BasePromptTemplate:
"""Get the right type from the config and load it accordingly.""" """Get the right type from the config and load it accordingly."""
@ -93,7 +95,16 @@ def _load_prompt(config: dict) -> PromptTemplate:
return PromptTemplate(**config) return PromptTemplate(**config)
def load_prompt(file: Union[str, Path]) -> BasePromptTemplate: def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
"""Unified method for loading a prompt from LangChainHub or local fs."""
if isinstance(path, str) and path.startswith("lc://prompts"):
path = path.lstrip("lc://prompts/")
return _load_from_hub(path)
else:
return _load_prompt_from_file(path)
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
"""Load prompt from file.""" """Load prompt from file."""
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): if isinstance(file, str):
@ -125,10 +136,7 @@ def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
return load_prompt_from_config(config) return load_prompt_from_config(config)
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" def _load_from_hub(path: str) -> BasePromptTemplate:
def load_from_hub(path: str) -> BasePromptTemplate:
"""Load prompt from hub.""" """Load prompt from hub."""
suffix = path.split(".")[-1] suffix = path.split(".")[-1]
if suffix not in {"py", "json", "yaml"}: if suffix not in {"py", "json", "yaml"}:
@ -141,4 +149,4 @@ def load_from_hub(path: str) -> BasePromptTemplate:
file = tmpdirname + "/prompt." + suffix file = tmpdirname + "/prompt." + suffix
with open(file, "wb") as f: with open(file, "wb") as f:
f.write(r.content) f.write(r.content)
return load_prompt(file) return _load_prompt_from_file(file)