feat: add spark proxy api

This commit is contained in:
csunny
2023-10-11 18:27:26 +08:00
parent 14cfc34c72
commit 0126e9ef2f
5 changed files with 239 additions and 1 deletions

View File

@@ -60,6 +60,8 @@ LLM_MODEL_CONFIG = {
"wenxin_proxyllm": "wenxin_proxyllm",
"tongyi_proxyllm": "tongyi_proxyllm",
"zhipu_proxyllm": "zhipu_proxyllm",
"bc_proxyllm": "bc_proxyllm",
"spark_proxyllm": "spark_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),

View File

@@ -8,6 +8,8 @@ from pilot.model.proxy.llms.claude import claude_generate_stream
from pilot.model.proxy.llms.wenxin import wenxin_generate_stream
from pilot.model.proxy.llms.tongyi import tongyi_generate_stream
from pilot.model.proxy.llms.zhipu import zhipu_generate_stream
from pilot.model.proxy.llms.baichuan import baichuan_generate_stream
from pilot.model.proxy.llms.spark import spark_generate_stream
from pilot.model.proxy.llms.proxy_model import ProxyModel
@@ -23,6 +25,8 @@ def proxyllm_generate_stream(
"wenxin_proxyllm": wenxin_generate_stream,
"tongyi_proxyllm": tongyi_generate_stream,
"zhipu_proxyllm": zhipu_generate_stream,
"bc_proxyllm": baichuan_generate_stream,
"spark_proxyllm": spark_generate_stream,
}
model_params = model.get_params()
model_name = model_params.model_name

View File

@@ -0,0 +1,86 @@
import os
import hashlib
import json
import time
import requests
from typing import List
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B"
def _calculate_md5(text: str) -> str:
"""Calculate md5 """
md5 = hashlib.md5()
md5.update(text.encode("utf-8"))
encrypted = md5.hexdigest()
return encrypted
def _sign(data: dict, secret_key: str, timestamp: str):
data_str = json.dumps(data)
signature = _calculate_md5(secret_key + data_str + timestamp)
return signature
def baichuan_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
model_params = model.get_params()
url = "https://api.baichuan-ai.com/v1/stream/chat"
model_name = os.getenv("BAICHUN_MODEL_NAME") or BAICHUAN_DEFAULT_MODEL
proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET")
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": "message.content"})
else:
pass
payload = {
"model": model_name,
"messages": history,
"parameters": {
"temperature": params.get("temperature"),
"top_k": params.get("top_k", 10)
}
}
timestamp = int(time.time())
_signature = _sign(payload, proxy_api_secret, str(timestamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + proxy_api_key,
"X-BC-Request-Id": params.get("request_id") or "dbgpt",
"X-BC-Timestamp": str(timestamp),
"X-BC-Signature": _signature,
"X-BC-Sign-Algo": "MD5",
}
res = requests.post(url=url, json=payload, headers=headers, stream=True)
print(f"Send request to {url} with real model {model_name}")
text = ""
for line in res.iter_lines():
if line:
if not line.startswith(b"data: "):
error_message = line.decode("utf-8")
yield error_message
else:
json_data = line.split(b": ", 1)[1]
decoded_line = json_data.decode("utf-8")
if decoded_line.lower() != "[DONE]".lower():
obj = json.loads(json_data)
if obj["data"]["messages"][0].get("content") is not None:
content = obj["data"]["messages"][0].get("content")
text += content
yield text

View File

@@ -0,0 +1,134 @@
import os
import json
import base64
import hmac
import hashlib
import websockets
import asyncio
from datetime import datetime
from typing import List
from time import mktime
from urllib.parse import urlencode
from urllib.parse import urlparse
from wsgiref.handlers import format_date_time
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v2"
def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
model_params = model.get_params()
proxy_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") or SPARK_DEFAULT_API_VERSION
proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
proxy_app_id = os.getenv("XUNFEI_SPARK_APPID")
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
url = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
else:
domain = "general"
url = "ws://spark-api.xf-yun.com/v1.1/chat"
messages: List[ModelMessage] = params["messages"]
history = []
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
data = {
"header": {
"app_id": proxy_app_id,
"uid": params.get("request_id", 1)
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": context_len,
"auditing": "default",
"temperature": params.get("temperature")
}
},
"payload": {
"message": {
"text": last_user_input.get("")
}
}
}
# TODO
async def async_call(request_url, data):
async with websockets.connect(request_url) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
chunk = ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
yield text[0]["content"]
class SparkAPI:
def __init__(self, appid: str, api_key: str, api_secret: str, spark_url: str) -> None:
self.appid = appid
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
self.spark_url = spark_url
def gen_url(self):
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
_signature = "host: " + self.host + "\n"
_signature += "data: " + date + "\n"
_signature += "GET " + self.path + " HTTP/1.1"
_signature_sha = hmac.new(self.api_secret.encode("utf-8"), _signature.encode("utf-8"),
digestmod=hashlib.sha256).digest()
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(encoding="utf-8")
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
authorization = base64.b64encode(_authorization.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
url = self.spark_url + "?" + urlencode(v)
return url