BUGFIX: handle tool message type when converting to string (#13626)

**Description:** Currently, if we pass in a ToolMessage back to the
chain, it crashes with error

`Got unsupported message type: `

This fixes it. 

Tested locally

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
tanujtiwari-at
2023-11-21 18:20:58 -08:00
committed by GitHub
parent 143049c90f
commit 5064890fcf
3 changed files with 86 additions and 64 deletions

View File

@@ -54,6 +54,8 @@ def get_buffer_string(
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ToolMessage):
role = "Tool"
elif isinstance(m, ChatMessage):
role = m.role
else:

View File

@@ -1,10 +1,21 @@
import unittest
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
ToolMessage,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)
@@ -100,3 +111,76 @@ def test_ani_message_chunks() -> None:
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed."
)
class TestGetBufferString(unittest.TestCase):
def setUp(self) -> None:
self.human_msg = HumanMessage(content="human")
self.ai_msg = AIMessage(content="ai")
self.sys_msg = SystemMessage(content="system")
self.func_msg = FunctionMessage(name="func", content="function")
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
self.chat_msg = ChatMessage(role="Chat", content="chat")
def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")
def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg]),
expected_output,
)
def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)
def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual(
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)
def test_multiple_msg(self) -> None:
msgs = [
self.human_msg,
self.ai_msg,
self.sys_msg,
self.func_msg,
self.tool_msg,
self.chat_msg,
]
expected_output = "\n".join(
[
"Human: human",
"AI: ai",
"System: system",
"Function: function",
"Tool: tool",
"Chat: chat",
]
)
self.assertEqual(
get_buffer_string(msgs),
expected_output,
)
def test_multiple_msg() -> None:
human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"})
ai_msg = AIMessage(content="ai")
sys_msg = SystemMessage(content="sys")
msgs = [
human_msg,
ai_msg,
sys_msg,
]
assert messages_from_dict(messages_to_dict(msgs)) == msgs