diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index bb93ca9bdd6..58d8a47998f 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -151,6 +151,10 @@ class BasePromptTemplate(BaseModel, ABC): prompt.format(variable1="foo") """ + @abstractmethod + def pformat(self, **kwargs: Any) -> None: + """Apply partial formatting to the prompt with the inputs, in place.""" + @property @abstractmethod def _prompt_type(self) -> str: diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index 5a7b0c54fbd..b2ca354690c 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -113,6 +113,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel): # Format the template with the input variables. return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) + def pformat(self, **kwargs: Any) -> None: + """Apply partial formatting to the prompt with the inputs, in place.""" + raise NotImplementedError( + "pformat is currently not supported for FewShotPromptTemplate" + ) + @property def _prompt_type(self) -> str: """Return the prompt type key.""" diff --git a/langchain/prompts/few_shot_with_templates.py b/langchain/prompts/few_shot_with_templates.py index cba4f6d024f..74cf45f35c4 100644 --- a/langchain/prompts/few_shot_with_templates.py +++ b/langchain/prompts/few_shot_with_templates.py @@ -133,6 +133,12 @@ class FewShotPromptWithTemplates(BasePromptTemplate, BaseModel): # Format the template with the input variables. return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) + def pformat(self, **kwargs: Any) -> None: + """Apply partial formatting to the prompt with the inputs, in place.""" + raise NotImplementedError( + "pformat is not currently implemented for FewShotPromptWithTemplates" + ) + @property def _prompt_type(self) -> str: """Return the prompt type key.""" diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 0b2dcfddae3..7317f7fc961 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -62,6 +62,17 @@ class PromptTemplate(BasePromptTemplate, BaseModel): """ return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) + def pformat(self, **kwargs: Any) -> None: + """Apply partial formatting to the prompt with the inputs, in place.""" + missing_variables = set(self.input_variables) - set(kwargs.keys()) + for var in missing_variables: + if self.template_format == "f-string": + kwargs[var] = "{" + var + "}" + elif self.template_format == "jinja2": + kwargs[var] = "{{" + var + "}}" + self.template = self.format(**kwargs) + self.input_variables = list(missing_variables) + @root_validator() def template_is_valid(cls, values: Dict) -> Dict: """Check that template and input variables are consistent.""" diff --git a/tests/unit_tests/prompts/test_prompt.py b/tests/unit_tests/prompts/test_prompt.py index 1789963e242..98f17f6248a 100644 --- a/tests/unit_tests/prompts/test_prompt.py +++ b/tests/unit_tests/prompts/test_prompt.py @@ -102,9 +102,57 @@ def test_prompt_invalid_template_format() -> None: ) +def test_no_input_variables() -> None: + """Test prompt template with no input variables.""" + template = "This is a test." + input_variables: list = [] + prompt = PromptTemplate(input_variables=input_variables, template=template) + assert prompt.format() == template + + def test_prompt_from_file() -> None: """Test prompt can be successfully constructed from a file.""" template_file = "tests/unit_tests/data/prompt_file.txt" input_variables = ["question"] prompt = PromptTemplate.from_file(template_file, input_variables) assert prompt.template == "Question: {question}\nAnswer:" + + +def test_pformat() -> None: + """Test prompt can be partially formatted.""" + template = "This is a {foo} {bar} {baz} test." + input_variables = ["foo", "bar", "baz"] + prompt = PromptTemplate(input_variables=input_variables, template=template) + + prompt.pformat(foo="foo") + assert prompt.template == "This is a foo {bar} {baz} test." + assert sorted(prompt.input_variables) == ["bar", "baz"] + + prompt.pformat(bar="bar") + assert prompt.template == "This is a foo bar {baz} test." + assert prompt.input_variables == ["baz"] + + prompt.pformat(baz="baz") + assert prompt.template == "This is a foo bar baz test." + assert prompt.input_variables == [] + + +def test_pformat_jinja2() -> None: + """Test prompt can be partially formatted.""" + template = "This is a {{foo}} {{bar}} {{baz}} test." + input_variables = ["foo", "bar", "baz"] + prompt = PromptTemplate( + input_variables=input_variables, template=template, template_format="jinja2" + ) + + prompt.pformat(foo="foo") + assert prompt.template == "This is a foo {{bar}} {{baz}} test." + assert sorted(prompt.input_variables) == ["bar", "baz"] + + prompt.pformat(bar="bar") + assert prompt.template == "This is a foo bar {{baz}} test." + assert prompt.input_variables == ["baz"] + + prompt.pformat(baz="baz") + assert prompt.template == "This is a foo bar baz test." + assert prompt.input_variables == []