diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 807f52ae10d..b1df4955232 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,4 +1,3 @@ -import unittest import uuid from typing import Optional, Union @@ -335,54 +334,45 @@ def test_ai_message_chunks() -> None: ) -class TestGetBufferString(unittest.TestCase): - def setUp(self) -> None: - self.human_msg = HumanMessage(content="human") - self.ai_msg = AIMessage(content="ai") - self.sys_msg = SystemMessage(content="system") - self.func_msg = FunctionMessage(name="func", content="function") - self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool") - self.chat_msg = ChatMessage(role="Chat", content="chat") - self.tool_calls_msg = AIMessage(content="tool") +class TestGetBufferString: + _HUMAN_MSG = HumanMessage(content="human") + _AI_MSG = AIMessage(content="ai") def test_empty_input(self) -> None: assert get_buffer_string([]) == "" def test_valid_single_message(self) -> None: - expected_output = f"Human: {self.human_msg.content}" - assert get_buffer_string([self.human_msg]) == expected_output + expected_output = "Human: human" + assert get_buffer_string([self._HUMAN_MSG]) == expected_output def test_custom_human_prefix(self) -> None: - prefix = "H" - expected_output = f"{prefix}: {self.human_msg.content}" - assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output + expected_output = "H: human" + assert get_buffer_string([self._HUMAN_MSG], human_prefix="H") == expected_output def test_custom_ai_prefix(self) -> None: - prefix = "A" - expected_output = f"{prefix}: {self.ai_msg.content}" - assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output + expected_output = "A: ai" + assert get_buffer_string([self._AI_MSG], ai_prefix="A") == expected_output def test_multiple_msg(self) -> None: msgs = [ - self.human_msg, - self.ai_msg, - self.sys_msg, - self.func_msg, - self.tool_msg, - self.chat_msg, - self.tool_calls_msg, + self._HUMAN_MSG, + self._AI_MSG, + SystemMessage(content="system"), + FunctionMessage(name="func", content="function"), + ToolMessage(tool_call_id="tool_id", content="tool"), + ChatMessage(role="Chat", content="chat"), + AIMessage(content="tool"), ] - expected_output = "\n".join( # noqa: FLY002 - [ - "Human: human", - "AI: ai", - "System: system", - "Function: function", - "Tool: tool", - "Chat: chat", - "AI: tool", - ] + expected_output = ( + "Human: human\n" + "AI: ai\n" + "System: system\n" + "Function: function\n" + "Tool: tool\n" + "Chat: chat\n" + "AI: tool" ) + assert get_buffer_string(msgs) == expected_output