mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 13:31:53 +00:00
allow partial prompt formatting
This commit is contained in:
parent
4e43b0efe9
commit
12195ce9cc
@ -151,6 +151,10 @@ class BasePromptTemplate(BaseModel, ABC):
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pformat(self, **kwargs: Any) -> None:
|
||||
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _prompt_type(self) -> str:
|
||||
|
@ -113,6 +113,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
def pformat(self, **kwargs: Any) -> None:
|
||||
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||
raise NotImplementedError(
|
||||
"pformat is currently not supported for FewShotPromptTemplate"
|
||||
)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
|
@ -133,6 +133,12 @@ class FewShotPromptWithTemplates(BasePromptTemplate, BaseModel):
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
def pformat(self, **kwargs: Any) -> None:
|
||||
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||
raise NotImplementedError(
|
||||
"pformat is not currently implemented for FewShotPromptWithTemplates"
|
||||
)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
|
@ -62,6 +62,17 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
||||
"""
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
|
||||
def pformat(self, **kwargs: Any) -> None:
|
||||
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||
missing_variables = set(self.input_variables) - set(kwargs.keys())
|
||||
for var in missing_variables:
|
||||
if self.template_format == "f-string":
|
||||
kwargs[var] = "{" + var + "}"
|
||||
elif self.template_format == "jinja2":
|
||||
kwargs[var] = "{{" + var + "}}"
|
||||
self.template = self.format(**kwargs)
|
||||
self.input_variables = list(missing_variables)
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that template and input variables are consistent."""
|
||||
|
@ -102,9 +102,57 @@ def test_prompt_invalid_template_format() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_no_input_variables() -> None:
|
||||
"""Test prompt template with no input variables."""
|
||||
template = "This is a test."
|
||||
input_variables: list = []
|
||||
prompt = PromptTemplate(input_variables=input_variables, template=template)
|
||||
assert prompt.format() == template
|
||||
|
||||
|
||||
def test_prompt_from_file() -> None:
|
||||
"""Test prompt can be successfully constructed from a file."""
|
||||
template_file = "tests/unit_tests/data/prompt_file.txt"
|
||||
input_variables = ["question"]
|
||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||
assert prompt.template == "Question: {question}\nAnswer:"
|
||||
|
||||
|
||||
def test_pformat() -> None:
|
||||
"""Test prompt can be partially formatted."""
|
||||
template = "This is a {foo} {bar} {baz} test."
|
||||
input_variables = ["foo", "bar", "baz"]
|
||||
prompt = PromptTemplate(input_variables=input_variables, template=template)
|
||||
|
||||
prompt.pformat(foo="foo")
|
||||
assert prompt.template == "This is a foo {bar} {baz} test."
|
||||
assert sorted(prompt.input_variables) == ["bar", "baz"]
|
||||
|
||||
prompt.pformat(bar="bar")
|
||||
assert prompt.template == "This is a foo bar {baz} test."
|
||||
assert prompt.input_variables == ["baz"]
|
||||
|
||||
prompt.pformat(baz="baz")
|
||||
assert prompt.template == "This is a foo bar baz test."
|
||||
assert prompt.input_variables == []
|
||||
|
||||
|
||||
def test_pformat_jinja2() -> None:
|
||||
"""Test prompt can be partially formatted."""
|
||||
template = "This is a {{foo}} {{bar}} {{baz}} test."
|
||||
input_variables = ["foo", "bar", "baz"]
|
||||
prompt = PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
||||
prompt.pformat(foo="foo")
|
||||
assert prompt.template == "This is a foo {{bar}} {{baz}} test."
|
||||
assert sorted(prompt.input_variables) == ["bar", "baz"]
|
||||
|
||||
prompt.pformat(bar="bar")
|
||||
assert prompt.template == "This is a foo bar {{baz}} test."
|
||||
assert prompt.input_variables == ["baz"]
|
||||
|
||||
prompt.pformat(baz="baz")
|
||||
assert prompt.template == "This is a foo bar baz test."
|
||||
assert prompt.input_variables == []
|
||||
|
Loading…
Reference in New Issue
Block a user