feat: +zhipu and spark support

This commit is contained in:
csunny
2023-10-11 23:32:59 +08:00
parent 0126e9ef2f
commit 990c2aa939
2 changed files with 38 additions and 11 deletions

View File

@@ -4,7 +4,6 @@ import base64
import hmac
import hashlib
import websockets
import asyncio
from datetime import datetime
from typing import List
from time import mktime
@@ -73,13 +72,12 @@ def spark_generate_stream(
},
"payload": {
"message": {
"text": last_user_input.get("")
"text": last_user_input.get("content")
}
}
}
# TODO
async_call(request_url, data)
async def async_call(request_url, data):
async with websockets.connect(request_url) as ws:

View File

@@ -1,5 +1,11 @@
from pilot.model.proxy.llms.proxy_model import ProxyModel
import os
import json
from typing import List
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
@@ -9,10 +15,33 @@ def zhipu_generate_stream(
print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url
proxyllm_backend = model_params.proxyllm_backend
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = "chatglm_pro"
# TODO
yield "Zhipu LLM was not supported!"
import zhipuai
zhipuai.api_key = proxy_api_key
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
)
for r in res.events():
if r.event == "add":
yield r.data