diff --git a/libs/community/tests/unit_tests/load/test_serializable.py b/libs/community/tests/unit_tests/load/test_serializable.py deleted file mode 100644 index 2d436fc53be..00000000000 --- a/libs/community/tests/unit_tests/load/test_serializable.py +++ /dev/null @@ -1,155 +0,0 @@ -import importlib -import inspect -import pkgutil -from types import ModuleType - -from langchain_core.load.mapping import SERIALIZABLE_MAPPING - - -def import_all_modules(package_name: str) -> dict: - package = importlib.import_module(package_name) - classes: dict = {} - - def _handle_module(module: ModuleType) -> None: - # Iterate over all members of the module - - names = dir(module) - - if hasattr(module, "__all__"): - names += list(module.__all__) - - names = sorted(set(names)) - - for name in names: - # Check if it's a class or function - attr = getattr(module, name) - - if not inspect.isclass(attr): - continue - - if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type): - continue - - if ( - isinstance(attr.is_lc_serializable(), bool) - and attr.is_lc_serializable() - ): - key = tuple(attr.lc_id()) - value = tuple(attr.__module__.split(".") + [attr.__name__]) - if key in classes and classes[key] != value: - raise ValueError - classes[key] = value - - _handle_module(package) - - for importer, modname, ispkg in pkgutil.walk_packages( - package.__path__, package.__name__ + "." - ): - try: - module = importlib.import_module(modname) - except ModuleNotFoundError: - continue - _handle_module(module) - - return classes - - -def test_import_all_modules() -> None: - """Test import all modules works as expected""" - all_modules = import_all_modules("langchain") - filtered_modules = [ - k - for k in all_modules - if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models") - ] - # This test will need to be updated if new serializable classes are added - # to community - assert sorted(filtered_modules) == sorted( - [ - ("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"), - ("langchain", "chat_models", "bedrock", "BedrockChat"), - ("langchain", "chat_models", "anthropic", "ChatAnthropic"), - ("langchain", "chat_models", "fireworks", "ChatFireworks"), - ("langchain", "chat_models", "google_palm", "ChatGooglePalm"), - ("langchain", "chat_models", "openai", "ChatOpenAI"), - ("langchain", "chat_models", "vertexai", "ChatVertexAI"), - ] - ) - - -def test_serializable_mapping() -> None: - to_skip = { - # This should have had a different namespace, as it was never - # exported from the langchain module, but we keep for whoever has - # already serialized it. - ("langchain", "prompts", "image", "ImagePromptTemplate"): ( - "langchain_core", - "prompts", - "image", - "ImagePromptTemplate", - ), - # This is not exported from langchain, only langchain_core - ("langchain_core", "prompts", "structured", "StructuredPrompt"): ( - "langchain_core", - "prompts", - "structured", - "StructuredPrompt", - ), - # This is not exported from langchain, only langchain_core - ("langchain", "schema", "messages", "RemoveMessage"): ( - "langchain_core", - "messages", - "modifier", - "RemoveMessage", - ), - ("langchain", "chat_models", "mistralai", "ChatMistralAI"): ( - "langchain_mistralai", - "chat_models", - "ChatMistralAI", - ), - ("langchain_groq", "chat_models", "ChatGroq"): ( - "langchain_groq", - "chat_models", - "ChatGroq", - ), - ("langchain_sambanova", "chat_models", "ChatSambaNovaCloud"): ( - "langchain_sambanova", - "chat_models", - "ChatSambaNovaCloud", - ), - ("langchain_sambanova", "chat_models", "ChatSambaStudio"): ( - "langchain_sambanova", - "chat_models", - "ChatSambaStudio", - ), - # TODO(0.3): For now we're skipping the below two tests. Need to fix - # so that it only runs when langchain-aws, langchain-google-genai - # are installed. - ("langchain", "chat_models", "bedrock", "ChatBedrock"): ( - "langchain_aws", - "chat_models", - "bedrock", - "ChatBedrock", - ), - ("langchain_google_genai", "chat_models", "ChatGoogleGenerativeAI"): ( - "langchain_google_genai", - "chat_models", - "ChatGoogleGenerativeAI", - ), - } - serializable_modules = import_all_modules("langchain") - - missing = set(SERIALIZABLE_MAPPING).difference( - set(serializable_modules).union(to_skip) - ) - assert missing == set() - extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING) - assert extra == set() - - for k, import_path in serializable_modules.items(): - import_dir, import_obj = import_path[:-1], import_path[-1] - # Import module - mod = importlib.import_module(".".join(import_dir)) - # Import class - cls = getattr(mod, import_obj) - assert list(k) == cls.lc_id() diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index b4201c18b30..aee837b539d 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -540,6 +540,12 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { "chat_models", "ChatSambaStudio", ), + ("langchain_core", "prompts", "message", "_DictMessagePromptTemplate"): ( + "langchain_core", + "prompts", + "dict", + "DictPromptTemplate", + ), } # Needed for backwards compatibility for old versions of LangChain where things diff --git a/libs/core/langchain_core/prompts/__init__.py b/libs/core/langchain_core/prompts/__init__.py index 706b3b6f8fd..61f1ceb7f9f 100644 --- a/libs/core/langchain_core/prompts/__init__.py +++ b/libs/core/langchain_core/prompts/__init__.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: MessagesPlaceholder, SystemMessagePromptTemplate, ) + from langchain_core.prompts.dict import DictPromptTemplate from langchain_core.prompts.few_shot import ( FewShotChatMessagePromptTemplate, FewShotPromptTemplate, @@ -68,6 +69,7 @@ __all__ = ( "BasePromptTemplate", "ChatMessagePromptTemplate", "ChatPromptTemplate", + "DictPromptTemplate", "FewShotPromptTemplate", "FewShotPromptWithTemplates", "FewShotChatMessagePromptTemplate", @@ -94,6 +96,7 @@ _dynamic_imports = { "BaseChatPromptTemplate": "chat", "ChatMessagePromptTemplate": "chat", "ChatPromptTemplate": "chat", + "DictPromptTemplate": "dict", "HumanMessagePromptTemplate": "chat", "MessagesPlaceholder": "chat", "SystemMessagePromptTemplate": "chat", diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 9ed0de4beff..b80954336c5 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -37,10 +37,10 @@ from langchain_core.messages import ( from langchain_core.messages.base import get_msg_title_repr from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.dict import DictPromptTemplate from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.message import ( BaseMessagePromptTemplate, - _DictMessagePromptTemplate, ) from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import ( @@ -396,9 +396,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): prompt: Union[ StringPromptTemplate, - list[ - Union[StringPromptTemplate, ImagePromptTemplate, _DictMessagePromptTemplate] - ], + list[Union[StringPromptTemplate, ImagePromptTemplate, DictPromptTemplate]], ] """Prompt template.""" additional_kwargs: dict = Field(default_factory=dict) @@ -447,7 +445,12 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): raise ValueError(msg) prompt = [] for tmpl in template: - if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: + if ( + isinstance(tmpl, str) + or isinstance(tmpl, dict) + and "text" in tmpl + and set(tmpl.keys()) <= {"type", "text"} + ): if isinstance(tmpl, str): text: str = tmpl else: @@ -457,7 +460,15 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): text, template_format=template_format ) ) - elif isinstance(tmpl, dict) and "image_url" in tmpl: + elif ( + isinstance(tmpl, dict) + and "image_url" in tmpl + and set(tmpl.keys()) + <= { + "type", + "image_url", + } + ): img_template = cast("_ImageTemplateParam", tmpl)["image_url"] input_variables = [] if isinstance(img_template, str): @@ -503,7 +514,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): "format." ) raise ValueError(msg) - data_template_obj = _DictMessagePromptTemplate( + data_template_obj = DictPromptTemplate( template=cast("dict[str, Any]", tmpl), template_format=template_format, ) @@ -592,7 +603,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): elif isinstance(prompt, ImagePromptTemplate): formatted = prompt.format(**inputs) content.append({"type": "image_url", "image_url": formatted}) - elif isinstance(prompt, _DictMessagePromptTemplate): + elif isinstance(prompt, DictPromptTemplate): formatted = prompt.format(**inputs) content.append(formatted) return self._msg_class( @@ -624,7 +635,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): elif isinstance(prompt, ImagePromptTemplate): formatted = await prompt.aformat(**inputs) content.append({"type": "image_url", "image_url": formatted}) - elif isinstance(prompt, _DictMessagePromptTemplate): + elif isinstance(prompt, DictPromptTemplate): formatted = prompt.format(**inputs) content.append(formatted) return self._msg_class( diff --git a/libs/core/langchain_core/prompts/dict.py b/libs/core/langchain_core/prompts/dict.py new file mode 100644 index 00000000000..0ccdf7a64a0 --- /dev/null +++ b/libs/core/langchain_core/prompts/dict.py @@ -0,0 +1,137 @@ +"""Dict prompt template.""" + +import warnings +from functools import cached_property +from typing import Any, Literal, Optional + +from langchain_core.load import dumpd +from langchain_core.prompts.string import ( + DEFAULT_FORMATTER_MAPPING, + get_template_variables, +) +from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.runnables.config import ensure_config + + +class DictPromptTemplate(RunnableSerializable[dict, dict]): + """Template represented by a dict. + + Recognizes variables in f-string or mustache formatted string dict values. Does NOT + recognize variables in dict keys. Applies recursively. + """ + + template: dict[str, Any] + template_format: Literal["f-string", "mustache"] + + @property + def input_variables(self) -> list[str]: + """Template input variables.""" + return _get_input_variables(self.template, self.template_format) + + def format(self, **kwargs: Any) -> dict[str, Any]: + """Format the prompt with the inputs.""" + return _insert_input_variables(self.template, kwargs, self.template_format) + + async def aformat(self, **kwargs: Any) -> dict[str, Any]: + """Format the prompt with the inputs.""" + return self.format(**kwargs) + + def invoke( + self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> dict: + """Invoke the prompt.""" + return self._call_with_config( + lambda x: self.format(**x), + input, + ensure_config(config), + run_type="prompt", + serialized=self._serialized, + **kwargs, + ) + + @property + def _prompt_type(self) -> str: + return "dict-prompt" + + @cached_property + def _serialized(self) -> dict[str, Any]: + return dumpd(self) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether or not the class is serializable. + + Returns: True. + """ + return True + + @classmethod + def get_lc_namespace(cls) -> list[str]: + """Serialization namespace.""" + return ["langchain_core", "prompts", "dict"] + + def pretty_repr(self, *, html: bool = False) -> str: + """Human-readable representation. + + Args: + html: Whether to format as HTML. Defaults to False. + + Returns: + Human-readable representation. + """ + raise NotImplementedError + + +def _get_input_variables( + template: dict, template_format: Literal["f-string", "mustache"] +) -> list[str]: + input_variables = [] + for v in template.values(): + if isinstance(v, str): + input_variables += get_template_variables(v, template_format) + elif isinstance(v, dict): + input_variables += _get_input_variables(v, template_format) + elif isinstance(v, (list, tuple)): + for x in v: + if isinstance(x, str): + input_variables += get_template_variables(x, template_format) + elif isinstance(x, dict): + input_variables += _get_input_variables(x, template_format) + else: + pass + return list(set(input_variables)) + + +def _insert_input_variables( + template: dict[str, Any], + inputs: dict[str, Any], + template_format: Literal["f-string", "mustache"], +) -> dict[str, Any]: + formatted = {} + formatter = DEFAULT_FORMATTER_MAPPING[template_format] + for k, v in template.items(): + if isinstance(v, str): + formatted[k] = formatter(v, **inputs) + elif isinstance(v, dict): + if k == "image_url" and "path" in v: + msg = ( + "Specifying image inputs via file path in environments with " + "user-input paths is a security vulnerability. Out of an abundance " + "of caution, the utility has been removed to prevent possible " + "misuse." + ) + warnings.warn(msg, stacklevel=2) + formatted[k] = _insert_input_variables(v, inputs, template_format) + elif isinstance(v, (list, tuple)): + formatted_v = [] + for x in v: + if isinstance(x, str): + formatted_v.append(formatter(x, **inputs)) + elif isinstance(x, dict): + formatted_v.append( + _insert_input_variables(x, inputs, template_format) + ) + formatted[k] = type(v)(formatted_v) + else: + formatted[k] = v + return formatted diff --git a/libs/core/langchain_core/prompts/message.py b/libs/core/langchain_core/prompts/message.py index 523a302e1c8..668374a19f0 100644 --- a/libs/core/langchain_core/prompts/message.py +++ b/libs/core/langchain_core/prompts/message.py @@ -3,14 +3,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any from langchain_core.load import Serializable -from langchain_core.messages import BaseMessage, convert_to_messages -from langchain_core.prompts.string import ( - DEFAULT_FORMATTER_MAPPING, - get_template_variables, -) +from langchain_core.messages import BaseMessage from langchain_core.utils.interactive_env import is_interactive_env if TYPE_CHECKING: @@ -98,89 +94,3 @@ class BaseMessagePromptTemplate(Serializable, ABC): prompt = ChatPromptTemplate(messages=[self]) return prompt + other - - -class _DictMessagePromptTemplate(BaseMessagePromptTemplate): - """Template represented by a dict that recursively fills input vars in string vals. - - Special handling of image_url dicts to load local paths. These look like: - ``{"type": "image_url", "image_url": {"path": "..."}}`` - """ - - template: dict[str, Any] - template_format: Literal["f-string", "mustache"] - - def format_messages(self, **kwargs: Any) -> list[BaseMessage]: - msg_dict = _insert_input_variables(self.template, kwargs, self.template_format) - return convert_to_messages([msg_dict]) - - @property - def input_variables(self) -> list[str]: - return _get_input_variables(self.template, self.template_format) - - @property - def _prompt_type(self) -> str: - return "message-dict-prompt" - - @classmethod - def get_lc_namespace(cls) -> list[str]: - return ["langchain_core", "prompts", "message"] - - def format( - self, - **kwargs: Any, - ) -> dict[str, Any]: - """Format the prompt with the inputs.""" - return _insert_input_variables(self.template, kwargs, self.template_format) - - -def _get_input_variables( - template: dict, template_format: Literal["f-string", "mustache"] -) -> list[str]: - input_variables = [] - for v in template.values(): - if isinstance(v, str): - input_variables += get_template_variables(v, template_format) - elif isinstance(v, dict): - input_variables += _get_input_variables(v, template_format) - elif isinstance(v, (list, tuple)): - for x in v: - if isinstance(x, str): - input_variables += get_template_variables(x, template_format) - elif isinstance(x, dict): - input_variables += _get_input_variables(x, template_format) - return list(set(input_variables)) - - -def _insert_input_variables( - template: dict[str, Any], - inputs: dict[str, Any], - template_format: Literal["f-string", "mustache"], -) -> dict[str, Any]: - formatted = {} - formatter = DEFAULT_FORMATTER_MAPPING[template_format] - for k, v in template.items(): - if isinstance(v, str): - formatted[k] = formatter(v, **inputs) - elif isinstance(v, dict): - # No longer support loading local images. - if k == "image_url" and "path" in v: - msg = ( - "Specifying image inputs via file path in environments with " - "user-input paths is a security vulnerability. Out of an abundance " - "of caution, the utility has been removed to prevent possible " - "misuse." - ) - raise ValueError(msg) - formatted[k] = _insert_input_variables(v, inputs, template_format) - elif isinstance(v, (list, tuple)): - formatted_v = [] - for x in v: - if isinstance(x, str): - formatted_v.append(formatter(x, **inputs)) - elif isinstance(x, dict): - formatted_v.append( - _insert_input_variables(x, inputs, template_format) - ) - formatted[k] = type(v)(formatted_v) - return formatted diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 7a24cf9dc65..23bce1d65fb 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -3135,6 +3135,27 @@ 'name': 'PromptTemplate', 'type': 'constructor', }), + dict({ + 'id': list([ + 'langchain_core', + 'prompts', + 'dict', + 'DictPromptTemplate', + ]), + 'kwargs': dict({ + 'template': dict({ + 'cache_control': dict({ + 'type': '{foo}', + }), + 'text': "What's in this image?", + 'type': 'text', + }), + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'DictPromptTemplate', + 'type': 'constructor', + }), dict({ 'id': list([ 'langchain', diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index b63c05a0e14..8730ee1bb69 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -973,6 +973,11 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None: "hello", {"text": "What's in this image?"}, {"type": "text", "text": "What's in this image?"}, + { + "type": "text", + "text": "What's in this image?", + "cache_control": {"type": "{foo}"}, + }, { "type": "image_url", "image_url": "data:image/jpeg;base64,{my_image}", @@ -1012,7 +1017,7 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None: @pytest.mark.xfail( reason=( "In a breaking release, we can update `_convert_to_message_template` to use " - "_DictMessagePromptTemplate for all `dict` inputs, allowing for templatization " + "DictPromptTemplate for all `dict` inputs, allowing for templatization " "of message attributes outside content blocks. That would enable the below " "test to pass." ) diff --git a/libs/core/tests/unit_tests/prompts/test_dict.py b/libs/core/tests/unit_tests/prompts/test_dict.py new file mode 100644 index 00000000000..581e418b6b5 --- /dev/null +++ b/libs/core/tests/unit_tests/prompts/test_dict.py @@ -0,0 +1,34 @@ +from langchain_core.load import load +from langchain_core.prompts.dict import DictPromptTemplate + + +def test__dict_message_prompt_template_fstring() -> None: + template = { + "type": "text", + "text": "{text1}", + "cache_control": {"type": "{cache_type}"}, + } + prompt = DictPromptTemplate(template=template, template_format="f-string") + expected = { + "type": "text", + "text": "important message", + "cache_control": {"type": "ephemeral"}, + } + actual = prompt.format(text1="important message", cache_type="ephemeral") + assert actual == expected + + +def test_deserialize_legacy() -> None: + ser = { + "type": "constructor", + "lc": 1, + "id": ["langchain_core", "prompts", "message", "_DictMessagePromptTemplate"], + "kwargs": { + "template_format": "f-string", + "template": {"type": "audio", "audio": "{audio_data}"}, + }, + } + expected = DictPromptTemplate( + template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string" + ) + assert load(ser) == expected diff --git a/libs/core/tests/unit_tests/prompts/test_imports.py b/libs/core/tests/unit_tests/prompts/test_imports.py index a3a43f8957b..be33b06338f 100644 --- a/libs/core/tests/unit_tests/prompts/test_imports.py +++ b/libs/core/tests/unit_tests/prompts/test_imports.py @@ -6,6 +6,7 @@ EXPECTED_ALL = [ "BasePromptTemplate", "ChatMessagePromptTemplate", "ChatPromptTemplate", + "DictPromptTemplate", "FewShotPromptTemplate", "FewShotPromptWithTemplates", "FewShotChatMessagePromptTemplate", diff --git a/libs/core/tests/unit_tests/prompts/test_message.py b/libs/core/tests/unit_tests/prompts/test_message.py deleted file mode 100644 index 2479e31fe33..00000000000 --- a/libs/core/tests/unit_tests/prompts/test_message.py +++ /dev/null @@ -1,61 +0,0 @@ -from pathlib import Path - -from langchain_core.messages import AIMessage, BaseMessage, ToolMessage -from langchain_core.prompts.message import _DictMessagePromptTemplate - -CUR_DIR = Path(__file__).parent.absolute().resolve() - - -def test__dict_message_prompt_template_fstring() -> None: - template = { - "role": "assistant", - "content": [ - {"type": "text", "text": "{text1}", "cache_control": {"type": "ephemeral"}}, - ], - "name": "{name1}", - "tool_calls": [ - { - "name": "{tool_name1}", - "args": {"arg1": "{tool_arg1}"}, - "id": "1", - "type": "tool_call", - } - ], - } - prompt = _DictMessagePromptTemplate(template=template, template_format="f-string") - expected: BaseMessage = AIMessage( - [ - { - "type": "text", - "text": "important message", - "cache_control": {"type": "ephemeral"}, - }, - ], - name="foo", - tool_calls=[ - { - "name": "do_stuff", - "args": {"arg1": "important arg1"}, - "id": "1", - "type": "tool_call", - } - ], - ) - actual = prompt.format_messages( - text1="important message", - name1="foo", - tool_arg1="important arg1", - tool_name1="do_stuff", - )[0] - assert actual == expected - - template = { - "role": "tool", - "content": "{content1}", - "tool_call_id": "1", - "name": "{name1}", - } - prompt = _DictMessagePromptTemplate(template=template, template_format="f-string") - expected = ToolMessage("foo", name="bar", tool_call_id="1") - actual = prompt.format_messages(content1="foo", name1="bar")[0] - assert actual == expected