mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
BUGFIX: handle tool message type when converting to string (#13626)
**Description:** Currently, if we pass in a ToolMessage back to the chain, it crashes with error `Got unsupported message type: ` This fixes it. Tested locally --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
143049c90f
commit
5064890fcf
@ -54,6 +54,8 @@ def get_buffer_string(
|
|||||||
role = "System"
|
role = "System"
|
||||||
elif isinstance(m, FunctionMessage):
|
elif isinstance(m, FunctionMessage):
|
||||||
role = "Function"
|
role = "Function"
|
||||||
|
elif isinstance(m, ToolMessage):
|
||||||
|
role = "Tool"
|
||||||
elif isinstance(m, ChatMessage):
|
elif isinstance(m, ChatMessage):
|
||||||
role = m.role
|
role = m.role
|
||||||
else:
|
else:
|
||||||
|
@ -1,10 +1,21 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
|
ChatMessage,
|
||||||
ChatMessageChunk,
|
ChatMessageChunk,
|
||||||
|
FunctionMessage,
|
||||||
FunctionMessageChunk,
|
FunctionMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
get_buffer_string,
|
||||||
|
messages_from_dict,
|
||||||
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -100,3 +111,76 @@ def test_ani_message_chunks() -> None:
|
|||||||
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
|
||||||
example=False, content=" indeed."
|
example=False, content=" indeed."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetBufferString(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.human_msg = HumanMessage(content="human")
|
||||||
|
self.ai_msg = AIMessage(content="ai")
|
||||||
|
self.sys_msg = SystemMessage(content="system")
|
||||||
|
self.func_msg = FunctionMessage(name="func", content="function")
|
||||||
|
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
|
||||||
|
self.chat_msg = ChatMessage(role="Chat", content="chat")
|
||||||
|
|
||||||
|
def test_empty_input(self) -> None:
|
||||||
|
self.assertEqual(get_buffer_string([]), "")
|
||||||
|
|
||||||
|
def test_valid_single_message(self) -> None:
|
||||||
|
expected_output = f"Human: {self.human_msg.content}"
|
||||||
|
self.assertEqual(
|
||||||
|
get_buffer_string([self.human_msg]),
|
||||||
|
expected_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_custom_human_prefix(self) -> None:
|
||||||
|
prefix = "H"
|
||||||
|
expected_output = f"{prefix}: {self.human_msg.content}"
|
||||||
|
self.assertEqual(
|
||||||
|
get_buffer_string([self.human_msg], human_prefix="H"),
|
||||||
|
expected_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_custom_ai_prefix(self) -> None:
|
||||||
|
prefix = "A"
|
||||||
|
expected_output = f"{prefix}: {self.ai_msg.content}"
|
||||||
|
self.assertEqual(
|
||||||
|
get_buffer_string([self.ai_msg], ai_prefix="A"),
|
||||||
|
expected_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multiple_msg(self) -> None:
|
||||||
|
msgs = [
|
||||||
|
self.human_msg,
|
||||||
|
self.ai_msg,
|
||||||
|
self.sys_msg,
|
||||||
|
self.func_msg,
|
||||||
|
self.tool_msg,
|
||||||
|
self.chat_msg,
|
||||||
|
]
|
||||||
|
expected_output = "\n".join(
|
||||||
|
[
|
||||||
|
"Human: human",
|
||||||
|
"AI: ai",
|
||||||
|
"System: system",
|
||||||
|
"Function: function",
|
||||||
|
"Tool: tool",
|
||||||
|
"Chat: chat",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
get_buffer_string(msgs),
|
||||||
|
expected_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_msg() -> None:
|
||||||
|
human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"})
|
||||||
|
ai_msg = AIMessage(content="ai")
|
||||||
|
sys_msg = SystemMessage(content="sys")
|
||||||
|
|
||||||
|
msgs = [
|
||||||
|
human_msg,
|
||||||
|
ai_msg,
|
||||||
|
sys_msg,
|
||||||
|
]
|
||||||
|
assert messages_from_dict(messages_to_dict(msgs)) == msgs
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
"""Test formatting functionality."""
|
"""Test formatting functionality."""
|
||||||
import unittest
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -16,75 +15,12 @@ from langchain_core.messages import (
|
|||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
get_buffer_string,
|
|
||||||
messages_from_dict,
|
|
||||||
messages_to_dict,
|
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
|
||||||
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
|
|
||||||
|
|
||||||
class TestGetBufferString(unittest.TestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.human_msg = HumanMessage(content="human")
|
|
||||||
self.ai_msg = AIMessage(content="ai")
|
|
||||||
self.sys_msg = SystemMessage(content="sys")
|
|
||||||
|
|
||||||
def test_empty_input(self) -> None:
|
|
||||||
self.assertEqual(get_buffer_string([]), "")
|
|
||||||
|
|
||||||
def test_valid_single_message(self) -> None:
|
|
||||||
expected_output = f"Human: {self.human_msg.content}"
|
|
||||||
self.assertEqual(
|
|
||||||
get_buffer_string([self.human_msg]),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_custom_human_prefix(self) -> None:
|
|
||||||
prefix = "H"
|
|
||||||
expected_output = f"{prefix}: {self.human_msg.content}"
|
|
||||||
self.assertEqual(
|
|
||||||
get_buffer_string([self.human_msg], human_prefix="H"),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_custom_ai_prefix(self) -> None:
|
|
||||||
prefix = "A"
|
|
||||||
expected_output = f"{prefix}: {self.ai_msg.content}"
|
|
||||||
self.assertEqual(
|
|
||||||
get_buffer_string([self.ai_msg], ai_prefix="A"),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_multiple_msg(self) -> None:
|
|
||||||
msgs = [self.human_msg, self.ai_msg, self.sys_msg]
|
|
||||||
expected_output = "\n".join(
|
|
||||||
[
|
|
||||||
f"Human: {self.human_msg.content}",
|
|
||||||
f"AI: {self.ai_msg.content}",
|
|
||||||
f"System: {self.sys_msg.content}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
get_buffer_string(msgs),
|
|
||||||
expected_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_msg() -> None:
|
|
||||||
human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"})
|
|
||||||
ai_msg = AIMessage(content="ai")
|
|
||||||
sys_msg = SystemMessage(content="sys")
|
|
||||||
|
|
||||||
msgs = [
|
|
||||||
human_msg,
|
|
||||||
ai_msg,
|
|
||||||
sys_msg,
|
|
||||||
]
|
|
||||||
assert messages_from_dict(messages_to_dict(msgs)) == msgs
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialization_of_wellknown_objects() -> None:
|
def test_serialization_of_wellknown_objects() -> None:
|
||||||
"""Test that pydantic is able to serialize and deserialize well known objects."""
|
"""Test that pydantic is able to serialize and deserialize well known objects."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user