diff --git a/libs/community/tests/unit_tests/load/test_serializable.py b/libs/community/tests/unit_tests/load/test_serializable.py index 613e3f01d8a..7ca4d27da91 100644 --- a/libs/community/tests/unit_tests/load/test_serializable.py +++ b/libs/community/tests/unit_tests/load/test_serializable.py @@ -95,6 +95,13 @@ def test_serializable_mapping() -> None: "structured", "StructuredPrompt", ), + # This is not exported from langchain, only langchain_core + ("langchain", "schema", "messages", "RemoveMessage"): ( + "langchain_core", + "messages", + "modifier", + "RemoveMessage", + ), } serializable_modules = import_all_modules("langchain") diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index 01e070895e7..69f0371289a 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -76,6 +76,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { "tool", "ToolMessage", ), + ("langchain", "schema", "messages", "RemoveMessage"): ( + "langchain_core", + "messages", + "modifier", + "RemoveMessage", + ), ("langchain", "schema", "agent", "AgentAction"): ( "langchain_core", "agents", diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py index 0d6d7e7d97b..4a28feedd8f 100644 --- a/libs/core/langchain_core/messages/__init__.py +++ b/libs/core/langchain_core/messages/__init__.py @@ -29,6 +29,7 @@ from langchain_core.messages.base import ( from langchain_core.messages.chat import ChatMessage, ChatMessageChunk from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk from langchain_core.messages.human import HumanMessage, HumanMessageChunk +from langchain_core.messages.modifier import RemoveMessage from langchain_core.messages.system import SystemMessage, SystemMessageChunk from langchain_core.messages.tool import ( InvalidToolCall, @@ -70,6 +71,7 @@ __all__ = [ "ToolCallChunk", "ToolMessage", "ToolMessageChunk", + "RemoveMessage", "_message_from_dict", "convert_to_messages", "get_buffer_string", diff --git a/libs/core/langchain_core/messages/modifier.py b/libs/core/langchain_core/messages/modifier.py new file mode 100644 index 00000000000..d68b4751eb2 --- /dev/null +++ b/libs/core/langchain_core/messages/modifier.py @@ -0,0 +1,23 @@ +from typing import Any, List, Literal + +from langchain_core.messages.base import BaseMessage + + +class RemoveMessage(BaseMessage): + """Message responsible for deleting other messages.""" + + type: Literal["remove"] = "remove" + + def __init__(self, id: str, **kwargs: Any) -> None: + if kwargs.pop("content", None): + raise ValueError("RemoveMessage does not support 'content' field.") + + return super().__init__("", id=id, **kwargs) + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "schema", "messages"] + + +RemoveMessage.update_forward_refs() diff --git a/libs/core/tests/unit_tests/messages/test_imports.py b/libs/core/tests/unit_tests/messages/test_imports.py index e866e71ab24..531409a4261 100644 --- a/libs/core/tests/unit_tests/messages/test_imports.py +++ b/libs/core/tests/unit_tests/messages/test_imports.py @@ -21,6 +21,7 @@ EXPECTED_ALL = [ "ToolCallChunk", "ToolMessage", "ToolMessageChunk", + "RemoveMessage", "convert_to_messages", "get_buffer_string", "merge_content",