mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 22:15:35 +00:00
[summary]adapt for xunfei Spark LLM (#846)
Co-authored-by: Edward Shine <ahxiao@iflytek.com>
This commit is contained in:
parent
0b02451fb3
commit
962bd9a48c
@ -423,6 +423,13 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
proxy_api_secret: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The app secret for current proxy LLM(Just for spark proxy LLM now)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
proxy_api_type: Optional[str] = field(
|
proxy_api_type: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
import hmac
|
import hmac
|
||||||
import hashlib
|
import hashlib
|
||||||
import websockets
|
from websockets.sync.client import connect
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
from time import mktime
|
from time import mktime
|
||||||
@ -13,7 +12,22 @@ from wsgiref.handlers import format_date_time
|
|||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
SPARK_DEFAULT_API_VERSION = "v2"
|
SPARK_DEFAULT_API_VERSION = "v3"
|
||||||
|
|
||||||
|
|
||||||
|
def getlength(text):
|
||||||
|
length = 0
|
||||||
|
for content in text:
|
||||||
|
temp = content["content"]
|
||||||
|
leng = len(temp)
|
||||||
|
length += leng
|
||||||
|
return length
|
||||||
|
|
||||||
|
|
||||||
|
def checklen(text):
|
||||||
|
while getlength(text) > 8192:
|
||||||
|
del text[0]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def spark_generate_stream(
|
def spark_generate_stream(
|
||||||
@ -23,41 +37,41 @@ def spark_generate_stream(
|
|||||||
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
|
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
|
||||||
proxy_api_key = model_params.proxy_api_key
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_api_secret = model_params.proxy_api_secret
|
proxy_api_secret = model_params.proxy_api_secret
|
||||||
proxy_app_id = model_params.proxy_app_id
|
proxy_app_id = model_params.proxy_api_app_id
|
||||||
|
|
||||||
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
||||||
|
url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||||
|
domain = "generalv3"
|
||||||
|
else:
|
||||||
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||||
domain = "generalv2"
|
domain = "generalv2"
|
||||||
else:
|
|
||||||
domain = "general"
|
|
||||||
url = "ws://spark-api.xf-yun.com/v1.1/chat"
|
|
||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
|
||||||
|
last_user_input = None
|
||||||
|
for index in range(len(messages) - 1, -1, -1):
|
||||||
|
print(f"index: {index}")
|
||||||
|
if messages[index].role == ModelMessageRoleType.HUMAN:
|
||||||
|
last_user_input = {"role": "user", "content": messages[index].content}
|
||||||
|
del messages[index]
|
||||||
|
break
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == ModelMessageRoleType.HUMAN:
|
# There is no role for system in spark LLM
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
|
||||||
history.append({"role": "user", "content": message.content})
|
history.append({"role": "user", "content": message.content})
|
||||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
|
||||||
history.append({"role": "system", "content": message.content})
|
|
||||||
elif message.role == ModelMessageRoleType.AI:
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
history.append({"role": "assistant", "content": message.content})
|
history.append({"role": "assistant", "content": message.content})
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
|
question = checklen(history + [last_user_input])
|
||||||
request_url = spark_api.gen_url()
|
|
||||||
|
|
||||||
temp_his = history[::-1]
|
|
||||||
last_user_input = None
|
|
||||||
for m in temp_his:
|
|
||||||
if m["role"] == "user":
|
|
||||||
last_user_input = m
|
|
||||||
break
|
|
||||||
|
|
||||||
|
print('last_user_input.get("content")', last_user_input.get("content"))
|
||||||
data = {
|
data = {
|
||||||
"header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)},
|
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
|
||||||
"parameter": {
|
"parameter": {
|
||||||
"chat": {
|
"chat": {
|
||||||
"domain": domain,
|
"domain": domain,
|
||||||
@ -67,23 +81,31 @@ def spark_generate_stream(
|
|||||||
"temperature": params.get("temperature"),
|
"temperature": params.get("temperature"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"payload": {"message": {"text": last_user_input.get("content")}},
|
"payload": {"message": {"text": question}},
|
||||||
}
|
}
|
||||||
|
|
||||||
async_call(request_url, data)
|
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
|
||||||
|
request_url = spark_api.gen_url()
|
||||||
|
return get_response(request_url, data)
|
||||||
|
|
||||||
|
|
||||||
async def async_call(request_url, data):
|
def get_response(request_url, data):
|
||||||
async with websockets.connect(request_url) as ws:
|
with connect(request_url) as ws:
|
||||||
await ws.send(json.dumps(data, ensure_ascii=False))
|
ws.send(json.dumps(data, ensure_ascii=False))
|
||||||
finish = False
|
result = ""
|
||||||
while not finish:
|
while True:
|
||||||
|
try:
|
||||||
chunk = ws.recv()
|
chunk = ws.recv()
|
||||||
response = json.loads(chunk)
|
response = json.loads(chunk)
|
||||||
if response.get("header", {}).get("status") == 2:
|
print("look out the response: ", response)
|
||||||
finish = True
|
choices = response.get("payload", {}).get("choices", {})
|
||||||
if text := response.get("payload", {}).get("choices", {}).get("text"):
|
if text := choices.get("text"):
|
||||||
yield text[0]["content"]
|
result += text[0]["content"]
|
||||||
|
if choices.get("status") == 2:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
class SparkAPI:
|
class SparkAPI:
|
||||||
@ -99,29 +121,33 @@ class SparkAPI:
|
|||||||
self.spark_url = spark_url
|
self.spark_url = spark_url
|
||||||
|
|
||||||
def gen_url(self):
|
def gen_url(self):
|
||||||
|
# 生成RFC1123格式的时间戳
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
date = format_date_time(mktime(now.timetuple()))
|
date = format_date_time(mktime(now.timetuple()))
|
||||||
|
|
||||||
_signature = "host: " + self.host + "\n"
|
# 拼接字符串
|
||||||
_signature += "data: " + date + "\n"
|
signature_origin = "host: " + self.host + "\n"
|
||||||
_signature += "GET " + self.path + " HTTP/1.1"
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||||
|
|
||||||
_signature_sha = hmac.new(
|
# 进行hmac-sha256进行加密
|
||||||
|
signature_sha = hmac.new(
|
||||||
self.api_secret.encode("utf-8"),
|
self.api_secret.encode("utf-8"),
|
||||||
_signature.encode("utf-8"),
|
signature_origin.encode("utf-8"),
|
||||||
digestmod=hashlib.sha256,
|
digestmod=hashlib.sha256,
|
||||||
).digest()
|
).digest()
|
||||||
|
|
||||||
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||||
encoding="utf-8"
|
|
||||||
)
|
|
||||||
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
|
|
||||||
|
|
||||||
authorization = base64.b64encode(_authorization.encode("utf-8")).decode(
|
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||||
|
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
||||||
encoding="utf-8"
|
encoding="utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 将请求的鉴权参数组合为字典
|
||||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||||
|
# 拼接鉴权参数,生成url
|
||||||
url = self.spark_url + "?" + urlencode(v)
|
url = self.spark_url + "?" + urlencode(v)
|
||||||
|
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||||
return url
|
return url
|
||||||
|
Loading…
Reference in New Issue
Block a user