Files
langchain/tests/unit_tests/test_schema.py
2023-06-09 12:24:19 -07:00

120 lines
3.3 KiB
Python

"""Test formatting functionality."""
import unittest
import langchain.schema as schema
from langchain.schema import __all__ as schema_all
from langchain.schema.base import (
AIMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)
class TestGetBufferString(unittest.TestCase):
human_msg: HumanMessage = HumanMessage(content="human")
ai_msg: AIMessage = AIMessage(content="ai")
sys_msg: SystemMessage = SystemMessage(content="sys")
def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")
def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg]),
expected_output,
)
def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)
def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual(
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)
def test_multiple_msg(self) -> None:
msgs = [self.human_msg, self.ai_msg, self.sys_msg]
expected_output = "\n".join(
[
f"Human: {self.human_msg.content}",
f"AI: {self.ai_msg.content}",
f"System: {self.sys_msg.content}",
]
)
self.assertEqual(
get_buffer_string(msgs),
expected_output,
)
class TestMessageDictConversion(unittest.TestCase):
human_msg: HumanMessage = HumanMessage(
content="human", additional_kwargs={"key": "value"}
)
ai_msg: AIMessage = AIMessage(content="ai")
sys_msg: SystemMessage = SystemMessage(content="sys")
def test_multiple_msg(self) -> None:
msgs = [
self.human_msg,
self.ai_msg,
self.sys_msg,
]
self.assertEqual(
messages_from_dict(messages_to_dict(msgs)),
msgs,
)
def test_public_api() -> None:
"""Test for changes in the public API."""
expected_all = [
"AIMessage",
"AgentAction",
"AgentFinish",
"BaseChatMessageHistory",
"BaseDocumentTransformer",
"BaseMemory",
"BaseMessage",
"BaseOutputParser",
"BaseRetriever",
"ChatGeneration",
"ChatMessage",
"ChatResult",
"Document",
"Generation",
"HumanMessage",
"LLMResult",
"Memory",
"OutputParserException",
"PromptValue",
"RUN_KEY",
"RunInfo",
"SystemMessage",
"T",
"get_buffer_string",
"message_from_dict",
"message_to_dict",
"messages_from_dict",
"messages_to_dict",
]
assert sorted(schema_all) == expected_all
# Assert that the object is actually present in the schema module
for module_name in expected_all:
assert hasattr(schema, module_name) and getattr(schema, module_name) is not None