feat: add gemini support (#953)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
yihong
2023-12-23 11:10:42 +08:00
committed by GitHub
parent e1ace141f7
commit 12234ae258
8 changed files with 243 additions and 42 deletions

View File

@@ -123,6 +123,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
- [x] [智谱·ChatGLM](http://open.bigmodel.cn/)
- [x] [讯飞·星火](https://xinghuo.xfyun.cn/)
- [x] [Google·Bard](https://bard.google.com/)
- [x] [Google·Gemini](https://makersuite.google.com/app/apikey)
- **隐私安全**

View File

@@ -61,7 +61,7 @@ class Config(metaclass=Singleton):
if self.zhipu_proxy_api_key:
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
"ZHIPU_MODEL_VERSION", "chatglm_pro"
"ZHIPU_MODEL_VERSION"
)
# wenxin
@@ -95,6 +95,14 @@ class Config(metaclass=Singleton):
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
# gemini proxy
self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY")
if self.gemini_proxy_api_key:
os.environ["gemini_proxyllm_proxy_api_key"] = self.gemini_proxy_api_key
os.environ["gemini_proxyllm_proxyllm_backend"] = os.getenv(
"GEMINI_MODEL_VERSION", "gemini-pro"
)
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")

View File

@@ -60,6 +60,7 @@ LLM_MODEL_CONFIG = {
"wenxin_proxyllm": "wenxin_proxyllm",
"tongyi_proxyllm": "tongyi_proxyllm",
"zhipu_proxyllm": "zhipu_proxyllm",
"gemini_proxyllm": "gemini_proxyllm",
"bc_proxyllm": "bc_proxyllm",
"spark_proxyllm": "spark_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),

View File

@@ -202,19 +202,65 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
def _parse_model_messages(
def parse_model_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
"""
Parameters:
messages: List of message from base chat.
Parse model messages to extract the user prompt, system messages, and a history of conversation.
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
pairs of human and AI messages.
Args:
messages (List[ModelMessage]): List of messages from a chat conversation.
Returns:
A tuple contains user prompt, system message list and history message list
str: user prompt
List[str]: system messages
List[List[str]]: history message of user and assistant
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
Examples:
.. code-block:: python
# Example 1: Single round of conversation
messages = [
ModelMessage(role="human", content="Hello"),
ModelMessage(role="ai", content="Hi there!"),
ModelMessage(role="human", content="How are you?"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "How are you?"
# system_messages: []
# history: [["Hello", "Hi there!"]]
# Example 2: Conversation with system messages
messages = [
ModelMessage(role="system", content="System initializing..."),
ModelMessage(role="human", content="Is it sunny today?"),
ModelMessage(role="ai", content="Yes, it's sunny."),
ModelMessage(role="human", content="Great!"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "Great!"
# system_messages: ["System initializing..."]
# history: [["Is it sunny today?", "Yes, it's sunny."]]
# Example 3: Multiple rounds with system message
messages = [
ModelMessage(role="human", content="Hi"),
ModelMessage(role="ai", content="Hello!"),
ModelMessage(role="system", content="Error 404"),
ModelMessage(role="human", content="What's the error?"),
ModelMessage(role="ai", content="Just a joke."),
ModelMessage(role="human", content="Funny!"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "Funny!"
# system_messages: ["Error 404"]
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
"""
user_prompt = ""
system_messages: List[str] = []
history_messages: List[List[str]] = [[]]

View File

@@ -324,6 +324,71 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
assert isinstance(new_conversation.messages[1], AIMessage)
def test_parse_model_messages_no_history_messages():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Hello"
assert system_messages == []
assert history_messages == []
def test_parse_model_messages_single_round_conversation():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hi there!"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello again"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Hello again"
assert system_messages == []
assert history_messages == [["Hello", "Hi there!"]]
def test_parse_model_messages_two_round_conversation_with_system_message():
messages = [
ModelMessage(
role=ModelMessageRoleType.SYSTEM, content="System initializing..."
),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How's the weather?"),
ModelMessage(role=ModelMessageRoleType.AI, content="It's sunny!"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Great to hear!"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Great to hear!"
assert system_messages == ["System initializing..."]
assert history_messages == [["How's the weather?", "It's sunny!"]]
def test_parse_model_messages_three_round_conversation():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="What's up?"),
ModelMessage(role=ModelMessageRoleType.AI, content="Not much, you?"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Same here."),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "Same here."
assert system_messages == []
assert history_messages == [["Hi", "Hello!"], ["What's up?", "Not much, you?"]]
def test_parse_model_messages_multiple_system_messages():
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System start"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hey"),
ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"),
ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System check"),
ModelMessage(role=ModelMessageRoleType.HUMAN, content="How are you?"),
]
user_prompt, system_messages, history_messages = parse_model_messages(messages)
assert user_prompt == "How are you?"
assert system_messages == ["System start", "System check"]
assert history_messages == [["Hey", "Hello!"]]
def test_to_openai_messages(
human_model_message, ai_model_message, system_model_message
):

View File

@@ -8,6 +8,7 @@ from dbgpt.model.proxy.llms.claude import claude_generate_stream
from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
from dbgpt.model.proxy.llms.gemini import gemini_generate_stream
from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
from dbgpt.model.proxy.llms.spark import spark_generate_stream
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
@@ -25,6 +26,7 @@ def proxyllm_generate_stream(
"wenxin_proxyllm": wenxin_generate_stream,
"tongyi_proxyllm": tongyi_generate_stream,
"zhipu_proxyllm": zhipu_generate_stream,
"gemini_proxyllm": gemini_generate_stream,
"bc_proxyllm": baichuan_generate_stream,
"spark_proxyllm": spark_generate_stream,
}

View File

@@ -0,0 +1,109 @@
from typing import List, Tuple, Dict, Any
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
from dbgpt.core.interface.message import ModelMessage, parse_model_messages
GEMINI_DEFAULT_MODEL = "gemini-pro"
def gemini_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
global history
# TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key
proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend
generation_config = {
"temperature": 0.7,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]
import google.generativeai as genai
if model_params.proxy_api_base:
from google.api_core import client_options
client_opts = client_options.ClientOptions(
api_endpoint=model_params.proxy_api_base
)
genai.configure(
api_key=proxy_api_key, transport="rest", client_options=client_opts
)
else:
genai.configure(api_key=proxy_api_key)
model = genai.GenerativeModel(
model_name=proxyllm_backend,
generation_config=generation_config,
safety_settings=safety_settings,
)
messages: List[ModelMessage] = params["messages"]
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
chat = model.start_chat(history=gemini_hist)
response = chat.send_message(user_prompt, stream=True)
text = ""
for chunk in response:
text += chunk.text
print(text)
yield text
def _transform_to_gemini_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[Dict[str, Any]]]:
"""Transform messages to gemini format
See https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
Args:
messages (List[ModelMessage]): messages
Returns:
Tuple[str, List[Dict[str, Any]]]: user_prompt, gemini_hist
Examples:
.. code-block:: python
messages = [
ModelMessage(role="human", content="Hello"),
ModelMessage(role="ai", content="Hi there!"),
ModelMessage(role="human", content="How are you?"),
]
user_prompt, gemini_hist = _transform_to_gemini_messages(messages)
assert user_prompt == "How are you?"
assert gemini_hist == [
{"role": "user", "parts": {"text": "Hello"}},
{"role": "model", "parts": {"text": "Hi there!"}}
]
"""
user_prompt, system_messages, history_messages = parse_model_messages(messages)
if system_messages:
user_prompt = "".join(system_messages) + "\n" + user_prompt
gemini_hist = []
if history_messages:
for user_message, model_message in history_messages:
gemini_hist.append({"role": "user", "parts": {"text": user_message}})
gemini_hist.append({"role": "model", "parts": {"text": model_message}})
return user_prompt, gemini_hist

View File

@@ -6,7 +6,7 @@ from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
def __convert_2_zhipu_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = []
@@ -57,38 +57,7 @@ def zhipu_generate_stream(
zhipuai.api_key = proxy_api_key
messages: List[ModelMessage] = params["messages"]
# Add history conversation
# system = ""
# if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
# role_define = messages.pop(0)
# system = role_define.content
# else:
# message = messages.pop(0)
# if message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
# for message in messages:
# if message.role == ModelMessageRoleType.SYSTEM:
# history.append({"role": "user", "content": message.content})
# # elif message.role == ModelMessageRoleType.HUMAN:
# # history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.AI:
# history.append({"role": "assistant", "content": message.content})
# else:
# pass
#
# # temp_his = history[::-1]
# temp_his = history
# last_user_input = None
# for m in temp_his:
# if m["role"] == "user":
# last_user_input = m
# break
#
# if last_user_input:
# history.remove(last_user_input)
# history.append(last_user_input)
history, systems = __convert_2_wenxin_messages(messages)
history, systems = __convert_2_zhipu_messages(messages)
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,