refactor(core): use pytest style in TestGetBufferString (#32786)

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-09-08 17:16:13 +02:00
committed by GitHub
parent 5840dad40b
commit f589168411

View File

@@ -1,4 +1,3 @@
import unittest
import uuid
from typing import Optional, Union
@@ -335,54 +334,45 @@ def test_ai_message_chunks() -> None:
)
class TestGetBufferString(unittest.TestCase):
def setUp(self) -> None:
self.human_msg = HumanMessage(content="human")
self.ai_msg = AIMessage(content="ai")
self.sys_msg = SystemMessage(content="system")
self.func_msg = FunctionMessage(name="func", content="function")
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
self.chat_msg = ChatMessage(role="Chat", content="chat")
self.tool_calls_msg = AIMessage(content="tool")
class TestGetBufferString:
_HUMAN_MSG = HumanMessage(content="human")
_AI_MSG = AIMessage(content="ai")
def test_empty_input(self) -> None:
assert get_buffer_string([]) == ""
def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
assert get_buffer_string([self.human_msg]) == expected_output
expected_output = "Human: human"
assert get_buffer_string([self._HUMAN_MSG]) == expected_output
def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output
expected_output = "H: human"
assert get_buffer_string([self._HUMAN_MSG], human_prefix="H") == expected_output
def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output
expected_output = "A: ai"
assert get_buffer_string([self._AI_MSG], ai_prefix="A") == expected_output
def test_multiple_msg(self) -> None:
msgs = [
self.human_msg,
self.ai_msg,
self.sys_msg,
self.func_msg,
self.tool_msg,
self.chat_msg,
self.tool_calls_msg,
self._HUMAN_MSG,
self._AI_MSG,
SystemMessage(content="system"),
FunctionMessage(name="func", content="function"),
ToolMessage(tool_call_id="tool_id", content="tool"),
ChatMessage(role="Chat", content="chat"),
AIMessage(content="tool"),
]
expected_output = "\n".join( # noqa: FLY002
[
"Human: human",
"AI: ai",
"System: system",
"Function: function",
"Tool: tool",
"Chat: chat",
"AI: tool",
]
expected_output = (
"Human: human\n"
"AI: ai\n"
"System: system\n"
"Function: function\n"
"Tool: tool\n"
"Chat: chat\n"
"AI: tool"
)
assert get_buffer_string(msgs) == expected_output