diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 79558e02c..2727f38ca 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -423,6 +423,13 @@ class ProxyModelParameters(BaseModelParameters): }, ) + proxy_api_secret: Optional[str] = field( + default=None, + metadata={ + "help": "The app secret for current proxy LLM(Just for spark proxy LLM now)." + }, + ) + proxy_api_type: Optional[str] = field( default=None, metadata={ diff --git a/pilot/model/proxy/llms/spark.py b/pilot/model/proxy/llms/spark.py index 72a9ccd2f..c9806a5ef 100644 --- a/pilot/model/proxy/llms/spark.py +++ b/pilot/model/proxy/llms/spark.py @@ -1,9 +1,8 @@ -import os import json import base64 import hmac import hashlib -import websockets +from websockets.sync.client import connect from datetime import datetime from typing import List from time import mktime @@ -13,7 +12,22 @@ from wsgiref.handlers import format_date_time from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.model.proxy.llms.proxy_model import ProxyModel -SPARK_DEFAULT_API_VERSION = "v2" +SPARK_DEFAULT_API_VERSION = "v3" + + +def getlength(text): + length = 0 + for content in text: + temp = content["content"] + leng = len(temp) + length += leng + return length + + +def checklen(text): + while getlength(text) > 8192: + del text[0] + return text def spark_generate_stream( @@ -23,41 +37,41 @@ def spark_generate_stream( proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION proxy_api_key = model_params.proxy_api_key proxy_api_secret = model_params.proxy_api_secret - proxy_app_id = model_params.proxy_app_id + proxy_app_id = model_params.proxy_api_app_id if proxy_api_version == SPARK_DEFAULT_API_VERSION: + url = "ws://spark-api.xf-yun.com/v3.1/chat" + domain = "generalv3" + else: url = "ws://spark-api.xf-yun.com/v2.1/chat" domain = "generalv2" - else: - domain = "general" - url = "ws://spark-api.xf-yun.com/v1.1/chat" messages: List[ModelMessage] = params["messages"] + last_user_input = None + for index in range(len(messages) - 1, -1, -1): + print(f"index: {index}") + if messages[index].role == ModelMessageRoleType.HUMAN: + last_user_input = {"role": "user", "content": messages[index].content} + del messages[index] + break + history = [] # Add history conversation for message in messages: - if message.role == ModelMessageRoleType.HUMAN: + # There is no role for system in spark LLM + if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM: history.append({"role": "user", "content": message.content}) - elif message.role == ModelMessageRoleType.SYSTEM: - history.append({"role": "system", "content": message.content}) elif message.role == ModelMessageRoleType.AI: history.append({"role": "assistant", "content": message.content}) else: pass - spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url) - request_url = spark_api.gen_url() - - temp_his = history[::-1] - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m - break + question = checklen(history + [last_user_input]) + print('last_user_input.get("content")', last_user_input.get("content")) data = { - "header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)}, + "header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))}, "parameter": { "chat": { "domain": domain, @@ -67,23 +81,31 @@ def spark_generate_stream( "temperature": params.get("temperature"), } }, - "payload": {"message": {"text": last_user_input.get("content")}}, + "payload": {"message": {"text": question}}, } - async_call(request_url, data) + spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url) + request_url = spark_api.gen_url() + return get_response(request_url, data) -async def async_call(request_url, data): - async with websockets.connect(request_url) as ws: - await ws.send(json.dumps(data, ensure_ascii=False)) - finish = False - while not finish: - chunk = ws.recv() - response = json.loads(chunk) - if response.get("header", {}).get("status") == 2: - finish = True - if text := response.get("payload", {}).get("choices", {}).get("text"): - yield text[0]["content"] +def get_response(request_url, data): + with connect(request_url) as ws: + ws.send(json.dumps(data, ensure_ascii=False)) + result = "" + while True: + try: + chunk = ws.recv() + response = json.loads(chunk) + print("look out the response: ", response) + choices = response.get("payload", {}).get("choices", {}) + if text := choices.get("text"): + result += text[0]["content"] + if choices.get("status") == 2: + break + except Exception: + break + yield result class SparkAPI: @@ -99,29 +121,33 @@ class SparkAPI: self.spark_url = spark_url def gen_url(self): + # 生成RFC1123格式的时间戳 now = datetime.now() date = format_date_time(mktime(now.timetuple())) - _signature = "host: " + self.host + "\n" - _signature += "data: " + date + "\n" - _signature += "GET " + self.path + " HTTP/1.1" + # 拼接字符串 + signature_origin = "host: " + self.host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + self.path + " HTTP/1.1" - _signature_sha = hmac.new( + # 进行hmac-sha256进行加密 + signature_sha = hmac.new( self.api_secret.encode("utf-8"), - _signature.encode("utf-8"), + signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() - _signature_sha_base64 = base64.b64encode(_signature_sha).decode( - encoding="utf-8" - ) - _authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'" + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization = base64.b64encode(_authorization.encode("utf-8")).decode( + authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( encoding="utf-8" ) + # 将请求的鉴权参数组合为字典 v = {"authorization": authorization, "date": date, "host": self.host} - + # 拼接鉴权参数,生成url url = self.spark_url + "?" + urlencode(v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 return url