feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -13,6 +13,8 @@ def bard_generate_stream(
proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = []
messages: List[ModelMessage] = params["messages"]
for message in messages:
@@ -25,14 +27,15 @@ def bard_generate_stream(
else:
pass
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
if convert_to_compatible_format:
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
msgs = []
for msg in history:

View File

@@ -128,7 +128,10 @@ def _build_request(model: ProxyModel, params):
messages: List[ModelMessage] = params["messages"]
# history = __convert_2_gpt_messages(messages)
history = ModelMessage.to_openai_messages(messages)
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = ModelMessage.to_openai_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
)
payloads = {
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),

View File

@@ -12,7 +12,6 @@ def gemini_generate_stream(
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
global history
# TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key

View File

@@ -56,6 +56,9 @@ def spark_generate_stream(
del messages[index]
break
# TODO: Support convert_to_compatible_format config
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = []
# Add history conversation
for message in messages:

View File

@@ -53,8 +53,12 @@ def tongyi_generate_stream(
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
messages: List[ModelMessage] = params["messages"]
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = __convert_2_tongyi_messages(messages)
if convert_to_compatible_format:
history = __convert_2_tongyi_messages(messages)
else:
history = ModelMessage.to_openai_messages(messages)
gen = Generation()
res = gen.call(
proxyllm_backend,

View File

@@ -25,8 +25,29 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
return res.json().get("access_token")
def _to_wenxin_messages(messages: List[ModelMessage]):
"""Convert messages to wenxin compatible format
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
"""
wenxin_messages = []
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
wenxin_messages.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
wenxin_messages.append({"role": "assistant", "content": message.content})
else:
pass
if len(system_messages) > 1:
raise ValueError("Wenxin only support one system message")
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = []
last_usr_message = ""
@@ -57,7 +78,8 @@ def __convert_2_wenxin_messages(messages: List[ModelMessage]):
last_message = messages[-1]
end_message = last_message.content
wenxin_messages.append({"role": "user", "content": end_message})
return wenxin_messages, system_messages
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
return wenxin_messages, str_system_message
def wenxin_generate_stream(
@@ -87,13 +109,14 @@ def wenxin_generate_stream(
messages: List[ModelMessage] = params["messages"]
history, systems = __convert_2_wenxin_messages(messages)
system = ""
if systems and len(systems) > 0:
system = systems[0]
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
if convert_to_compatible_format:
history, system_message = __convert_2_wenxin_messages(messages)
else:
history, system_message = _to_wenxin_messages(messages)
payload = {
"messages": history,
"system": system,
"system": system_message,
"temperature": params.get("temperature"),
"stream": True,
}

View File

@@ -57,6 +57,10 @@ def zhipu_generate_stream(
zhipuai.api_key = proxy_api_key
messages: List[ModelMessage] = params["messages"]
# TODO: Support convert_to_compatible_format config, zhipu not support system message
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history, systems = __convert_2_zhipu_messages(messages)
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,