mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
Save Prompts (#194)
This commit is contained in:
parent
b90e25f786
commit
ae72cf84b8
@ -1,7 +1,10 @@
|
|||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.formatting import formatter
|
from langchain.formatting import formatter
|
||||||
@ -61,3 +64,39 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
|
|
||||||
prompt.format(variable1="foo")
|
prompt.format(variable1="foo")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _prompt_dict(self) -> Dict:
|
||||||
|
"""Return a dictionary of the prompt."""
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
|
"""Save the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to directory to save prompt to.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.save(file_path="path/prompt.yaml")
|
||||||
|
"""
|
||||||
|
# Convert file to Path object.
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
save_path = Path(file_path)
|
||||||
|
else:
|
||||||
|
save_path = file_path
|
||||||
|
|
||||||
|
directory_path = save_path.parent
|
||||||
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
prompt_dict = self._prompt_dict()
|
||||||
|
|
||||||
|
if save_path.suffix == ".json":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
f.write(json.dumps(prompt_dict, indent=4))
|
||||||
|
elif save_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
@ -108,3 +108,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
template = self.example_separator.join([piece for piece in pieces if piece])
|
template = self.example_separator.join([piece for piece in pieces if piece])
|
||||||
# Format the template with the input variables.
|
# Format the template with the input variables.
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||||
|
|
||||||
|
def _prompt_dict(self) -> Dict:
|
||||||
|
"""Return a dictionary of the prompt."""
|
||||||
|
if self.example_selector:
|
||||||
|
raise ValueError("Saving an example selector is not currently supported")
|
||||||
|
|
||||||
|
prompt_dict = self.dict()
|
||||||
|
prompt_dict["_type"] = "few_shot"
|
||||||
|
return prompt_dict
|
||||||
|
@ -43,6 +43,34 @@ def test_loading_from_JSON() -> None:
|
|||||||
assert prompt == expected_prompt
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||||
|
"""Test equality when saving and loading a prompt."""
|
||||||
|
simple_prompt = PromptTemplate(
|
||||||
|
input_variables=["adjective", "content"],
|
||||||
|
template="Tell me a {adjective} joke about {content}.",
|
||||||
|
)
|
||||||
|
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
|
||||||
|
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
|
||||||
|
assert loaded_prompt == simple_prompt
|
||||||
|
|
||||||
|
few_shot_prompt = FewShotPromptTemplate(
|
||||||
|
input_variables=["adjective"],
|
||||||
|
prefix="Write antonyms for the following words.",
|
||||||
|
example_prompt=PromptTemplate(
|
||||||
|
input_variables=["input", "output"],
|
||||||
|
template="Input: {input}\nOutput: {output}",
|
||||||
|
),
|
||||||
|
examples=[
|
||||||
|
{"input": "happy", "output": "sad"},
|
||||||
|
{"input": "tall", "output": "short"},
|
||||||
|
],
|
||||||
|
suffix="Input: {adjective}\nOutput:",
|
||||||
|
)
|
||||||
|
few_shot_prompt.save(file_path=tmp_path / "few_shot.yaml")
|
||||||
|
loaded_prompt = load_prompt(tmp_path / "few_shot.yaml")
|
||||||
|
assert loaded_prompt == few_shot_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_loading_with_template_as_file() -> None:
|
def test_loading_with_template_as_file() -> None:
|
||||||
"""Test loading when the template is a file."""
|
"""Test loading when the template is a file."""
|
||||||
with change_directory():
|
with change_directory():
|
||||||
|
Loading…
Reference in New Issue
Block a user