This commit is contained in:
Bagatur
2023-11-10 17:58:43 -08:00
parent 2afa070d34
commit d1027b1d7f

View File

@@ -19,9 +19,11 @@ from typing import (
overload,
)
from typing_extensions import TypedDict
from langchain._api import deprecated
from langchain.load.serializable import Serializable
from langchain.prompts.base import StringPromptTemplate
from langchain.prompts.base import StringPromptTemplate, get_template_variables
from langchain.prompts.image import ImagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.pydantic_v1 import Field, root_validator
@@ -234,6 +236,14 @@ _StringImageMessagePromptTemplateT = TypeVar(
)
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]
class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict]
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
@@ -249,7 +259,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod
def from_template(
cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, Dict[str, Any]]]],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
@@ -271,19 +281,30 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
elif isinstance(template, list):
prompt = []
for tmpl in template:
if isinstance(tmpl, str):
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
prompt.append(
PromptTemplate.from_template(
tmpl, template_format=template_format
text, template_format=template_format
)
)
elif isinstance(tmpl, dict) and (
"image" in tmpl or "image_url" in tmpl
):
img_template = tmpl.get("image") or tmpl.get("image_url")
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
if vars:
if len(vars) > 1:
raise ValueError
variable_name = vars[0]
img_template = {}
else:
variable_name = None
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
variable_name=img_template
variable_name=variable_name, template=img_template
)
elif isinstance(img_template, dict):
img_template = dict(img_template)