mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
120 lines
3.3 KiB
Python
120 lines
3.3 KiB
Python
"""Test formatting functionality."""
|
|
|
|
import unittest
|
|
|
|
import langchain.schema as schema
|
|
from langchain.schema import __all__ as schema_all
|
|
from langchain.schema.base import (
|
|
AIMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
get_buffer_string,
|
|
messages_from_dict,
|
|
messages_to_dict,
|
|
)
|
|
|
|
|
|
class TestGetBufferString(unittest.TestCase):
|
|
human_msg: HumanMessage = HumanMessage(content="human")
|
|
ai_msg: AIMessage = AIMessage(content="ai")
|
|
sys_msg: SystemMessage = SystemMessage(content="sys")
|
|
|
|
def test_empty_input(self) -> None:
|
|
self.assertEqual(get_buffer_string([]), "")
|
|
|
|
def test_valid_single_message(self) -> None:
|
|
expected_output = f"Human: {self.human_msg.content}"
|
|
self.assertEqual(
|
|
get_buffer_string([self.human_msg]),
|
|
expected_output,
|
|
)
|
|
|
|
def test_custom_human_prefix(self) -> None:
|
|
prefix = "H"
|
|
expected_output = f"{prefix}: {self.human_msg.content}"
|
|
self.assertEqual(
|
|
get_buffer_string([self.human_msg], human_prefix="H"),
|
|
expected_output,
|
|
)
|
|
|
|
def test_custom_ai_prefix(self) -> None:
|
|
prefix = "A"
|
|
expected_output = f"{prefix}: {self.ai_msg.content}"
|
|
self.assertEqual(
|
|
get_buffer_string([self.ai_msg], ai_prefix="A"),
|
|
expected_output,
|
|
)
|
|
|
|
def test_multiple_msg(self) -> None:
|
|
msgs = [self.human_msg, self.ai_msg, self.sys_msg]
|
|
expected_output = "\n".join(
|
|
[
|
|
f"Human: {self.human_msg.content}",
|
|
f"AI: {self.ai_msg.content}",
|
|
f"System: {self.sys_msg.content}",
|
|
]
|
|
)
|
|
self.assertEqual(
|
|
get_buffer_string(msgs),
|
|
expected_output,
|
|
)
|
|
|
|
|
|
class TestMessageDictConversion(unittest.TestCase):
|
|
human_msg: HumanMessage = HumanMessage(
|
|
content="human", additional_kwargs={"key": "value"}
|
|
)
|
|
ai_msg: AIMessage = AIMessage(content="ai")
|
|
sys_msg: SystemMessage = SystemMessage(content="sys")
|
|
|
|
def test_multiple_msg(self) -> None:
|
|
msgs = [
|
|
self.human_msg,
|
|
self.ai_msg,
|
|
self.sys_msg,
|
|
]
|
|
self.assertEqual(
|
|
messages_from_dict(messages_to_dict(msgs)),
|
|
msgs,
|
|
)
|
|
|
|
|
|
def test_public_api() -> None:
|
|
"""Test for changes in the public API."""
|
|
expected_all = [
|
|
"AIMessage",
|
|
"AgentAction",
|
|
"AgentFinish",
|
|
"BaseChatMessageHistory",
|
|
"BaseDocumentTransformer",
|
|
"BaseMemory",
|
|
"BaseMessage",
|
|
"BaseOutputParser",
|
|
"BaseRetriever",
|
|
"ChatGeneration",
|
|
"ChatMessage",
|
|
"ChatResult",
|
|
"Document",
|
|
"Generation",
|
|
"HumanMessage",
|
|
"LLMResult",
|
|
"Memory",
|
|
"OutputParserException",
|
|
"PromptValue",
|
|
"RUN_KEY",
|
|
"RunInfo",
|
|
"SystemMessage",
|
|
"T",
|
|
"get_buffer_string",
|
|
"message_from_dict",
|
|
"message_to_dict",
|
|
"messages_from_dict",
|
|
"messages_to_dict",
|
|
]
|
|
|
|
assert sorted(schema_all) == expected_all
|
|
|
|
# Assert that the object is actually present in the schema module
|
|
for module_name in expected_all:
|
|
assert hasattr(schema, module_name) and getattr(schema, module_name) is not None
|