mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 18:53:10 +00:00
core[patch]: update dict prompt template (#30967)
Align with JS changes made in https://github.com/langchain-ai/langchainjs/pull/8043
This commit is contained in:
parent
4bc70766b5
commit
d4fc734250
@ -1,155 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
import pkgutil
|
|
||||||
from types import ModuleType
|
|
||||||
|
|
||||||
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
|
||||||
|
|
||||||
|
|
||||||
def import_all_modules(package_name: str) -> dict:
|
|
||||||
package = importlib.import_module(package_name)
|
|
||||||
classes: dict = {}
|
|
||||||
|
|
||||||
def _handle_module(module: ModuleType) -> None:
|
|
||||||
# Iterate over all members of the module
|
|
||||||
|
|
||||||
names = dir(module)
|
|
||||||
|
|
||||||
if hasattr(module, "__all__"):
|
|
||||||
names += list(module.__all__)
|
|
||||||
|
|
||||||
names = sorted(set(names))
|
|
||||||
|
|
||||||
for name in names:
|
|
||||||
# Check if it's a class or function
|
|
||||||
attr = getattr(module, name)
|
|
||||||
|
|
||||||
if not inspect.isclass(attr):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(attr.is_lc_serializable(), bool)
|
|
||||||
and attr.is_lc_serializable()
|
|
||||||
):
|
|
||||||
key = tuple(attr.lc_id())
|
|
||||||
value = tuple(attr.__module__.split(".") + [attr.__name__])
|
|
||||||
if key in classes and classes[key] != value:
|
|
||||||
raise ValueError
|
|
||||||
classes[key] = value
|
|
||||||
|
|
||||||
_handle_module(package)
|
|
||||||
|
|
||||||
for importer, modname, ispkg in pkgutil.walk_packages(
|
|
||||||
package.__path__, package.__name__ + "."
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(modname)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
continue
|
|
||||||
_handle_module(module)
|
|
||||||
|
|
||||||
return classes
|
|
||||||
|
|
||||||
|
|
||||||
def test_import_all_modules() -> None:
|
|
||||||
"""Test import all modules works as expected"""
|
|
||||||
all_modules = import_all_modules("langchain")
|
|
||||||
filtered_modules = [
|
|
||||||
k
|
|
||||||
for k in all_modules
|
|
||||||
if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models")
|
|
||||||
]
|
|
||||||
# This test will need to be updated if new serializable classes are added
|
|
||||||
# to community
|
|
||||||
assert sorted(filtered_modules) == sorted(
|
|
||||||
[
|
|
||||||
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"),
|
|
||||||
("langchain", "chat_models", "bedrock", "BedrockChat"),
|
|
||||||
("langchain", "chat_models", "anthropic", "ChatAnthropic"),
|
|
||||||
("langchain", "chat_models", "fireworks", "ChatFireworks"),
|
|
||||||
("langchain", "chat_models", "google_palm", "ChatGooglePalm"),
|
|
||||||
("langchain", "chat_models", "openai", "ChatOpenAI"),
|
|
||||||
("langchain", "chat_models", "vertexai", "ChatVertexAI"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_serializable_mapping() -> None:
|
|
||||||
to_skip = {
|
|
||||||
# This should have had a different namespace, as it was never
|
|
||||||
# exported from the langchain module, but we keep for whoever has
|
|
||||||
# already serialized it.
|
|
||||||
("langchain", "prompts", "image", "ImagePromptTemplate"): (
|
|
||||||
"langchain_core",
|
|
||||||
"prompts",
|
|
||||||
"image",
|
|
||||||
"ImagePromptTemplate",
|
|
||||||
),
|
|
||||||
# This is not exported from langchain, only langchain_core
|
|
||||||
("langchain_core", "prompts", "structured", "StructuredPrompt"): (
|
|
||||||
"langchain_core",
|
|
||||||
"prompts",
|
|
||||||
"structured",
|
|
||||||
"StructuredPrompt",
|
|
||||||
),
|
|
||||||
# This is not exported from langchain, only langchain_core
|
|
||||||
("langchain", "schema", "messages", "RemoveMessage"): (
|
|
||||||
"langchain_core",
|
|
||||||
"messages",
|
|
||||||
"modifier",
|
|
||||||
"RemoveMessage",
|
|
||||||
),
|
|
||||||
("langchain", "chat_models", "mistralai", "ChatMistralAI"): (
|
|
||||||
"langchain_mistralai",
|
|
||||||
"chat_models",
|
|
||||||
"ChatMistralAI",
|
|
||||||
),
|
|
||||||
("langchain_groq", "chat_models", "ChatGroq"): (
|
|
||||||
"langchain_groq",
|
|
||||||
"chat_models",
|
|
||||||
"ChatGroq",
|
|
||||||
),
|
|
||||||
("langchain_sambanova", "chat_models", "ChatSambaNovaCloud"): (
|
|
||||||
"langchain_sambanova",
|
|
||||||
"chat_models",
|
|
||||||
"ChatSambaNovaCloud",
|
|
||||||
),
|
|
||||||
("langchain_sambanova", "chat_models", "ChatSambaStudio"): (
|
|
||||||
"langchain_sambanova",
|
|
||||||
"chat_models",
|
|
||||||
"ChatSambaStudio",
|
|
||||||
),
|
|
||||||
# TODO(0.3): For now we're skipping the below two tests. Need to fix
|
|
||||||
# so that it only runs when langchain-aws, langchain-google-genai
|
|
||||||
# are installed.
|
|
||||||
("langchain", "chat_models", "bedrock", "ChatBedrock"): (
|
|
||||||
"langchain_aws",
|
|
||||||
"chat_models",
|
|
||||||
"bedrock",
|
|
||||||
"ChatBedrock",
|
|
||||||
),
|
|
||||||
("langchain_google_genai", "chat_models", "ChatGoogleGenerativeAI"): (
|
|
||||||
"langchain_google_genai",
|
|
||||||
"chat_models",
|
|
||||||
"ChatGoogleGenerativeAI",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
serializable_modules = import_all_modules("langchain")
|
|
||||||
|
|
||||||
missing = set(SERIALIZABLE_MAPPING).difference(
|
|
||||||
set(serializable_modules).union(to_skip)
|
|
||||||
)
|
|
||||||
assert missing == set()
|
|
||||||
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
|
|
||||||
assert extra == set()
|
|
||||||
|
|
||||||
for k, import_path in serializable_modules.items():
|
|
||||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
|
||||||
# Import module
|
|
||||||
mod = importlib.import_module(".".join(import_dir))
|
|
||||||
# Import class
|
|
||||||
cls = getattr(mod, import_obj)
|
|
||||||
assert list(k) == cls.lc_id()
|
|
@ -540,6 +540,12 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
|||||||
"chat_models",
|
"chat_models",
|
||||||
"ChatSambaStudio",
|
"ChatSambaStudio",
|
||||||
),
|
),
|
||||||
|
("langchain_core", "prompts", "message", "_DictMessagePromptTemplate"): (
|
||||||
|
"langchain_core",
|
||||||
|
"prompts",
|
||||||
|
"dict",
|
||||||
|
"DictPromptTemplate",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Needed for backwards compatibility for old versions of LangChain where things
|
# Needed for backwards compatibility for old versions of LangChain where things
|
||||||
|
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
|||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
from langchain_core.prompts.dict import DictPromptTemplate
|
||||||
from langchain_core.prompts.few_shot import (
|
from langchain_core.prompts.few_shot import (
|
||||||
FewShotChatMessagePromptTemplate,
|
FewShotChatMessagePromptTemplate,
|
||||||
FewShotPromptTemplate,
|
FewShotPromptTemplate,
|
||||||
@ -68,6 +69,7 @@ __all__ = (
|
|||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
"ChatMessagePromptTemplate",
|
"ChatMessagePromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
|
"DictPromptTemplate",
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
"FewShotChatMessagePromptTemplate",
|
"FewShotChatMessagePromptTemplate",
|
||||||
@ -94,6 +96,7 @@ _dynamic_imports = {
|
|||||||
"BaseChatPromptTemplate": "chat",
|
"BaseChatPromptTemplate": "chat",
|
||||||
"ChatMessagePromptTemplate": "chat",
|
"ChatMessagePromptTemplate": "chat",
|
||||||
"ChatPromptTemplate": "chat",
|
"ChatPromptTemplate": "chat",
|
||||||
|
"DictPromptTemplate": "dict",
|
||||||
"HumanMessagePromptTemplate": "chat",
|
"HumanMessagePromptTemplate": "chat",
|
||||||
"MessagesPlaceholder": "chat",
|
"MessagesPlaceholder": "chat",
|
||||||
"SystemMessagePromptTemplate": "chat",
|
"SystemMessagePromptTemplate": "chat",
|
||||||
|
@ -37,10 +37,10 @@ from langchain_core.messages import (
|
|||||||
from langchain_core.messages.base import get_msg_title_repr
|
from langchain_core.messages.base import get_msg_title_repr
|
||||||
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.prompts.dict import DictPromptTemplate
|
||||||
from langchain_core.prompts.image import ImagePromptTemplate
|
from langchain_core.prompts.image import ImagePromptTemplate
|
||||||
from langchain_core.prompts.message import (
|
from langchain_core.prompts.message import (
|
||||||
BaseMessagePromptTemplate,
|
BaseMessagePromptTemplate,
|
||||||
_DictMessagePromptTemplate,
|
|
||||||
)
|
)
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import (
|
from langchain_core.prompts.string import (
|
||||||
@ -396,9 +396,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
|
|
||||||
prompt: Union[
|
prompt: Union[
|
||||||
StringPromptTemplate,
|
StringPromptTemplate,
|
||||||
list[
|
list[Union[StringPromptTemplate, ImagePromptTemplate, DictPromptTemplate]],
|
||||||
Union[StringPromptTemplate, ImagePromptTemplate, _DictMessagePromptTemplate]
|
|
||||||
],
|
|
||||||
]
|
]
|
||||||
"""Prompt template."""
|
"""Prompt template."""
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
@ -447,7 +445,12 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
prompt = []
|
prompt = []
|
||||||
for tmpl in template:
|
for tmpl in template:
|
||||||
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
if (
|
||||||
|
isinstance(tmpl, str)
|
||||||
|
or isinstance(tmpl, dict)
|
||||||
|
and "text" in tmpl
|
||||||
|
and set(tmpl.keys()) <= {"type", "text"}
|
||||||
|
):
|
||||||
if isinstance(tmpl, str):
|
if isinstance(tmpl, str):
|
||||||
text: str = tmpl
|
text: str = tmpl
|
||||||
else:
|
else:
|
||||||
@ -457,7 +460,15 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
text, template_format=template_format
|
text, template_format=template_format
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
elif (
|
||||||
|
isinstance(tmpl, dict)
|
||||||
|
and "image_url" in tmpl
|
||||||
|
and set(tmpl.keys())
|
||||||
|
<= {
|
||||||
|
"type",
|
||||||
|
"image_url",
|
||||||
|
}
|
||||||
|
):
|
||||||
img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
|
img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
|
||||||
input_variables = []
|
input_variables = []
|
||||||
if isinstance(img_template, str):
|
if isinstance(img_template, str):
|
||||||
@ -503,7 +514,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
"format."
|
"format."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
data_template_obj = _DictMessagePromptTemplate(
|
data_template_obj = DictPromptTemplate(
|
||||||
template=cast("dict[str, Any]", tmpl),
|
template=cast("dict[str, Any]", tmpl),
|
||||||
template_format=template_format,
|
template_format=template_format,
|
||||||
)
|
)
|
||||||
@ -592,7 +603,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
elif isinstance(prompt, ImagePromptTemplate):
|
elif isinstance(prompt, ImagePromptTemplate):
|
||||||
formatted = prompt.format(**inputs)
|
formatted = prompt.format(**inputs)
|
||||||
content.append({"type": "image_url", "image_url": formatted})
|
content.append({"type": "image_url", "image_url": formatted})
|
||||||
elif isinstance(prompt, _DictMessagePromptTemplate):
|
elif isinstance(prompt, DictPromptTemplate):
|
||||||
formatted = prompt.format(**inputs)
|
formatted = prompt.format(**inputs)
|
||||||
content.append(formatted)
|
content.append(formatted)
|
||||||
return self._msg_class(
|
return self._msg_class(
|
||||||
@ -624,7 +635,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
elif isinstance(prompt, ImagePromptTemplate):
|
elif isinstance(prompt, ImagePromptTemplate):
|
||||||
formatted = await prompt.aformat(**inputs)
|
formatted = await prompt.aformat(**inputs)
|
||||||
content.append({"type": "image_url", "image_url": formatted})
|
content.append({"type": "image_url", "image_url": formatted})
|
||||||
elif isinstance(prompt, _DictMessagePromptTemplate):
|
elif isinstance(prompt, DictPromptTemplate):
|
||||||
formatted = prompt.format(**inputs)
|
formatted = prompt.format(**inputs)
|
||||||
content.append(formatted)
|
content.append(formatted)
|
||||||
return self._msg_class(
|
return self._msg_class(
|
||||||
|
137
libs/core/langchain_core/prompts/dict.py
Normal file
137
libs/core/langchain_core/prompts/dict.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""Dict prompt template."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
from langchain_core.load import dumpd
|
||||||
|
from langchain_core.prompts.string import (
|
||||||
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
|
get_template_variables,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||||
|
from langchain_core.runnables.config import ensure_config
|
||||||
|
|
||||||
|
|
||||||
|
class DictPromptTemplate(RunnableSerializable[dict, dict]):
|
||||||
|
"""Template represented by a dict.
|
||||||
|
|
||||||
|
Recognizes variables in f-string or mustache formatted string dict values. Does NOT
|
||||||
|
recognize variables in dict keys. Applies recursively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
template: dict[str, Any]
|
||||||
|
template_format: Literal["f-string", "mustache"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_variables(self) -> list[str]:
|
||||||
|
"""Template input variables."""
|
||||||
|
return _get_input_variables(self.template, self.template_format)
|
||||||
|
|
||||||
|
def format(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
"""Format the prompt with the inputs."""
|
||||||
|
return _insert_input_variables(self.template, kwargs, self.template_format)
|
||||||
|
|
||||||
|
async def aformat(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
"""Format the prompt with the inputs."""
|
||||||
|
return self.format(**kwargs)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> dict:
|
||||||
|
"""Invoke the prompt."""
|
||||||
|
return self._call_with_config(
|
||||||
|
lambda x: self.format(**x),
|
||||||
|
input,
|
||||||
|
ensure_config(config),
|
||||||
|
run_type="prompt",
|
||||||
|
serialized=self._serialized,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _prompt_type(self) -> str:
|
||||||
|
return "dict-prompt"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _serialized(self) -> dict[str, Any]:
|
||||||
|
return dumpd(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether or not the class is serializable.
|
||||||
|
|
||||||
|
Returns: True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
|
"""Serialization namespace."""
|
||||||
|
return ["langchain_core", "prompts", "dict"]
|
||||||
|
|
||||||
|
def pretty_repr(self, *, html: bool = False) -> str:
|
||||||
|
"""Human-readable representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html: Whether to format as HTML. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable representation.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _get_input_variables(
|
||||||
|
template: dict, template_format: Literal["f-string", "mustache"]
|
||||||
|
) -> list[str]:
|
||||||
|
input_variables = []
|
||||||
|
for v in template.values():
|
||||||
|
if isinstance(v, str):
|
||||||
|
input_variables += get_template_variables(v, template_format)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
input_variables += _get_input_variables(v, template_format)
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
for x in v:
|
||||||
|
if isinstance(x, str):
|
||||||
|
input_variables += get_template_variables(x, template_format)
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
input_variables += _get_input_variables(x, template_format)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return list(set(input_variables))
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_input_variables(
|
||||||
|
template: dict[str, Any],
|
||||||
|
inputs: dict[str, Any],
|
||||||
|
template_format: Literal["f-string", "mustache"],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
formatted = {}
|
||||||
|
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
|
||||||
|
for k, v in template.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
formatted[k] = formatter(v, **inputs)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
if k == "image_url" and "path" in v:
|
||||||
|
msg = (
|
||||||
|
"Specifying image inputs via file path in environments with "
|
||||||
|
"user-input paths is a security vulnerability. Out of an abundance "
|
||||||
|
"of caution, the utility has been removed to prevent possible "
|
||||||
|
"misuse."
|
||||||
|
)
|
||||||
|
warnings.warn(msg, stacklevel=2)
|
||||||
|
formatted[k] = _insert_input_variables(v, inputs, template_format)
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
formatted_v = []
|
||||||
|
for x in v:
|
||||||
|
if isinstance(x, str):
|
||||||
|
formatted_v.append(formatter(x, **inputs))
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
formatted_v.append(
|
||||||
|
_insert_input_variables(x, inputs, template_format)
|
||||||
|
)
|
||||||
|
formatted[k] = type(v)(formatted_v)
|
||||||
|
else:
|
||||||
|
formatted[k] = v
|
||||||
|
return formatted
|
@ -3,14 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain_core.load import Serializable
|
from langchain_core.load import Serializable
|
||||||
from langchain_core.messages import BaseMessage, convert_to_messages
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.prompts.string import (
|
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
|
||||||
get_template_variables,
|
|
||||||
)
|
|
||||||
from langchain_core.utils.interactive_env import is_interactive_env
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -98,89 +94,3 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
|||||||
|
|
||||||
prompt = ChatPromptTemplate(messages=[self])
|
prompt = ChatPromptTemplate(messages=[self])
|
||||||
return prompt + other
|
return prompt + other
|
||||||
|
|
||||||
|
|
||||||
class _DictMessagePromptTemplate(BaseMessagePromptTemplate):
|
|
||||||
"""Template represented by a dict that recursively fills input vars in string vals.
|
|
||||||
|
|
||||||
Special handling of image_url dicts to load local paths. These look like:
|
|
||||||
``{"type": "image_url", "image_url": {"path": "..."}}``
|
|
||||||
"""
|
|
||||||
|
|
||||||
template: dict[str, Any]
|
|
||||||
template_format: Literal["f-string", "mustache"]
|
|
||||||
|
|
||||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
|
||||||
msg_dict = _insert_input_variables(self.template, kwargs, self.template_format)
|
|
||||||
return convert_to_messages([msg_dict])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_variables(self) -> list[str]:
|
|
||||||
return _get_input_variables(self.template, self.template_format)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _prompt_type(self) -> str:
|
|
||||||
return "message-dict-prompt"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lc_namespace(cls) -> list[str]:
|
|
||||||
return ["langchain_core", "prompts", "message"]
|
|
||||||
|
|
||||||
def format(
|
|
||||||
self,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Format the prompt with the inputs."""
|
|
||||||
return _insert_input_variables(self.template, kwargs, self.template_format)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_input_variables(
|
|
||||||
template: dict, template_format: Literal["f-string", "mustache"]
|
|
||||||
) -> list[str]:
|
|
||||||
input_variables = []
|
|
||||||
for v in template.values():
|
|
||||||
if isinstance(v, str):
|
|
||||||
input_variables += get_template_variables(v, template_format)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
input_variables += _get_input_variables(v, template_format)
|
|
||||||
elif isinstance(v, (list, tuple)):
|
|
||||||
for x in v:
|
|
||||||
if isinstance(x, str):
|
|
||||||
input_variables += get_template_variables(x, template_format)
|
|
||||||
elif isinstance(x, dict):
|
|
||||||
input_variables += _get_input_variables(x, template_format)
|
|
||||||
return list(set(input_variables))
|
|
||||||
|
|
||||||
|
|
||||||
def _insert_input_variables(
|
|
||||||
template: dict[str, Any],
|
|
||||||
inputs: dict[str, Any],
|
|
||||||
template_format: Literal["f-string", "mustache"],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
formatted = {}
|
|
||||||
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
|
|
||||||
for k, v in template.items():
|
|
||||||
if isinstance(v, str):
|
|
||||||
formatted[k] = formatter(v, **inputs)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
# No longer support loading local images.
|
|
||||||
if k == "image_url" and "path" in v:
|
|
||||||
msg = (
|
|
||||||
"Specifying image inputs via file path in environments with "
|
|
||||||
"user-input paths is a security vulnerability. Out of an abundance "
|
|
||||||
"of caution, the utility has been removed to prevent possible "
|
|
||||||
"misuse."
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
formatted[k] = _insert_input_variables(v, inputs, template_format)
|
|
||||||
elif isinstance(v, (list, tuple)):
|
|
||||||
formatted_v = []
|
|
||||||
for x in v:
|
|
||||||
if isinstance(x, str):
|
|
||||||
formatted_v.append(formatter(x, **inputs))
|
|
||||||
elif isinstance(x, dict):
|
|
||||||
formatted_v.append(
|
|
||||||
_insert_input_variables(x, inputs, template_format)
|
|
||||||
)
|
|
||||||
formatted[k] = type(v)(formatted_v)
|
|
||||||
return formatted
|
|
||||||
|
@ -3135,6 +3135,27 @@
|
|||||||
'name': 'PromptTemplate',
|
'name': 'PromptTemplate',
|
||||||
'type': 'constructor',
|
'type': 'constructor',
|
||||||
}),
|
}),
|
||||||
|
dict({
|
||||||
|
'id': list([
|
||||||
|
'langchain_core',
|
||||||
|
'prompts',
|
||||||
|
'dict',
|
||||||
|
'DictPromptTemplate',
|
||||||
|
]),
|
||||||
|
'kwargs': dict({
|
||||||
|
'template': dict({
|
||||||
|
'cache_control': dict({
|
||||||
|
'type': '{foo}',
|
||||||
|
}),
|
||||||
|
'text': "What's in this image?",
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
'template_format': 'f-string',
|
||||||
|
}),
|
||||||
|
'lc': 1,
|
||||||
|
'name': 'DictPromptTemplate',
|
||||||
|
'type': 'constructor',
|
||||||
|
}),
|
||||||
dict({
|
dict({
|
||||||
'id': list([
|
'id': list([
|
||||||
'langchain',
|
'langchain',
|
||||||
|
@ -973,6 +973,11 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
|
|||||||
"hello",
|
"hello",
|
||||||
{"text": "What's in this image?"},
|
{"text": "What's in this image?"},
|
||||||
{"type": "text", "text": "What's in this image?"},
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?",
|
||||||
|
"cache_control": {"type": "{foo}"},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": "data:image/jpeg;base64,{my_image}",
|
"image_url": "data:image/jpeg;base64,{my_image}",
|
||||||
@ -1012,7 +1017,7 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
|
|||||||
@pytest.mark.xfail(
|
@pytest.mark.xfail(
|
||||||
reason=(
|
reason=(
|
||||||
"In a breaking release, we can update `_convert_to_message_template` to use "
|
"In a breaking release, we can update `_convert_to_message_template` to use "
|
||||||
"_DictMessagePromptTemplate for all `dict` inputs, allowing for templatization "
|
"DictPromptTemplate for all `dict` inputs, allowing for templatization "
|
||||||
"of message attributes outside content blocks. That would enable the below "
|
"of message attributes outside content blocks. That would enable the below "
|
||||||
"test to pass."
|
"test to pass."
|
||||||
)
|
)
|
||||||
|
34
libs/core/tests/unit_tests/prompts/test_dict.py
Normal file
34
libs/core/tests/unit_tests/prompts/test_dict.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from langchain_core.load import load
|
||||||
|
from langchain_core.prompts.dict import DictPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
def test__dict_message_prompt_template_fstring() -> None:
|
||||||
|
template = {
|
||||||
|
"type": "text",
|
||||||
|
"text": "{text1}",
|
||||||
|
"cache_control": {"type": "{cache_type}"},
|
||||||
|
}
|
||||||
|
prompt = DictPromptTemplate(template=template, template_format="f-string")
|
||||||
|
expected = {
|
||||||
|
"type": "text",
|
||||||
|
"text": "important message",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
actual = prompt.format(text1="important message", cache_type="ephemeral")
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_deserialize_legacy() -> None:
|
||||||
|
ser = {
|
||||||
|
"type": "constructor",
|
||||||
|
"lc": 1,
|
||||||
|
"id": ["langchain_core", "prompts", "message", "_DictMessagePromptTemplate"],
|
||||||
|
"kwargs": {
|
||||||
|
"template_format": "f-string",
|
||||||
|
"template": {"type": "audio", "audio": "{audio_data}"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expected = DictPromptTemplate(
|
||||||
|
template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string"
|
||||||
|
)
|
||||||
|
assert load(ser) == expected
|
@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
|||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
"ChatMessagePromptTemplate",
|
"ChatMessagePromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
|
"DictPromptTemplate",
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
"FewShotChatMessagePromptTemplate",
|
"FewShotChatMessagePromptTemplate",
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
||||||
from langchain_core.prompts.message import _DictMessagePromptTemplate
|
|
||||||
|
|
||||||
CUR_DIR = Path(__file__).parent.absolute().resolve()
|
|
||||||
|
|
||||||
|
|
||||||
def test__dict_message_prompt_template_fstring() -> None:
|
|
||||||
template = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "{text1}", "cache_control": {"type": "ephemeral"}},
|
|
||||||
],
|
|
||||||
"name": "{name1}",
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"name": "{tool_name1}",
|
|
||||||
"args": {"arg1": "{tool_arg1}"},
|
|
||||||
"id": "1",
|
|
||||||
"type": "tool_call",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
prompt = _DictMessagePromptTemplate(template=template, template_format="f-string")
|
|
||||||
expected: BaseMessage = AIMessage(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "important message",
|
|
||||||
"cache_control": {"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
name="foo",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"name": "do_stuff",
|
|
||||||
"args": {"arg1": "important arg1"},
|
|
||||||
"id": "1",
|
|
||||||
"type": "tool_call",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
actual = prompt.format_messages(
|
|
||||||
text1="important message",
|
|
||||||
name1="foo",
|
|
||||||
tool_arg1="important arg1",
|
|
||||||
tool_name1="do_stuff",
|
|
||||||
)[0]
|
|
||||||
assert actual == expected
|
|
||||||
|
|
||||||
template = {
|
|
||||||
"role": "tool",
|
|
||||||
"content": "{content1}",
|
|
||||||
"tool_call_id": "1",
|
|
||||||
"name": "{name1}",
|
|
||||||
}
|
|
||||||
prompt = _DictMessagePromptTemplate(template=template, template_format="f-string")
|
|
||||||
expected = ToolMessage("foo", name="bar", tool_call_id="1")
|
|
||||||
actual = prompt.format_messages(content1="foo", name1="bar")[0]
|
|
||||||
assert actual == expected
|
|
Loading…
Reference in New Issue
Block a user