mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 01:49:58 +00:00
365 lines
12 KiB
Python
Executable File
365 lines
12 KiB
Python
Executable File
import pytest
|
|
|
|
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
|
from dbgpt.core.interface.message import *
|
|
|
|
|
|
@pytest.fixture
|
|
def basic_conversation():
|
|
return OnceConversation(chat_mode="chat_normal", user_name="user1", sys_code="sys1")
|
|
|
|
|
|
@pytest.fixture
|
|
def human_message():
|
|
return HumanMessage(content="Hello")
|
|
|
|
|
|
@pytest.fixture
|
|
def ai_message():
|
|
return AIMessage(content="Hi there")
|
|
|
|
|
|
@pytest.fixture
|
|
def system_message():
|
|
return SystemMessage(content="System update")
|
|
|
|
|
|
@pytest.fixture
|
|
def view_message():
|
|
return ViewMessage(content="View this")
|
|
|
|
|
|
@pytest.fixture
|
|
def conversation_identifier():
|
|
return ConversationIdentifier("conv1")
|
|
|
|
|
|
@pytest.fixture
|
|
def message_identifier():
|
|
return MessageIdentifier("conv1", 1)
|
|
|
|
|
|
@pytest.fixture
|
|
def message_storage_item():
|
|
message = HumanMessage(content="Hello", index=1)
|
|
message_detail = message.to_dict()
|
|
return MessageStorageItem("conv1", 1, message_detail)
|
|
|
|
|
|
@pytest.fixture
|
|
def storage_conversation():
|
|
return StorageConversation("conv1", chat_mode="chat_normal", user_name="user1")
|
|
|
|
|
|
@pytest.fixture
|
|
def conversation_with_messages():
|
|
conv = OnceConversation(chat_mode="chat_normal", user_name="user1")
|
|
conv.start_new_round()
|
|
conv.add_user_message("Hello")
|
|
conv.add_ai_message("Hi")
|
|
conv.end_current_round()
|
|
|
|
conv.start_new_round()
|
|
conv.add_user_message("How are you?")
|
|
conv.add_ai_message("I'm good, thanks")
|
|
conv.end_current_round()
|
|
|
|
return conv
|
|
|
|
|
|
@pytest.fixture
|
|
def human_model_message():
|
|
return ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")
|
|
|
|
|
|
@pytest.fixture
|
|
def ai_model_message():
|
|
return ModelMessage(role=ModelMessageRoleType.AI, content="Hi there")
|
|
|
|
|
|
@pytest.fixture
|
|
def system_model_message():
|
|
return ModelMessage(
|
|
role=ModelMessageRoleType.SYSTEM, content="You are a helpful chatbot!"
|
|
)
|
|
|
|
|
|
def test_init(basic_conversation):
|
|
assert basic_conversation.chat_mode == "chat_normal"
|
|
assert basic_conversation.user_name == "user1"
|
|
assert basic_conversation.sys_code == "sys1"
|
|
assert basic_conversation.messages == []
|
|
assert basic_conversation.start_date == ""
|
|
assert basic_conversation.chat_order == 0
|
|
assert basic_conversation.model_name == ""
|
|
assert basic_conversation.param_type == ""
|
|
assert basic_conversation.param_value == ""
|
|
assert basic_conversation.cost == 0
|
|
assert basic_conversation.tokens == 0
|
|
assert basic_conversation._message_index == 0
|
|
|
|
|
|
def test_add_user_message(basic_conversation, human_message):
|
|
basic_conversation.add_user_message(human_message.content)
|
|
assert len(basic_conversation.messages) == 1
|
|
assert isinstance(basic_conversation.messages[0], HumanMessage)
|
|
|
|
|
|
def test_add_ai_message(basic_conversation, ai_message):
|
|
basic_conversation.add_ai_message(ai_message.content)
|
|
assert len(basic_conversation.messages) == 1
|
|
assert isinstance(basic_conversation.messages[0], AIMessage)
|
|
|
|
|
|
def test_add_system_message(basic_conversation, system_message):
|
|
basic_conversation.add_system_message(system_message.content)
|
|
assert len(basic_conversation.messages) == 1
|
|
assert isinstance(basic_conversation.messages[0], SystemMessage)
|
|
|
|
|
|
def test_add_view_message(basic_conversation, view_message):
|
|
basic_conversation.add_view_message(view_message.content)
|
|
assert len(basic_conversation.messages) == 1
|
|
assert isinstance(basic_conversation.messages[0], ViewMessage)
|
|
|
|
|
|
def test_set_start_time(basic_conversation):
|
|
now = datetime.now()
|
|
basic_conversation.set_start_time(now)
|
|
assert basic_conversation.start_date == now.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
def test_clear_messages(basic_conversation, human_message):
|
|
basic_conversation.add_user_message(human_message.content)
|
|
basic_conversation.clear()
|
|
assert len(basic_conversation.messages) == 0
|
|
|
|
|
|
def test_get_latest_user_message(basic_conversation, human_message):
|
|
basic_conversation.add_user_message(human_message.content)
|
|
latest_message = basic_conversation.get_latest_user_message()
|
|
assert latest_message == human_message
|
|
|
|
|
|
def test_get_system_messages(basic_conversation, system_message):
|
|
basic_conversation.add_system_message(system_message.content)
|
|
system_messages = basic_conversation.get_system_messages()
|
|
assert len(system_messages) == 1
|
|
assert system_messages[0] == system_message
|
|
|
|
|
|
def test_from_conversation(basic_conversation):
|
|
new_conversation = OnceConversation(chat_mode="chat_advanced", user_name="user2")
|
|
basic_conversation.from_conversation(new_conversation)
|
|
assert basic_conversation.chat_mode == "chat_advanced"
|
|
assert basic_conversation.user_name == "user2"
|
|
|
|
|
|
def test_get_messages_by_round(conversation_with_messages):
|
|
# Test first round
|
|
round1_messages = conversation_with_messages.get_messages_by_round(1)
|
|
assert len(round1_messages) == 2
|
|
assert round1_messages[0].content == "Hello"
|
|
assert round1_messages[1].content == "Hi"
|
|
|
|
# Test not existing round
|
|
no_messages = conversation_with_messages.get_messages_by_round(3)
|
|
assert len(no_messages) == 0
|
|
|
|
|
|
def test_get_latest_round(conversation_with_messages):
|
|
latest_round_messages = conversation_with_messages.get_latest_round()
|
|
assert len(latest_round_messages) == 2
|
|
assert latest_round_messages[0].content == "How are you?"
|
|
assert latest_round_messages[1].content == "I'm good, thanks"
|
|
|
|
|
|
def test_get_messages_with_round(conversation_with_messages):
|
|
# Test last round
|
|
last_round_messages = conversation_with_messages.get_messages_with_round(1)
|
|
assert len(last_round_messages) == 2
|
|
assert last_round_messages[0].content == "How are you?"
|
|
assert last_round_messages[1].content == "I'm good, thanks"
|
|
|
|
# Test last two rounds
|
|
last_two_rounds_messages = conversation_with_messages.get_messages_with_round(2)
|
|
assert len(last_two_rounds_messages) == 4
|
|
assert last_two_rounds_messages[0].content == "Hello"
|
|
assert last_two_rounds_messages[1].content == "Hi"
|
|
|
|
|
|
def test_get_model_messages(conversation_with_messages):
|
|
model_messages = conversation_with_messages.get_model_messages()
|
|
assert len(model_messages) == 4
|
|
assert all(isinstance(msg, ModelMessage) for msg in model_messages)
|
|
assert model_messages[0].content == "Hello"
|
|
assert model_messages[1].content == "Hi"
|
|
assert model_messages[2].content == "How are you?"
|
|
assert model_messages[3].content == "I'm good, thanks"
|
|
|
|
|
|
def test_conversation_identifier(conversation_identifier):
|
|
assert conversation_identifier.conv_uid == "conv1"
|
|
assert conversation_identifier.identifier_type == "conversation"
|
|
assert conversation_identifier.str_identifier == "conversation:conv1"
|
|
assert conversation_identifier.to_dict() == {
|
|
"conv_uid": "conv1",
|
|
"identifier_type": "conversation",
|
|
}
|
|
|
|
|
|
def test_message_identifier(message_identifier):
|
|
assert message_identifier.conv_uid == "conv1"
|
|
assert message_identifier.index == 1
|
|
assert message_identifier.identifier_type == "message"
|
|
assert message_identifier.str_identifier == "message___conv1___1"
|
|
assert message_identifier.to_dict() == {
|
|
"conv_uid": "conv1",
|
|
"index": 1,
|
|
"identifier_type": "message",
|
|
}
|
|
|
|
|
|
def test_message_storage_item(message_storage_item):
|
|
assert message_storage_item.conv_uid == "conv1"
|
|
assert message_storage_item.index == 1
|
|
assert message_storage_item.message_detail == {
|
|
"type": "human",
|
|
"data": {
|
|
"content": "Hello",
|
|
"index": 1,
|
|
"round_index": 0,
|
|
"additional_kwargs": {},
|
|
"example": False,
|
|
},
|
|
"index": 1,
|
|
"round_index": 0,
|
|
}
|
|
|
|
assert isinstance(message_storage_item.identifier, MessageIdentifier)
|
|
assert message_storage_item.to_dict() == {
|
|
"conv_uid": "conv1",
|
|
"index": 1,
|
|
"message_detail": {
|
|
"type": "human",
|
|
"index": 1,
|
|
"data": {
|
|
"content": "Hello",
|
|
"index": 1,
|
|
"round_index": 0,
|
|
"additional_kwargs": {},
|
|
"example": False,
|
|
},
|
|
"round_index": 0,
|
|
},
|
|
}
|
|
|
|
assert isinstance(message_storage_item.to_message(), BaseMessage)
|
|
|
|
|
|
def test_storage_conversation_init(storage_conversation):
|
|
assert storage_conversation.conv_uid == "conv1"
|
|
assert storage_conversation.chat_mode == "chat_normal"
|
|
assert storage_conversation.user_name == "user1"
|
|
|
|
|
|
def test_storage_conversation_add_user_message(storage_conversation):
|
|
storage_conversation.add_user_message("Hi")
|
|
assert len(storage_conversation.messages) == 1
|
|
assert isinstance(storage_conversation.messages[0], HumanMessage)
|
|
|
|
|
|
def test_storage_conversation_add_ai_message(storage_conversation):
|
|
storage_conversation.add_ai_message("Hello")
|
|
assert len(storage_conversation.messages) == 1
|
|
assert isinstance(storage_conversation.messages[0], AIMessage)
|
|
|
|
|
|
def test_save_to_storage(storage_conversation, in_memory_storage):
|
|
# Set storage
|
|
storage_conversation.conv_storage = in_memory_storage
|
|
storage_conversation.message_storage = in_memory_storage
|
|
|
|
# Add messages
|
|
storage_conversation.add_user_message("User message")
|
|
storage_conversation.add_ai_message("AI response")
|
|
|
|
# Save to storage
|
|
storage_conversation.save_to_storage()
|
|
|
|
# Create a new StorageConversation instance to load the data
|
|
saved_conversation = StorageConversation(
|
|
storage_conversation.conv_uid,
|
|
conv_storage=in_memory_storage,
|
|
message_storage=in_memory_storage,
|
|
)
|
|
|
|
assert saved_conversation.conv_uid == storage_conversation.conv_uid
|
|
assert len(saved_conversation.messages) == 2
|
|
assert isinstance(saved_conversation.messages[0], HumanMessage)
|
|
assert isinstance(saved_conversation.messages[1], AIMessage)
|
|
|
|
|
|
def test_load_from_storage(storage_conversation, in_memory_storage):
|
|
# Set storage
|
|
storage_conversation.conv_storage = in_memory_storage
|
|
storage_conversation.message_storage = in_memory_storage
|
|
|
|
# Add messages and save to storage
|
|
storage_conversation.add_user_message("User message")
|
|
storage_conversation.add_ai_message("AI response")
|
|
storage_conversation.save_to_storage()
|
|
|
|
# Create a new StorageConversation instance to load the data
|
|
new_conversation = StorageConversation(
|
|
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
|
|
)
|
|
|
|
# Check if the data is loaded correctly
|
|
assert new_conversation.conv_uid == storage_conversation.conv_uid
|
|
assert len(new_conversation.messages) == 2
|
|
assert new_conversation.messages[0].content == "User message"
|
|
assert new_conversation.messages[1].content == "AI response"
|
|
assert isinstance(new_conversation.messages[0], HumanMessage)
|
|
assert isinstance(new_conversation.messages[1], AIMessage)
|
|
|
|
|
|
def test_to_openai_messages(
|
|
human_model_message, ai_model_message, system_model_message
|
|
):
|
|
none_messages = ModelMessage.to_openai_messages([])
|
|
assert none_messages == []
|
|
|
|
single_messages = ModelMessage.to_openai_messages([human_model_message])
|
|
assert single_messages == [{"role": "user", "content": human_model_message.content}]
|
|
|
|
normal_messages = ModelMessage.to_openai_messages(
|
|
[
|
|
system_model_message,
|
|
human_model_message,
|
|
ai_model_message,
|
|
human_model_message,
|
|
]
|
|
)
|
|
assert normal_messages == [
|
|
{"role": "system", "content": system_model_message.content},
|
|
{"role": "user", "content": human_model_message.content},
|
|
{"role": "assistant", "content": ai_model_message.content},
|
|
{"role": "user", "content": human_model_message.content},
|
|
]
|
|
|
|
shuffle_messages = ModelMessage.to_openai_messages(
|
|
[
|
|
system_model_message,
|
|
human_model_message,
|
|
human_model_message,
|
|
ai_model_message,
|
|
]
|
|
)
|
|
assert shuffle_messages == [
|
|
{"role": "system", "content": system_model_message.content},
|
|
{"role": "user", "content": human_model_message.content},
|
|
{"role": "assistant", "content": ai_model_message.content},
|
|
{"role": "user", "content": human_model_message.content},
|
|
]
|