mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 20:10:08 +00:00
feat: add gemini support (#953)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -123,6 +123,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
|
||||
- [x] [智谱·ChatGLM](http://open.bigmodel.cn/)
|
||||
- [x] [讯飞·星火](https://xinghuo.xfyun.cn/)
|
||||
- [x] [Google·Bard](https://bard.google.com/)
|
||||
- [x] [Google·Gemini](https://makersuite.google.com/app/apikey)
|
||||
|
||||
- **隐私安全**
|
||||
|
||||
|
@@ -61,7 +61,7 @@ class Config(metaclass=Singleton):
|
||||
if self.zhipu_proxy_api_key:
|
||||
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
|
||||
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
|
||||
"ZHIPU_MODEL_VERSION", "chatglm_pro"
|
||||
"ZHIPU_MODEL_VERSION"
|
||||
)
|
||||
|
||||
# wenxin
|
||||
@@ -95,6 +95,14 @@ class Config(metaclass=Singleton):
|
||||
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
|
||||
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
|
||||
|
||||
# gemini proxy
|
||||
self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY")
|
||||
if self.gemini_proxy_api_key:
|
||||
os.environ["gemini_proxyllm_proxy_api_key"] = self.gemini_proxy_api_key
|
||||
os.environ["gemini_proxyllm_proxyllm_backend"] = os.getenv(
|
||||
"GEMINI_MODEL_VERSION", "gemini-pro"
|
||||
)
|
||||
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
|
@@ -60,6 +60,7 @@ LLM_MODEL_CONFIG = {
|
||||
"wenxin_proxyllm": "wenxin_proxyllm",
|
||||
"tongyi_proxyllm": "tongyi_proxyllm",
|
||||
"zhipu_proxyllm": "zhipu_proxyllm",
|
||||
"gemini_proxyllm": "gemini_proxyllm",
|
||||
"bc_proxyllm": "bc_proxyllm",
|
||||
"spark_proxyllm": "spark_proxyllm",
|
||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||
|
@@ -202,19 +202,65 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _parse_model_messages(
|
||||
def parse_model_messages(
|
||||
messages: List[ModelMessage],
|
||||
) -> Tuple[str, List[str], List[List[str, str]]]:
|
||||
"""
|
||||
Parameters:
|
||||
messages: List of message from base chat.
|
||||
Parse model messages to extract the user prompt, system messages, and a history of conversation.
|
||||
|
||||
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
|
||||
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
|
||||
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
|
||||
pairs of human and AI messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): List of messages from a chat conversation.
|
||||
|
||||
Returns:
|
||||
A tuple contains user prompt, system message list and history message list
|
||||
str: user prompt
|
||||
List[str]: system messages
|
||||
List[List[str]]: history message of user and assistant
|
||||
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
|
||||
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# Example 1: Single round of conversation
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hello"),
|
||||
ModelMessage(role="ai", content="Hi there!"),
|
||||
ModelMessage(role="human", content="How are you?"),
|
||||
]
|
||||
user_prompt, system_messages, history = parse_model_messages(messages)
|
||||
# user_prompt: "How are you?"
|
||||
# system_messages: []
|
||||
# history: [["Hello", "Hi there!"]]
|
||||
|
||||
# Example 2: Conversation with system messages
|
||||
messages = [
|
||||
ModelMessage(role="system", content="System initializing..."),
|
||||
ModelMessage(role="human", content="Is it sunny today?"),
|
||||
ModelMessage(role="ai", content="Yes, it's sunny."),
|
||||
ModelMessage(role="human", content="Great!"),
|
||||
]
|
||||
user_prompt, system_messages, history = parse_model_messages(messages)
|
||||
# user_prompt: "Great!"
|
||||
# system_messages: ["System initializing..."]
|
||||
# history: [["Is it sunny today?", "Yes, it's sunny."]]
|
||||
|
||||
# Example 3: Multiple rounds with system message
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hi"),
|
||||
ModelMessage(role="ai", content="Hello!"),
|
||||
ModelMessage(role="system", content="Error 404"),
|
||||
ModelMessage(role="human", content="What's the error?"),
|
||||
ModelMessage(role="ai", content="Just a joke."),
|
||||
ModelMessage(role="human", content="Funny!"),
|
||||
]
|
||||
user_prompt, system_messages, history = parse_model_messages(messages)
|
||||
# user_prompt: "Funny!"
|
||||
# system_messages: ["Error 404"]
|
||||
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
|
||||
"""
|
||||
user_prompt = ""
|
||||
|
||||
system_messages: List[str] = []
|
||||
history_messages: List[List[str]] = [[]]
|
||||
|
||||
|
@@ -324,6 +324,71 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
|
||||
assert isinstance(new_conversation.messages[1], AIMessage)
|
||||
|
||||
|
||||
def test_parse_model_messages_no_history_messages():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Hello"
|
||||
assert system_messages == []
|
||||
assert history_messages == []
|
||||
|
||||
|
||||
def test_parse_model_messages_single_round_conversation():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hi there!"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello again"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Hello again"
|
||||
assert system_messages == []
|
||||
assert history_messages == [["Hello", "Hi there!"]]
|
||||
|
||||
|
||||
def test_parse_model_messages_two_round_conversation_with_system_message():
|
||||
messages = [
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.SYSTEM, content="System initializing..."
|
||||
),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How's the weather?"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="It's sunny!"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Great to hear!"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Great to hear!"
|
||||
assert system_messages == ["System initializing..."]
|
||||
assert history_messages == [["How's the weather?", "It's sunny!"]]
|
||||
|
||||
|
||||
def test_parse_model_messages_three_round_conversation():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="What's up?"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Not much, you?"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Same here."),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "Same here."
|
||||
assert system_messages == []
|
||||
assert history_messages == [["Hi", "Hello!"], ["What's up?", "Not much, you?"]]
|
||||
|
||||
|
||||
def test_parse_model_messages_multiple_system_messages():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System start"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hey"),
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System check"),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How are you?"),
|
||||
]
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
assert user_prompt == "How are you?"
|
||||
assert system_messages == ["System start", "System check"]
|
||||
assert history_messages == [["Hey", "Hello!"]]
|
||||
|
||||
|
||||
def test_to_openai_messages(
|
||||
human_model_message, ai_model_message, system_model_message
|
||||
):
|
||||
|
@@ -8,6 +8,7 @@ from dbgpt.model.proxy.llms.claude import claude_generate_stream
|
||||
from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
|
||||
from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
|
||||
from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
|
||||
from dbgpt.model.proxy.llms.gemini import gemini_generate_stream
|
||||
from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
|
||||
from dbgpt.model.proxy.llms.spark import spark_generate_stream
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
@@ -25,6 +26,7 @@ def proxyllm_generate_stream(
|
||||
"wenxin_proxyllm": wenxin_generate_stream,
|
||||
"tongyi_proxyllm": tongyi_generate_stream,
|
||||
"zhipu_proxyllm": zhipu_generate_stream,
|
||||
"gemini_proxyllm": gemini_generate_stream,
|
||||
"bc_proxyllm": baichuan_generate_stream,
|
||||
"spark_proxyllm": spark_generate_stream,
|
||||
}
|
||||
|
109
dbgpt/model/proxy/llms/gemini.py
Normal file
109
dbgpt/model/proxy/llms/gemini.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
from dbgpt.core.interface.message import ModelMessage, parse_model_messages
|
||||
|
||||
GEMINI_DEFAULT_MODEL = "gemini-pro"
|
||||
|
||||
|
||||
def gemini_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
global history
|
||||
|
||||
# TODO proxy model use unified config?
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend
|
||||
|
||||
generation_config = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 1,
|
||||
"top_k": 1,
|
||||
"max_output_tokens": 2048,
|
||||
}
|
||||
|
||||
safety_settings = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
]
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
if model_params.proxy_api_base:
|
||||
from google.api_core import client_options
|
||||
|
||||
client_opts = client_options.ClientOptions(
|
||||
api_endpoint=model_params.proxy_api_base
|
||||
)
|
||||
genai.configure(
|
||||
api_key=proxy_api_key, transport="rest", client_options=client_opts
|
||||
)
|
||||
else:
|
||||
genai.configure(api_key=proxy_api_key)
|
||||
model = genai.GenerativeModel(
|
||||
model_name=proxyllm_backend,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
|
||||
chat = model.start_chat(history=gemini_hist)
|
||||
response = chat.send_message(user_prompt, stream=True)
|
||||
text = ""
|
||||
for chunk in response:
|
||||
text += chunk.text
|
||||
print(text)
|
||||
yield text
|
||||
|
||||
|
||||
def _transform_to_gemini_messages(
|
||||
messages: List[ModelMessage],
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
"""Transform messages to gemini format
|
||||
|
||||
See https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): messages
|
||||
|
||||
Returns:
|
||||
Tuple[str, List[Dict[str, Any]]]: user_prompt, gemini_hist
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hello"),
|
||||
ModelMessage(role="ai", content="Hi there!"),
|
||||
ModelMessage(role="human", content="How are you?"),
|
||||
]
|
||||
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
|
||||
assert user_prompt == "How are you?"
|
||||
assert gemini_hist == [
|
||||
{"role": "user", "parts": {"text": "Hello"}},
|
||||
{"role": "model", "parts": {"text": "Hi there!"}}
|
||||
]
|
||||
"""
|
||||
user_prompt, system_messages, history_messages = parse_model_messages(messages)
|
||||
if system_messages:
|
||||
user_prompt = "".join(system_messages) + "\n" + user_prompt
|
||||
gemini_hist = []
|
||||
if history_messages:
|
||||
for user_message, model_message in history_messages:
|
||||
gemini_hist.append({"role": "user", "parts": {"text": user_message}})
|
||||
gemini_hist.append({"role": "model", "parts": {"text": model_message}})
|
||||
return user_prompt, gemini_hist
|
@@ -6,7 +6,7 @@ from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
|
||||
|
||||
|
||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
def __convert_2_zhipu_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
wenxin_messages = []
|
||||
|
||||
@@ -57,38 +57,7 @@ def zhipu_generate_stream(
|
||||
zhipuai.api_key = proxy_api_key
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
# system = ""
|
||||
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
# role_define = messages.pop(0)
|
||||
# system = role_define.content
|
||||
# else:
|
||||
# message = messages.pop(0)
|
||||
# if message.role == ModelMessageRoleType.HUMAN:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
# for message in messages:
|
||||
# if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
# # elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# # history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.AI:
|
||||
# history.append({"role": "assistant", "content": message.content})
|
||||
# else:
|
||||
# pass
|
||||
#
|
||||
# # temp_his = history[::-1]
|
||||
# temp_his = history
|
||||
# last_user_input = None
|
||||
# for m in temp_his:
|
||||
# if m["role"] == "user":
|
||||
# last_user_input = m
|
||||
# break
|
||||
#
|
||||
# if last_user_input:
|
||||
# history.remove(last_user_input)
|
||||
# history.append(last_user_input)
|
||||
|
||||
history, systems = __convert_2_wenxin_messages(messages)
|
||||
history, systems = __convert_2_zhipu_messages(messages)
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
model=proxyllm_backend,
|
||||
prompt=history,
|
||||
|
Reference in New Issue
Block a user