mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
fmt
This commit is contained in:
@@ -450,7 +450,7 @@ _StringImageMessagePromptTemplateT = TypeVar(
|
||||
|
||||
|
||||
class _TextTemplateParam(TypedDict, total=False):
|
||||
text: Union[str, Dict]
|
||||
text: str
|
||||
|
||||
|
||||
class _ImageTemplateParam(TypedDict, total=False):
|
||||
@@ -462,7 +462,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
|
||||
prompt: Union[
|
||||
StringPromptTemplate,
|
||||
List[Union[StringPromptTemplate, ContentBlockPromptTemplate]],
|
||||
List[
|
||||
Union[StringPromptTemplate, ImagePromptTemplate, ContentBlockPromptTemplate]
|
||||
],
|
||||
]
|
||||
"""Prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
@@ -520,6 +522,29 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
tmpl, template_format=template_format
|
||||
)
|
||||
)
|
||||
# For backwards compatible ser/des.
|
||||
# TODO: Refactor in 1.0 so this just uses a ContentBlockPromptTemplate.
|
||||
elif isinstance(tmpl, dict) and set(tmpl.keys()) <= {
|
||||
"type",
|
||||
"text",
|
||||
}:
|
||||
prompt.append(
|
||||
PromptTemplate.from_template(
|
||||
cast(_TextTemplateParam, tmpl)["text"],
|
||||
template_format=template_format,
|
||||
)
|
||||
)
|
||||
# For backwards compatible ser/des.
|
||||
# TODO: Refactor in 1.0 so this just use a ContentBlockPromptTemplate.
|
||||
elif isinstance(tmpl, dict) and set(tmpl.keys()) <= {
|
||||
"type",
|
||||
"image_url",
|
||||
}:
|
||||
prompt.append(
|
||||
ImagePromptTemplate(
|
||||
template=cast(dict, tmpl), template_format=template_format
|
||||
)
|
||||
)
|
||||
elif isinstance(tmpl, dict):
|
||||
prompt.append(
|
||||
ContentBlockPromptTemplate(
|
||||
@@ -615,7 +640,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
formatted: Union[str, dict] = prompt.format(**inputs)
|
||||
content.append({"type": "text", "text": formatted})
|
||||
elif isinstance(prompt, ImagePromptTemplate):
|
||||
formatted = prompt.format(**inputs)
|
||||
formatted = cast(dict, prompt.format(**inputs))
|
||||
content.append({"type": "image_url", "image_url": formatted})
|
||||
elif isinstance(prompt, ContentBlockPromptTemplate):
|
||||
formatted = prompt.format(**inputs)
|
||||
@@ -654,7 +679,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
formatted: Union[str, dict] = await prompt.aformat(**inputs)
|
||||
content.append({"type": "text", "text": formatted})
|
||||
elif isinstance(prompt, ImagePromptTemplate):
|
||||
formatted = await prompt.aformat(**inputs)
|
||||
formatted = cast(dict, await prompt.aformat(**inputs))
|
||||
content.append({"type": "image_url", "image_url": formatted})
|
||||
elif isinstance(prompt, ContentBlockPromptTemplate):
|
||||
formatted = await prompt.aformat(**inputs)
|
||||
|
||||
Reference in New Issue
Block a user