[summary]adapt for xunfei Spark LLM (#846)

Co-authored-by: Edward Shine <ahxiao@iflytek.com>
This commit is contained in:
edward 2023-11-28 14:56:13 +08:00 committed by GitHub
parent 0b02451fb3
commit 962bd9a48c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 44 deletions

View File

@ -423,6 +423,13 @@ class ProxyModelParameters(BaseModelParameters):
},
)
proxy_api_secret: Optional[str] = field(
default=None,
metadata={
"help": "The app secret for current proxy LLM(Just for spark proxy LLM now)."
},
)
proxy_api_type: Optional[str] = field(
default=None,
metadata={

View File

@ -1,9 +1,8 @@
import os
import json
import base64
import hmac
import hashlib
import websockets
from websockets.sync.client import connect
from datetime import datetime
from typing import List
from time import mktime
@ -13,7 +12,22 @@ from wsgiref.handlers import format_date_time
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v2"
SPARK_DEFAULT_API_VERSION = "v3"
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while getlength(text) > 8192:
del text[0]
return text
def spark_generate_stream(
@ -23,41 +37,41 @@ def spark_generate_stream(
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
proxy_app_id = model_params.proxy_app_id
proxy_app_id = model_params.proxy_api_app_id
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
url = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
else:
url = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
else:
domain = "general"
url = "ws://spark-api.xf-yun.com/v1.1/chat"
messages: List[ModelMessage] = params["messages"]
last_user_input = None
for index in range(len(messages) - 1, -1, -1):
print(f"index: {index}")
if messages[index].role == ModelMessageRoleType.HUMAN:
last_user_input = {"role": "user", "content": messages[index].content}
del messages[index]
break
history = []
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
# There is no role for system in spark LLM
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
question = checklen(history + [last_user_input])
print('last_user_input.get("content")', last_user_input.get("content"))
data = {
"header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)},
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
"parameter": {
"chat": {
"domain": domain,
@ -67,23 +81,31 @@ def spark_generate_stream(
"temperature": params.get("temperature"),
}
},
"payload": {"message": {"text": last_user_input.get("content")}},
"payload": {"message": {"text": question}},
}
async_call(request_url, data)
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
return get_response(request_url, data)
async def async_call(request_url, data):
async with websockets.connect(request_url) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
chunk = ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
yield text[0]["content"]
def get_response(request_url, data):
with connect(request_url) as ws:
ws.send(json.dumps(data, ensure_ascii=False))
result = ""
while True:
try:
chunk = ws.recv()
response = json.loads(chunk)
print("look out the response: ", response)
choices = response.get("payload", {}).get("choices", {})
if text := choices.get("text"):
result += text[0]["content"]
if choices.get("status") == 2:
break
except Exception:
break
yield result
class SparkAPI:
@ -99,29 +121,33 @@ class SparkAPI:
self.spark_url = spark_url
def gen_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
_signature = "host: " + self.host + "\n"
_signature += "data: " + date + "\n"
_signature += "GET " + self.path + " HTTP/1.1"
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
_signature_sha = hmac.new(
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
_signature.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(
encoding="utf-8"
)
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization = base64.b64encode(_authorization.encode("utf-8")).decode(
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url