feat: +zhipu and spark support

This commit is contained in:
csunny
2023-10-11 23:32:59 +08:00
parent 0126e9ef2f
commit 990c2aa939
2 changed files with 38 additions and 11 deletions

View File

@@ -4,7 +4,6 @@ import base64
import hmac import hmac
import hashlib import hashlib
import websockets import websockets
import asyncio
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from time import mktime from time import mktime
@@ -73,13 +72,12 @@ def spark_generate_stream(
}, },
"payload": { "payload": {
"message": { "message": {
"text": last_user_input.get("") "text": last_user_input.get("content")
} }
} }
} }
# TODO async_call(request_url, data)
async def async_call(request_url, data): async def async_call(request_url, data):
async with websockets.connect(request_url) as ws: async with websockets.connect(request_url) as ws:

View File

@@ -1,5 +1,11 @@
from pilot.model.proxy.llms.proxy_model import ProxyModel import os
import json
from typing import List
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
def zhipu_generate_stream( def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048 model: ProxyModel, tokenizer, params, device, context_len=2048
@@ -9,10 +15,33 @@ def zhipu_generate_stream(
print(f"Model: {model}, model_params: {model_params}") print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend: import zhipuai
proxyllm_backend = "chatglm_pro" zhipuai.api_key = proxy_api_key
# TODO
yield "Zhipu LLM was not supported!" history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
)
for r in res.events():
if r.event == "add":
yield r.data