mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 10:34:30 +00:00
feat: +zhipu and spark support
This commit is contained in:
@@ -4,7 +4,6 @@ import base64
|
|||||||
import hmac
|
import hmac
|
||||||
import hashlib
|
import hashlib
|
||||||
import websockets
|
import websockets
|
||||||
import asyncio
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
from time import mktime
|
from time import mktime
|
||||||
@@ -73,13 +72,12 @@ def spark_generate_stream(
|
|||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
"message": {
|
"message": {
|
||||||
"text": last_user_input.get("")
|
"text": last_user_input.get("content")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO
|
async_call(request_url, data)
|
||||||
|
|
||||||
|
|
||||||
async def async_call(request_url, data):
|
async def async_call(request_url, data):
|
||||||
async with websockets.connect(request_url) as ws:
|
async with websockets.connect(request_url) as ws:
|
||||||
|
@@ -1,5 +1,11 @@
|
|||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
import os
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
|
||||||
|
|
||||||
def zhipu_generate_stream(
|
def zhipu_generate_stream(
|
||||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
@@ -9,10 +15,33 @@ def zhipu_generate_stream(
|
|||||||
print(f"Model: {model}, model_params: {model_params}")
|
print(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
proxy_api_key = model_params.proxy_api_key
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_server_url = model_params.proxy_server_url
|
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
|
||||||
proxyllm_backend = model_params.proxyllm_backend
|
|
||||||
|
|
||||||
if not proxyllm_backend:
|
import zhipuai
|
||||||
proxyllm_backend = "chatglm_pro"
|
zhipuai.api_key = proxy_api_key
|
||||||
# TODO
|
|
||||||
yield "Zhipu LLM was not supported!"
|
history = []
|
||||||
|
|
||||||
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
# Add history conversation
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
history.append({"role": "user", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
history.append({"role": "system", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
|
history.append({"role": "assistant", "content": message.content})
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
res = zhipuai.model_api.sse_invoke(
|
||||||
|
model=proxyllm_backend,
|
||||||
|
prompt=history,
|
||||||
|
temperature=params.get("temperature"),
|
||||||
|
top_p=params.get("top_p"),
|
||||||
|
incremental=False,
|
||||||
|
)
|
||||||
|
for r in res.events():
|
||||||
|
if r.event == "add":
|
||||||
|
yield r.data
|
Reference in New Issue
Block a user