diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index 7ac3c7d3e38..ff9834df27b 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -115,6 +115,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { "chat", "SystemMessagePromptTemplate", ), + ("langchain", "prompts", "image", "ImagePromptTemplate"): ( + "langchain_core", + "prompts", + "image", + "ImagePromptTemplate", + ), ("langchain", "schema", "agent", "AgentActionMessageLog"): ( "langchain_core", "agents", @@ -510,6 +516,12 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { "system", "SystemMessage", ), + ("langchain", "schema", "prompt_template", "ImagePromptTemplate"): ( + "langchain_core", + "prompts", + "image", + "ImagePromptTemplate", + ), } # Needed for backwards compatibility for a few versions where we serialized diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index 3a3b16117e4..d320690924a 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue from langchain_core.prompts.base import BasePromptTemplate @@ -30,6 +30,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]): """Return the prompt type key.""" return "image-prompt" + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "prompts", "image"] + def format_prompt(self, **kwargs: Any) -> PromptValue: """Create Chat Messages.""" return ImagePromptValue(image_url=self.format(**kwargs)) diff --git a/libs/core/tests/unit_tests/prompts/test_image.py b/libs/core/tests/unit_tests/prompts/test_image.py new file mode 100644 index 00000000000..746c099b803 --- /dev/null +++ b/libs/core/tests/unit_tests/prompts/test_image.py @@ -0,0 +1,109 @@ +import json + +from langchain_core.load import dump, loads +from langchain_core.prompts import ChatPromptTemplate + + +def test_image_prompt_template_deserializable() -> None: + """Test that the image prompt template is serializable.""" + loads( + dump.dumps( + ChatPromptTemplate.from_messages( + [("system", [{"type": "image", "image_url": "{img}"}])] + ) + ) + ) + + +def test_image_prompt_template_deserializable_old() -> None: + """Test that the image prompt template is serializable.""" + loads( + json.dumps( + { + "lc": 1, + "type": "constructor", + "id": ["langchain", "prompts", "chat", "ChatPromptTemplate"], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate", + ], + "kwargs": { + "prompt": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate", + ], + "kwargs": { + "template": "Foo", + "input_variables": [], + "template_format": "f-string", + "partial_variables": {}, + }, + } + ] + }, + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate", + ], + "kwargs": { + "prompt": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "image", + "ImagePromptTemplate", + ], + "kwargs": { + "template": { + "url": "data:image/png;base64,{img}" + }, + "input_variables": ["img"], + }, + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate", + ], + "kwargs": { + "template": "{input}", + "input_variables": ["input"], + "template_format": "f-string", + "partial_variables": {}, + }, + }, + ] + }, + }, + ], + "input_variables": ["img", "input"], + }, + } + ) + ) diff --git a/libs/langchain/tests/unit_tests/load/test_serializable.py b/libs/langchain/tests/unit_tests/load/test_serializable.py index 198d8f1d465..64d762035f0 100644 --- a/libs/langchain/tests/unit_tests/load/test_serializable.py +++ b/libs/langchain/tests/unit_tests/load/test_serializable.py @@ -40,8 +40,21 @@ def import_all_modules(package_name: str) -> dict: def test_serializable_mapping() -> None: + # This should have had a different namespace, as it was never + # exported from the langchain module, but we keep for whoever has + # already serialized it. + to_skip = { + ("langchain", "prompts", "image", "ImagePromptTemplate"): ( + "langchain_core", + "prompts", + "image", + "ImagePromptTemplate", + ), + } serializable_modules = import_all_modules("langchain") - missing = set(SERIALIZABLE_MAPPING).difference(serializable_modules) + missing = set(SERIALIZABLE_MAPPING).difference( + set(serializable_modules).union(to_skip) + ) assert missing == set() extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING) assert extra == set()