mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
Save Prompts (#194)
This commit is contained in:
parent
b90e25f786
commit
ae72cf84b8
@ -1,7 +1,10 @@
|
||||
"""BasePrompt schema definition."""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.formatting import formatter
|
||||
@ -61,3 +64,39 @@ class BasePromptTemplate(BaseModel, ABC):
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
|
||||
def _prompt_dict(self) -> Dict:
|
||||
"""Return a dictionary of the prompt."""
|
||||
return self.dict()
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the prompt.
|
||||
|
||||
Args:
|
||||
file_path: Path to directory to save prompt to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt.save(file_path="path/prompt.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self._prompt_dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
f.write(json.dumps(prompt_dict, indent=4))
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
@ -108,3 +108,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
|
||||
template = self.example_separator.join([piece for piece in pieces if piece])
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
def _prompt_dict(self) -> Dict:
|
||||
"""Return a dictionary of the prompt."""
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
|
||||
prompt_dict = self.dict()
|
||||
prompt_dict["_type"] = "few_shot"
|
||||
return prompt_dict
|
||||
|
@ -43,6 +43,34 @@ def test_loading_from_JSON() -> None:
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||
"""Test equality when saving and loading a prompt."""
|
||||
simple_prompt = PromptTemplate(
|
||||
input_variables=["adjective", "content"],
|
||||
template="Tell me a {adjective} joke about {content}.",
|
||||
)
|
||||
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
|
||||
assert loaded_prompt == simple_prompt
|
||||
|
||||
few_shot_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
example_prompt=PromptTemplate(
|
||||
input_variables=["input", "output"],
|
||||
template="Input: {input}\nOutput: {output}",
|
||||
),
|
||||
examples=[
|
||||
{"input": "happy", "output": "sad"},
|
||||
{"input": "tall", "output": "short"},
|
||||
],
|
||||
suffix="Input: {adjective}\nOutput:",
|
||||
)
|
||||
few_shot_prompt.save(file_path=tmp_path / "few_shot.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "few_shot.yaml")
|
||||
assert loaded_prompt == few_shot_prompt
|
||||
|
||||
|
||||
def test_loading_with_template_as_file() -> None:
|
||||
"""Test loading when the template is a file."""
|
||||
with change_directory():
|
||||
|
Loading…
Reference in New Issue
Block a user