This commit is contained in:
Bagatur
2024-08-19 15:02:52 -07:00
parent 644f338a10
commit cb168ef981

View File

@@ -450,7 +450,7 @@ _StringImageMessagePromptTemplateT = TypeVar(
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]
text: str
class _ImageTemplateParam(TypedDict, total=False):
@@ -462,7 +462,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
prompt: Union[
StringPromptTemplate,
List[Union[StringPromptTemplate, ContentBlockPromptTemplate]],
List[
Union[StringPromptTemplate, ImagePromptTemplate, ContentBlockPromptTemplate]
],
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
@@ -520,6 +522,29 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
tmpl, template_format=template_format
)
)
# For backwards compatible ser/des.
# TODO: Refactor in 1.0 so this just uses a ContentBlockPromptTemplate.
elif isinstance(tmpl, dict) and set(tmpl.keys()) <= {
"type",
"text",
}:
prompt.append(
PromptTemplate.from_template(
cast(_TextTemplateParam, tmpl)["text"],
template_format=template_format,
)
)
# For backwards compatible ser/des.
# TODO: Refactor in 1.0 so this just use a ContentBlockPromptTemplate.
elif isinstance(tmpl, dict) and set(tmpl.keys()) <= {
"type",
"image_url",
}:
prompt.append(
ImagePromptTemplate(
template=cast(dict, tmpl), template_format=template_format
)
)
elif isinstance(tmpl, dict):
prompt.append(
ContentBlockPromptTemplate(
@@ -615,7 +640,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
formatted: Union[str, dict] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
formatted = cast(dict, prompt.format(**inputs))
content.append({"type": "image_url", "image_url": formatted})
elif isinstance(prompt, ContentBlockPromptTemplate):
formatted = prompt.format(**inputs)
@@ -654,7 +679,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
formatted: Union[str, dict] = await prompt.aformat(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = await prompt.aformat(**inputs)
formatted = cast(dict, await prompt.aformat(**inputs))
content.append({"type": "image_url", "image_url": formatted})
elif isinstance(prompt, ContentBlockPromptTemplate):
formatted = await prompt.aformat(**inputs)