mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
parent
36b6b3cdf6
commit
4aba0abeaa
@ -1,7 +1,7 @@
|
|||||||
"""Prompt template classes."""
|
"""Prompt template classes."""
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.loading import load_from_hub, load_prompt
|
from langchain.prompts.loading import load_prompt
|
||||||
from langchain.prompts.prompt import Prompt, PromptTemplate
|
from langchain.prompts.prompt import Prompt, PromptTemplate
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -10,5 +10,4 @@ __all__ = [
|
|||||||
"PromptTemplate",
|
"PromptTemplate",
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"Prompt",
|
"Prompt",
|
||||||
"load_from_hub",
|
|
||||||
]
|
]
|
||||||
|
@ -12,6 +12,8 @@ from langchain.prompts.base import BasePromptTemplate
|
|||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||||
"""Get the right type from the config and load it accordingly."""
|
"""Get the right type from the config and load it accordingly."""
|
||||||
@ -93,7 +95,16 @@ def _load_prompt(config: dict) -> PromptTemplate:
|
|||||||
return PromptTemplate(**config)
|
return PromptTemplate(**config)
|
||||||
|
|
||||||
|
|
||||||
def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
|
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||||
|
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
||||||
|
if isinstance(path, str) and path.startswith("lc://prompts"):
|
||||||
|
path = path.lstrip("lc://prompts/")
|
||||||
|
return _load_from_hub(path)
|
||||||
|
else:
|
||||||
|
return _load_prompt_from_file(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||||
"""Load prompt from file."""
|
"""Load prompt from file."""
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file, str):
|
if isinstance(file, str):
|
||||||
@ -125,10 +136,7 @@ def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
|
|||||||
return load_prompt_from_config(config)
|
return load_prompt_from_config(config)
|
||||||
|
|
||||||
|
|
||||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
def _load_from_hub(path: str) -> BasePromptTemplate:
|
||||||
|
|
||||||
|
|
||||||
def load_from_hub(path: str) -> BasePromptTemplate:
|
|
||||||
"""Load prompt from hub."""
|
"""Load prompt from hub."""
|
||||||
suffix = path.split(".")[-1]
|
suffix = path.split(".")[-1]
|
||||||
if suffix not in {"py", "json", "yaml"}:
|
if suffix not in {"py", "json", "yaml"}:
|
||||||
@ -141,4 +149,4 @@ def load_from_hub(path: str) -> BasePromptTemplate:
|
|||||||
file = tmpdirname + "/prompt." + suffix
|
file = tmpdirname + "/prompt." + suffix
|
||||||
with open(file, "wb") as f:
|
with open(file, "wb") as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
return load_prompt(file)
|
return _load_prompt_from_file(file)
|
||||||
|
Loading…
Reference in New Issue
Block a user