From d1027b1d7fd2cece7097c645bcc34d1d2fd8066b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Fri, 10 Nov 2023 17:58:43 -0800 Subject: [PATCH] cr --- libs/langchain/langchain/prompts/chat.py | 39 ++++++++++++++++++------ 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index 70f70f34d76..ed8b7583d19 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -19,9 +19,11 @@ from typing import ( overload, ) +from typing_extensions import TypedDict + from langchain._api import deprecated from langchain.load.serializable import Serializable -from langchain.prompts.base import StringPromptTemplate +from langchain.prompts.base import StringPromptTemplate, get_template_variables from langchain.prompts.image import ImagePromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.pydantic_v1 import Field, root_validator @@ -234,6 +236,14 @@ _StringImageMessagePromptTemplateT = TypeVar( ) +class _TextTemplateParam(TypedDict, total=False): + text: Union[str, Dict] + + +class _ImageTemplateParam(TypedDict, total=False): + image_url: Union[str, Dict] + + class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """Human message prompt template. This is a message sent from the user.""" @@ -249,7 +259,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): @classmethod def from_template( cls: Type[_StringImageMessagePromptTemplateT], - template: Union[str, List[Union[str, Dict[str, Any]]]], + template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]], template_format: str = "f-string", **kwargs: Any, ) -> _StringImageMessagePromptTemplateT: @@ -271,19 +281,30 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): elif isinstance(template, list): prompt = [] for tmpl in template: - if isinstance(tmpl, str): + if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: + if isinstance(tmpl, str): + text: str = tmpl + else: + text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501 prompt.append( PromptTemplate.from_template( - tmpl, template_format=template_format + text, template_format=template_format ) ) - elif isinstance(tmpl, dict) and ( - "image" in tmpl or "image_url" in tmpl - ): - img_template = tmpl.get("image") or tmpl.get("image_url") + elif isinstance(tmpl, dict) and "image_url" in tmpl: + img_template = cast(_ImageTemplateParam, tmpl)["image_url"] if isinstance(img_template, str): + vars = get_template_variables(img_template, "f-string") + if vars: + if len(vars) > 1: + raise ValueError + variable_name = vars[0] + img_template = {} + else: + variable_name = None + img_template = {"url": img_template} img_template_obj = ImagePromptTemplate( - variable_name=img_template + variable_name=variable_name, template=img_template ) elif isinstance(img_template, dict): img_template = dict(img_template)