mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -13,6 +13,8 @@ def bard_generate_stream(
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_server_url = model_params.proxy_server_url
|
||||
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
history = []
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
for message in messages:
|
||||
@@ -25,14 +27,15 @@ def bard_generate_stream(
|
||||
else:
|
||||
pass
|
||||
|
||||
last_user_input_index = None
|
||||
for i in range(len(history) - 1, -1, -1):
|
||||
if history[i]["role"] == "user":
|
||||
last_user_input_index = i
|
||||
break
|
||||
if last_user_input_index:
|
||||
last_user_input = history.pop(last_user_input_index)
|
||||
history.append(last_user_input)
|
||||
if convert_to_compatible_format:
|
||||
last_user_input_index = None
|
||||
for i in range(len(history) - 1, -1, -1):
|
||||
if history[i]["role"] == "user":
|
||||
last_user_input_index = i
|
||||
break
|
||||
if last_user_input_index:
|
||||
last_user_input = history.pop(last_user_input_index)
|
||||
history.append(last_user_input)
|
||||
|
||||
msgs = []
|
||||
for msg in history:
|
||||
|
@@ -128,7 +128,10 @@ def _build_request(model: ProxyModel, params):
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
# history = __convert_2_gpt_messages(messages)
|
||||
history = ModelMessage.to_openai_messages(messages)
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
history = ModelMessage.to_openai_messages(
|
||||
messages, convert_to_compatible_format=convert_to_compatible_format
|
||||
)
|
||||
payloads = {
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
|
@@ -12,7 +12,6 @@ def gemini_generate_stream(
|
||||
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
global history
|
||||
|
||||
# TODO proxy model use unified config?
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
|
@@ -56,6 +56,9 @@ def spark_generate_stream(
|
||||
del messages[index]
|
||||
break
|
||||
|
||||
# TODO: Support convert_to_compatible_format config
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
history = []
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
|
@@ -53,8 +53,12 @@ def tongyi_generate_stream(
|
||||
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
history = __convert_2_tongyi_messages(messages)
|
||||
if convert_to_compatible_format:
|
||||
history = __convert_2_tongyi_messages(messages)
|
||||
else:
|
||||
history = ModelMessage.to_openai_messages(messages)
|
||||
gen = Generation()
|
||||
res = gen.call(
|
||||
proxyllm_backend,
|
||||
|
@@ -25,8 +25,29 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
|
||||
return res.json().get("access_token")
|
||||
|
||||
|
||||
def _to_wenxin_messages(messages: List[ModelMessage]):
|
||||
"""Convert messages to wenxin compatible format
|
||||
|
||||
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
|
||||
"""
|
||||
wenxin_messages = []
|
||||
system_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
wenxin_messages.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(message.content)
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
wenxin_messages.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
if len(system_messages) > 1:
|
||||
raise ValueError("Wenxin only support one system message")
|
||||
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
|
||||
return wenxin_messages, str_system_message
|
||||
|
||||
|
||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
wenxin_messages = []
|
||||
|
||||
last_usr_message = ""
|
||||
@@ -57,7 +78,8 @@ def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||
last_message = messages[-1]
|
||||
end_message = last_message.content
|
||||
wenxin_messages.append({"role": "user", "content": end_message})
|
||||
return wenxin_messages, system_messages
|
||||
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
|
||||
return wenxin_messages, str_system_message
|
||||
|
||||
|
||||
def wenxin_generate_stream(
|
||||
@@ -87,13 +109,14 @@ def wenxin_generate_stream(
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
history, systems = __convert_2_wenxin_messages(messages)
|
||||
system = ""
|
||||
if systems and len(systems) > 0:
|
||||
system = systems[0]
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
if convert_to_compatible_format:
|
||||
history, system_message = __convert_2_wenxin_messages(messages)
|
||||
else:
|
||||
history, system_message = _to_wenxin_messages(messages)
|
||||
payload = {
|
||||
"messages": history,
|
||||
"system": system,
|
||||
"system": system_message,
|
||||
"temperature": params.get("temperature"),
|
||||
"stream": True,
|
||||
}
|
||||
|
@@ -57,6 +57,10 @@ def zhipu_generate_stream(
|
||||
zhipuai.api_key = proxy_api_key
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
# TODO: Support convert_to_compatible_format config, zhipu not support system message
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
history, systems = __convert_2_zhipu_messages(messages)
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
model=proxyllm_backend,
|
||||
|
Reference in New Issue
Block a user