mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 13:31:53 +00:00
allow partial prompt formatting
This commit is contained in:
parent
4e43b0efe9
commit
12195ce9cc
@ -151,6 +151,10 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
prompt.format(variable1="foo")
|
prompt.format(variable1="foo")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pformat(self, **kwargs: Any) -> None:
|
||||||
|
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
|
@ -113,6 +113,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
# Format the template with the input variables.
|
# Format the template with the input variables.
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||||
|
|
||||||
|
def pformat(self, **kwargs: Any) -> None:
|
||||||
|
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"pformat is currently not supported for FewShotPromptTemplate"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
"""Return the prompt type key."""
|
"""Return the prompt type key."""
|
||||||
|
@ -133,6 +133,12 @@ class FewShotPromptWithTemplates(BasePromptTemplate, BaseModel):
|
|||||||
# Format the template with the input variables.
|
# Format the template with the input variables.
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||||
|
|
||||||
|
def pformat(self, **kwargs: Any) -> None:
|
||||||
|
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"pformat is not currently implemented for FewShotPromptWithTemplates"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
"""Return the prompt type key."""
|
"""Return the prompt type key."""
|
||||||
|
@ -62,6 +62,17 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
"""
|
"""
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||||
|
|
||||||
|
def pformat(self, **kwargs: Any) -> None:
|
||||||
|
"""Apply partial formatting to the prompt with the inputs, in place."""
|
||||||
|
missing_variables = set(self.input_variables) - set(kwargs.keys())
|
||||||
|
for var in missing_variables:
|
||||||
|
if self.template_format == "f-string":
|
||||||
|
kwargs[var] = "{" + var + "}"
|
||||||
|
elif self.template_format == "jinja2":
|
||||||
|
kwargs[var] = "{{" + var + "}}"
|
||||||
|
self.template = self.format(**kwargs)
|
||||||
|
self.input_variables = list(missing_variables)
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def template_is_valid(cls, values: Dict) -> Dict:
|
def template_is_valid(cls, values: Dict) -> Dict:
|
||||||
"""Check that template and input variables are consistent."""
|
"""Check that template and input variables are consistent."""
|
||||||
|
@ -102,9 +102,57 @@ def test_prompt_invalid_template_format() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_input_variables() -> None:
|
||||||
|
"""Test prompt template with no input variables."""
|
||||||
|
template = "This is a test."
|
||||||
|
input_variables: list = []
|
||||||
|
prompt = PromptTemplate(input_variables=input_variables, template=template)
|
||||||
|
assert prompt.format() == template
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_from_file() -> None:
|
def test_prompt_from_file() -> None:
|
||||||
"""Test prompt can be successfully constructed from a file."""
|
"""Test prompt can be successfully constructed from a file."""
|
||||||
template_file = "tests/unit_tests/data/prompt_file.txt"
|
template_file = "tests/unit_tests/data/prompt_file.txt"
|
||||||
input_variables = ["question"]
|
input_variables = ["question"]
|
||||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||||
assert prompt.template == "Question: {question}\nAnswer:"
|
assert prompt.template == "Question: {question}\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pformat() -> None:
|
||||||
|
"""Test prompt can be partially formatted."""
|
||||||
|
template = "This is a {foo} {bar} {baz} test."
|
||||||
|
input_variables = ["foo", "bar", "baz"]
|
||||||
|
prompt = PromptTemplate(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
|
prompt.pformat(foo="foo")
|
||||||
|
assert prompt.template == "This is a foo {bar} {baz} test."
|
||||||
|
assert sorted(prompt.input_variables) == ["bar", "baz"]
|
||||||
|
|
||||||
|
prompt.pformat(bar="bar")
|
||||||
|
assert prompt.template == "This is a foo bar {baz} test."
|
||||||
|
assert prompt.input_variables == ["baz"]
|
||||||
|
|
||||||
|
prompt.pformat(baz="baz")
|
||||||
|
assert prompt.template == "This is a foo bar baz test."
|
||||||
|
assert prompt.input_variables == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_pformat_jinja2() -> None:
|
||||||
|
"""Test prompt can be partially formatted."""
|
||||||
|
template = "This is a {{foo}} {{bar}} {{baz}} test."
|
||||||
|
input_variables = ["foo", "bar", "baz"]
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt.pformat(foo="foo")
|
||||||
|
assert prompt.template == "This is a foo {{bar}} {{baz}} test."
|
||||||
|
assert sorted(prompt.input_variables) == ["bar", "baz"]
|
||||||
|
|
||||||
|
prompt.pformat(bar="bar")
|
||||||
|
assert prompt.template == "This is a foo bar {{baz}} test."
|
||||||
|
assert prompt.input_variables == ["baz"]
|
||||||
|
|
||||||
|
prompt.pformat(baz="baz")
|
||||||
|
assert prompt.template == "This is a foo bar baz test."
|
||||||
|
assert prompt.input_variables == []
|
||||||
|
Loading…
Reference in New Issue
Block a user