core[minor]: Implement aformat_messages for _StringImageMessagePromptTemplate (#20036)

This commit is contained in:
Christophe Bornet
2024-04-09 21:59:39 +02:00
committed by GitHub
parent 19001e6cb9
commit f43b48aebc
2 changed files with 78 additions and 16 deletions

View File

@@ -506,6 +506,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""
return [self.format(**kwargs)]
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
return [await self.aformat(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
@@ -546,6 +549,34 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
content=content, additional_kwargs=self.additional_kwargs
)
async def aformat(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
if isinstance(self.prompt, StringPromptTemplate):
text = await self.prompt.aformat(**kwargs)
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content: List = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = await prompt.aformat(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = await prompt.aformat(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
def pretty_repr(self, html: bool = False) -> str:
# TODO: Handle partials
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")