mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 01:37:59 +00:00
core[minor]: Implement aformat for FewShotPromptWithTemplates (#20039)
This commit is contained in:
parent
855ba46f80
commit
19001e6cb9
@ -101,6 +101,14 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return await self.example_selector.aselect_examples(kwargs)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
@ -149,6 +157,42 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
async def aformat(self, **kwargs: Any) -> str:
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
# Get the examples to use.
|
||||
examples = await self._aget_examples(**kwargs)
|
||||
# Format the examples.
|
||||
example_strings = [
|
||||
# We can use the sync method here as PromptTemplate doesn't block
|
||||
self.example_prompt.format(**example)
|
||||
for example in examples
|
||||
]
|
||||
# Create the overall prefix.
|
||||
if self.prefix is None:
|
||||
prefix = ""
|
||||
else:
|
||||
prefix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
|
||||
}
|
||||
for k in prefix_kwargs.keys():
|
||||
kwargs.pop(k)
|
||||
prefix = await self.prefix.aformat(**prefix_kwargs)
|
||||
|
||||
# Create the overall suffix
|
||||
suffix_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
|
||||
}
|
||||
for k in suffix_kwargs.keys():
|
||||
kwargs.pop(k)
|
||||
suffix = await self.suffix.aformat(
|
||||
**suffix_kwargs,
|
||||
)
|
||||
|
||||
pieces = [prefix, *example_strings, suffix]
|
||||
template = self.example_separator.join([piece for piece in pieces if piece])
|
||||
# Format the template with the input variables.
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
|
@ -10,7 +10,7 @@ EXAMPLE_PROMPT = PromptTemplate(
|
||||
)
|
||||
|
||||
|
||||
def test_prompttemplate_prefix_suffix() -> None:
|
||||
async def test_prompttemplate_prefix_suffix() -> None:
|
||||
"""Test that few shot works when prefix and suffix are PromptTemplates."""
|
||||
prefix = PromptTemplate(
|
||||
input_variables=["content"], template="This is a test about {content}."
|
||||
@ -32,13 +32,15 @@ def test_prompttemplate_prefix_suffix() -> None:
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
example_separator="\n",
|
||||
)
|
||||
output = prompt.format(content="animals", new_content="party")
|
||||
expected_output = (
|
||||
"This is a test about animals.\n"
|
||||
"foo: bar\n"
|
||||
"baz: foo\n"
|
||||
"Now you try to talk about party."
|
||||
)
|
||||
output = prompt.format(content="animals", new_content="party")
|
||||
assert output == expected_output
|
||||
output = await prompt.aformat(content="animals", new_content="party")
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user