mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 16:01:33 +00:00
refactor(core): use pytest style in TestGetBufferString
(#32786)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
5840dad40b
commit
f589168411
@@ -1,4 +1,3 @@
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -335,54 +334,45 @@ def test_ai_message_chunks() -> None:
|
||||
)
|
||||
|
||||
|
||||
class TestGetBufferString(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.human_msg = HumanMessage(content="human")
|
||||
self.ai_msg = AIMessage(content="ai")
|
||||
self.sys_msg = SystemMessage(content="system")
|
||||
self.func_msg = FunctionMessage(name="func", content="function")
|
||||
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
|
||||
self.chat_msg = ChatMessage(role="Chat", content="chat")
|
||||
self.tool_calls_msg = AIMessage(content="tool")
|
||||
class TestGetBufferString:
|
||||
_HUMAN_MSG = HumanMessage(content="human")
|
||||
_AI_MSG = AIMessage(content="ai")
|
||||
|
||||
def test_empty_input(self) -> None:
|
||||
assert get_buffer_string([]) == ""
|
||||
|
||||
def test_valid_single_message(self) -> None:
|
||||
expected_output = f"Human: {self.human_msg.content}"
|
||||
assert get_buffer_string([self.human_msg]) == expected_output
|
||||
expected_output = "Human: human"
|
||||
assert get_buffer_string([self._HUMAN_MSG]) == expected_output
|
||||
|
||||
def test_custom_human_prefix(self) -> None:
|
||||
prefix = "H"
|
||||
expected_output = f"{prefix}: {self.human_msg.content}"
|
||||
assert get_buffer_string([self.human_msg], human_prefix="H") == expected_output
|
||||
expected_output = "H: human"
|
||||
assert get_buffer_string([self._HUMAN_MSG], human_prefix="H") == expected_output
|
||||
|
||||
def test_custom_ai_prefix(self) -> None:
|
||||
prefix = "A"
|
||||
expected_output = f"{prefix}: {self.ai_msg.content}"
|
||||
assert get_buffer_string([self.ai_msg], ai_prefix="A") == expected_output
|
||||
expected_output = "A: ai"
|
||||
assert get_buffer_string([self._AI_MSG], ai_prefix="A") == expected_output
|
||||
|
||||
def test_multiple_msg(self) -> None:
|
||||
msgs = [
|
||||
self.human_msg,
|
||||
self.ai_msg,
|
||||
self.sys_msg,
|
||||
self.func_msg,
|
||||
self.tool_msg,
|
||||
self.chat_msg,
|
||||
self.tool_calls_msg,
|
||||
self._HUMAN_MSG,
|
||||
self._AI_MSG,
|
||||
SystemMessage(content="system"),
|
||||
FunctionMessage(name="func", content="function"),
|
||||
ToolMessage(tool_call_id="tool_id", content="tool"),
|
||||
ChatMessage(role="Chat", content="chat"),
|
||||
AIMessage(content="tool"),
|
||||
]
|
||||
expected_output = "\n".join( # noqa: FLY002
|
||||
[
|
||||
"Human: human",
|
||||
"AI: ai",
|
||||
"System: system",
|
||||
"Function: function",
|
||||
"Tool: tool",
|
||||
"Chat: chat",
|
||||
"AI: tool",
|
||||
]
|
||||
expected_output = (
|
||||
"Human: human\n"
|
||||
"AI: ai\n"
|
||||
"System: system\n"
|
||||
"Function: function\n"
|
||||
"Tool: tool\n"
|
||||
"Chat: chat\n"
|
||||
"AI: tool"
|
||||
)
|
||||
|
||||
assert get_buffer_string(msgs) == expected_output
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user