mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 21:51:25 +00:00
feat: add gemini support (#953)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -202,19 +202,65 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _parse_model_messages(
|
||||
def parse_model_messages(
|
||||
messages: List[ModelMessage],
|
||||
) -> Tuple[str, List[str], List[List[str, str]]]:
|
||||
"""
|
||||
Parameters:
|
||||
messages: List of message from base chat.
|
||||
Parse model messages to extract the user prompt, system messages, and a history of conversation.
|
||||
|
||||
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
|
||||
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
|
||||
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
|
||||
pairs of human and AI messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): List of messages from a chat conversation.
|
||||
|
||||
Returns:
|
||||
A tuple contains user prompt, system message list and history message list
|
||||
str: user prompt
|
||||
List[str]: system messages
|
||||
List[List[str]]: history message of user and assistant
|
||||
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
|
||||
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# Example 1: Single round of conversation
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hello"),
|
||||
ModelMessage(role="ai", content="Hi there!"),
|
||||
ModelMessage(role="human", content="How are you?"),
|
||||
]
|
||||
user_prompt, system_messages, history = parse_model_messages(messages)
|
||||
# user_prompt: "How are you?"
|
||||
# system_messages: []
|
||||
# history: [["Hello", "Hi there!"]]
|
||||
|
||||
# Example 2: Conversation with system messages
|
||||
messages = [
|
||||
ModelMessage(role="system", content="System initializing..."),
|
||||
ModelMessage(role="human", content="Is it sunny today?"),
|
||||
ModelMessage(role="ai", content="Yes, it's sunny."),
|
||||
ModelMessage(role="human", content="Great!"),
|
||||
]
|
||||
user_prompt, system_messages, history = parse_model_messages(messages)
|
||||
# user_prompt: "Great!"
|
||||
# system_messages: ["System initializing..."]
|
||||
# history: [["Is it sunny today?", "Yes, it's sunny."]]
|
||||
|
||||
# Example 3: Multiple rounds with system message
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hi"),
|
||||
ModelMessage(role="ai", content="Hello!"),
|
||||
ModelMessage(role="system", content="Error 404"),
|
||||
ModelMessage(role="human", content="What's the error?"),
|
||||
ModelMessage(role="ai", content="Just a joke."),
|
||||
ModelMessage(role="human", content="Funny!"),
|
||||
]
|
||||
user_prompt, system_messages, history = parse_model_messages(messages)
|
||||
# user_prompt: "Funny!"
|
||||
# system_messages: ["Error 404"]
|
||||
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
|
||||
"""
|
||||
user_prompt = ""
|
||||
|
||||
system_messages: List[str] = []
|
||||
history_messages: List[List[str]] = [[]]
|
||||
|
||||
|
@@ -324,6 +324,71 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
|
||||
assert isinstance(new_conversation.messages[1], AIMessage)
|
||||
|
||||
|
||||
def test_parse_model_messages_no_history_messages():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Hello"
|
||||
assert system_messages == []
|
||||
assert history_messages == []
|
||||
|
||||
|
||||
def test_parse_model_messages_single_round_conversation():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hi there!"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello again"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Hello again"
|
||||
assert system_messages == []
|
||||
assert history_messages == [["Hello", "Hi there!"]]
|
||||
|
||||
|
||||
def test_parse_model_messages_two_round_conversation_with_system_message():
|
||||
messages = [
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.SYSTEM, content="System initializing..."
|
||||
),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How's the weather?"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="It's sunny!"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Great to hear!"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Great to hear!"
|
||||
assert system_messages == ["System initializing..."]
|
||||
assert history_messages == [["How's the weather?", "It's sunny!"]]
|
||||
|
||||
|
||||
def test_parse_model_messages_three_round_conversation():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="What's up?"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Not much, you?"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Same here."),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Same here."
|
||||
assert system_messages == []
|
||||
assert history_messages == [["Hi", "Hello!"], ["What's up?", "Not much, you?"]]
|
||||
|
||||
|
||||
def test_parse_model_messages_multiple_system_messages():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System start"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hey"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System check"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How are you?"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "How are you?"
|
||||
assert system_messages == ["System start", "System check"]
|
||||
assert history_messages == [["Hey", "Hello!"]]
|
||||
|
||||
|
||||
def test_to_openai_messages(
|
||||
human_model_message, ai_model_message, system_model_message
|
||||
):
|
||||
|
Reference in New Issue
Block a user