core[minor]: update conversion utils to handle RemoveMessage (#23840)

This commit is contained in:
Vadym Barda 2024-07-03 16:13:31 -04:00 committed by GitHub
parent 4ab78572e7
commit 9bb623381b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 1 deletions

View File

@ -32,6 +32,7 @@ from langchain_core.messages.base import BaseMessage, BaseMessageChunk
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
from langchain_core.messages.human import HumanMessage, HumanMessageChunk from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.modifier import RemoveMessage
from langchain_core.messages.system import SystemMessage, SystemMessageChunk from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
@ -42,7 +43,12 @@ if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable from langchain_core.runnables.base import Runnable
AnyMessage = Union[ AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage AIMessage,
HumanMessage,
ChatMessage,
SystemMessage,
FunctionMessage,
ToolMessage,
] ]
@ -113,6 +119,8 @@ def _message_from_dict(message: dict) -> BaseMessage:
return FunctionMessage(**message["data"]) return FunctionMessage(**message["data"])
elif _type == "tool": elif _type == "tool":
return ToolMessage(**message["data"]) return ToolMessage(**message["data"])
elif _type == "remove":
return RemoveMessage(**message["data"])
elif _type == "AIMessageChunk": elif _type == "AIMessageChunk":
return AIMessageChunk(**message["data"]) return AIMessageChunk(**message["data"])
elif _type == "HumanMessageChunk": elif _type == "HumanMessageChunk":
@ -214,6 +222,8 @@ def _create_message_from_message_type(
message = FunctionMessage(content=content, **kwargs) message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool": elif message_type == "tool":
message = ToolMessage(content=content, **kwargs) message = ToolMessage(content=content, **kwargs)
elif message_type == "remove":
message = RemoveMessage(**kwargs)
else: else:
raise ValueError( raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human'," f"Unexpected message type: {message_type}. Use one of 'human',"

View File

@ -12,6 +12,7 @@ from langchain_core.messages import (
FunctionMessageChunk, FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk, HumanMessageChunk,
RemoveMessage,
SystemMessage, SystemMessage,
ToolCall, ToolCall,
ToolCallChunk, ToolCallChunk,
@ -649,6 +650,7 @@ def test_convert_to_messages() -> None:
], ],
}, },
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"}, {"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
{"role": "remove", "id": "message_to_remove", "content": ""},
] ]
) == [ ) == [
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
@ -668,6 +670,7 @@ def test_convert_to_messages() -> None:
tool_calls=[ToolCall(name="greet", args={"name": "Jane"}, id="tool_id")], tool_calls=[ToolCall(name="greet", args={"name": "Jane"}, id="tool_id")],
), ),
ToolMessage(tool_call_id="tool_id", content="Hi!"), ToolMessage(tool_call_id="tool_id", content="Hi!"),
RemoveMessage(id="message_to_remove"),
] ]
# tuples # tuples