[summary]adapt for xunfei Spark LLM (#846)

Co-authored-by: Edward Shine <ahxiao@iflytek.com>
This commit is contained in:
edward 2023-11-28 14:56:13 +08:00 committed by GitHub
parent 0b02451fb3
commit 962bd9a48c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 44 deletions

View File

@ -423,6 +423,13 @@ class ProxyModelParameters(BaseModelParameters):
}, },
) )
proxy_api_secret: Optional[str] = field(
default=None,
metadata={
"help": "The app secret for current proxy LLM(Just for spark proxy LLM now)."
},
)
proxy_api_type: Optional[str] = field( proxy_api_type: Optional[str] = field(
default=None, default=None,
metadata={ metadata={

View File

@ -1,9 +1,8 @@
import os
import json import json
import base64 import base64
import hmac import hmac
import hashlib import hashlib
import websockets from websockets.sync.client import connect
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from time import mktime from time import mktime
@ -13,7 +12,22 @@ from wsgiref.handlers import format_date_time
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.model.proxy.llms.proxy_model import ProxyModel from pilot.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v2" SPARK_DEFAULT_API_VERSION = "v3"
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while getlength(text) > 8192:
del text[0]
return text
def spark_generate_stream( def spark_generate_stream(
@ -23,41 +37,41 @@ def spark_generate_stream(
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
proxy_api_key = model_params.proxy_api_key proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret proxy_api_secret = model_params.proxy_api_secret
proxy_app_id = model_params.proxy_app_id proxy_app_id = model_params.proxy_api_app_id
if proxy_api_version == SPARK_DEFAULT_API_VERSION: if proxy_api_version == SPARK_DEFAULT_API_VERSION:
url = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
else:
url = "ws://spark-api.xf-yun.com/v2.1/chat" url = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2" domain = "generalv2"
else:
domain = "general"
url = "ws://spark-api.xf-yun.com/v1.1/chat"
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
last_user_input = None
for index in range(len(messages) - 1, -1, -1):
print(f"index: {index}")
if messages[index].role == ModelMessageRoleType.HUMAN:
last_user_input = {"role": "user", "content": messages[index].content}
del messages[index]
break
history = [] history = []
# Add history conversation # Add history conversation
for message in messages: for message in messages:
if message.role == ModelMessageRoleType.HUMAN: # There is no role for system in spark LLM
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content}) history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI: elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content}) history.append({"role": "assistant", "content": message.content})
else: else:
pass pass
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url) question = checklen(history + [last_user_input])
request_url = spark_api.gen_url()
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
print('last_user_input.get("content")', last_user_input.get("content"))
data = { data = {
"header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)}, "header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
"parameter": { "parameter": {
"chat": { "chat": {
"domain": domain, "domain": domain,
@ -67,23 +81,31 @@ def spark_generate_stream(
"temperature": params.get("temperature"), "temperature": params.get("temperature"),
} }
}, },
"payload": {"message": {"text": last_user_input.get("content")}}, "payload": {"message": {"text": question}},
} }
async_call(request_url, data) spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
return get_response(request_url, data)
async def async_call(request_url, data): def get_response(request_url, data):
async with websockets.connect(request_url) as ws: with connect(request_url) as ws:
await ws.send(json.dumps(data, ensure_ascii=False)) ws.send(json.dumps(data, ensure_ascii=False))
finish = False result = ""
while not finish: while True:
try:
chunk = ws.recv() chunk = ws.recv()
response = json.loads(chunk) response = json.loads(chunk)
if response.get("header", {}).get("status") == 2: print("look out the response: ", response)
finish = True choices = response.get("payload", {}).get("choices", {})
if text := response.get("payload", {}).get("choices", {}).get("text"): if text := choices.get("text"):
yield text[0]["content"] result += text[0]["content"]
if choices.get("status") == 2:
break
except Exception:
break
yield result
class SparkAPI: class SparkAPI:
@ -99,29 +121,33 @@ class SparkAPI:
self.spark_url = spark_url self.spark_url = spark_url
def gen_url(self): def gen_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now() now = datetime.now()
date = format_date_time(mktime(now.timetuple())) date = format_date_time(mktime(now.timetuple()))
_signature = "host: " + self.host + "\n" # 拼接字符串
_signature += "data: " + date + "\n" signature_origin = "host: " + self.host + "\n"
_signature += "GET " + self.path + " HTTP/1.1" signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
_signature_sha = hmac.new( # 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"), self.api_secret.encode("utf-8"),
_signature.encode("utf-8"), signature_origin.encode("utf-8"),
digestmod=hashlib.sha256, digestmod=hashlib.sha256,
).digest() ).digest()
_signature_sha_base64 = base64.b64encode(_signature_sha).decode( signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
encoding="utf-8"
)
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
authorization = base64.b64encode(_authorization.encode("utf-8")).decode( authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8" encoding="utf-8"
) )
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host} v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v) url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url return url