mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
cr
This commit is contained in:
@@ -19,9 +19,11 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain._api import deprecated
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.prompts.base import StringPromptTemplate
|
||||
from langchain.prompts.base import StringPromptTemplate, get_template_variables
|
||||
from langchain.prompts.image import ImagePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
@@ -234,6 +236,14 @@ _StringImageMessagePromptTemplateT = TypeVar(
|
||||
)
|
||||
|
||||
|
||||
class _TextTemplateParam(TypedDict, total=False):
|
||||
text: Union[str, Dict]
|
||||
|
||||
|
||||
class _ImageTemplateParam(TypedDict, total=False):
|
||||
image_url: Union[str, Dict]
|
||||
|
||||
|
||||
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
@@ -249,7 +259,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: Type[_StringImageMessagePromptTemplateT],
|
||||
template: Union[str, List[Union[str, Dict[str, Any]]]],
|
||||
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||
template_format: str = "f-string",
|
||||
**kwargs: Any,
|
||||
) -> _StringImageMessagePromptTemplateT:
|
||||
@@ -271,19 +281,30 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
elif isinstance(template, list):
|
||||
prompt = []
|
||||
for tmpl in template:
|
||||
if isinstance(tmpl, str):
|
||||
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
||||
if isinstance(tmpl, str):
|
||||
text: str = tmpl
|
||||
else:
|
||||
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
|
||||
prompt.append(
|
||||
PromptTemplate.from_template(
|
||||
tmpl, template_format=template_format
|
||||
text, template_format=template_format
|
||||
)
|
||||
)
|
||||
elif isinstance(tmpl, dict) and (
|
||||
"image" in tmpl or "image_url" in tmpl
|
||||
):
|
||||
img_template = tmpl.get("image") or tmpl.get("image_url")
|
||||
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
||||
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
|
||||
if isinstance(img_template, str):
|
||||
vars = get_template_variables(img_template, "f-string")
|
||||
if vars:
|
||||
if len(vars) > 1:
|
||||
raise ValueError
|
||||
variable_name = vars[0]
|
||||
img_template = {}
|
||||
else:
|
||||
variable_name = None
|
||||
img_template = {"url": img_template}
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
variable_name=img_template
|
||||
variable_name=variable_name, template=img_template
|
||||
)
|
||||
elif isinstance(img_template, dict):
|
||||
img_template = dict(img_template)
|
||||
|
||||
Reference in New Issue
Block a user