allow partial prompt formatting

This commit is contained in:
Ankush Gola 2023-02-21 17:31:08 -08:00
parent 4e43b0efe9
commit 12195ce9cc
5 changed files with 75 additions and 0 deletions

View File

@ -151,6 +151,10 @@ class BasePromptTemplate(BaseModel, ABC):
prompt.format(variable1="foo")
"""
@abstractmethod
def pformat(self, **kwargs: Any) -> None:
"""Apply partial formatting to the prompt with the inputs, in place."""
@property
@abstractmethod
def _prompt_type(self) -> str:

View File

@ -113,6 +113,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
def pformat(self, **kwargs: Any) -> None:
"""Apply partial formatting to the prompt with the inputs, in place."""
raise NotImplementedError(
"pformat is currently not supported for FewShotPromptTemplate"
)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""

View File

@ -133,6 +133,12 @@ class FewShotPromptWithTemplates(BasePromptTemplate, BaseModel):
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
def pformat(self, **kwargs: Any) -> None:
"""Apply partial formatting to the prompt with the inputs, in place."""
raise NotImplementedError(
"pformat is not currently implemented for FewShotPromptWithTemplates"
)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""

View File

@ -62,6 +62,17 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
"""
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
def pformat(self, **kwargs: Any) -> None:
"""Apply partial formatting to the prompt with the inputs, in place."""
missing_variables = set(self.input_variables) - set(kwargs.keys())
for var in missing_variables:
if self.template_format == "f-string":
kwargs[var] = "{" + var + "}"
elif self.template_format == "jinja2":
kwargs[var] = "{{" + var + "}}"
self.template = self.format(**kwargs)
self.input_variables = list(missing_variables)
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""

View File

@ -102,9 +102,57 @@ def test_prompt_invalid_template_format() -> None:
)
def test_no_input_variables() -> None:
"""Test prompt template with no input variables."""
template = "This is a test."
input_variables: list = []
prompt = PromptTemplate(input_variables=input_variables, template=template)
assert prompt.format() == template
def test_prompt_from_file() -> None:
"""Test prompt can be successfully constructed from a file."""
template_file = "tests/unit_tests/data/prompt_file.txt"
input_variables = ["question"]
prompt = PromptTemplate.from_file(template_file, input_variables)
assert prompt.template == "Question: {question}\nAnswer:"
def test_pformat() -> None:
"""Test prompt can be partially formatted."""
template = "This is a {foo} {bar} {baz} test."
input_variables = ["foo", "bar", "baz"]
prompt = PromptTemplate(input_variables=input_variables, template=template)
prompt.pformat(foo="foo")
assert prompt.template == "This is a foo {bar} {baz} test."
assert sorted(prompt.input_variables) == ["bar", "baz"]
prompt.pformat(bar="bar")
assert prompt.template == "This is a foo bar {baz} test."
assert prompt.input_variables == ["baz"]
prompt.pformat(baz="baz")
assert prompt.template == "This is a foo bar baz test."
assert prompt.input_variables == []
def test_pformat_jinja2() -> None:
"""Test prompt can be partially formatted."""
template = "This is a {{foo}} {{bar}} {{baz}} test."
input_variables = ["foo", "bar", "baz"]
prompt = PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
prompt.pformat(foo="foo")
assert prompt.template == "This is a foo {{bar}} {{baz}} test."
assert sorted(prompt.input_variables) == ["bar", "baz"]
prompt.pformat(bar="bar")
assert prompt.template == "This is a foo bar {{baz}} test."
assert prompt.input_variables == ["baz"]
prompt.pformat(baz="baz")
assert prompt.template == "This is a foo bar baz test."
assert prompt.input_variables == []