mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
feat: add spark proxy api
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import dashscope
|
import dashscope
|
||||||
import requests
|
import requests
|
||||||
|
import hashlib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from dashscope import Generation
|
from dashscope import Generation
|
||||||
|
|
||||||
@@ -44,6 +45,17 @@ def build_access_token(api_key: str, secret_key: str) -> str:
|
|||||||
if res.status_code == 200:
|
if res.status_code == 200:
|
||||||
return res.json().get("access_token")
|
return res.json().get("access_token")
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_md5(text: str) -> str:
|
||||||
|
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
md5.update(text.encode("utf-8"))
|
||||||
|
encrypted = md5.hexdigest()
|
||||||
|
return encrypted
|
||||||
|
|
||||||
|
def baichuan_call():
|
||||||
|
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
call_with_messages()
|
call_with_messages()
|
@@ -60,6 +60,8 @@ LLM_MODEL_CONFIG = {
|
|||||||
"wenxin_proxyllm": "wenxin_proxyllm",
|
"wenxin_proxyllm": "wenxin_proxyllm",
|
||||||
"tongyi_proxyllm": "tongyi_proxyllm",
|
"tongyi_proxyllm": "tongyi_proxyllm",
|
||||||
"zhipu_proxyllm": "zhipu_proxyllm",
|
"zhipu_proxyllm": "zhipu_proxyllm",
|
||||||
|
"bc_proxyllm": "bc_proxyllm",
|
||||||
|
"spark_proxyllm": "spark_proxyllm",
|
||||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||||
|
@@ -8,6 +8,8 @@ from pilot.model.proxy.llms.claude import claude_generate_stream
|
|||||||
from pilot.model.proxy.llms.wenxin import wenxin_generate_stream
|
from pilot.model.proxy.llms.wenxin import wenxin_generate_stream
|
||||||
from pilot.model.proxy.llms.tongyi import tongyi_generate_stream
|
from pilot.model.proxy.llms.tongyi import tongyi_generate_stream
|
||||||
from pilot.model.proxy.llms.zhipu import zhipu_generate_stream
|
from pilot.model.proxy.llms.zhipu import zhipu_generate_stream
|
||||||
|
from pilot.model.proxy.llms.baichuan import baichuan_generate_stream
|
||||||
|
from pilot.model.proxy.llms.spark import spark_generate_stream
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +25,8 @@ def proxyllm_generate_stream(
|
|||||||
"wenxin_proxyllm": wenxin_generate_stream,
|
"wenxin_proxyllm": wenxin_generate_stream,
|
||||||
"tongyi_proxyllm": tongyi_generate_stream,
|
"tongyi_proxyllm": tongyi_generate_stream,
|
||||||
"zhipu_proxyllm": zhipu_generate_stream,
|
"zhipu_proxyllm": zhipu_generate_stream,
|
||||||
|
"bc_proxyllm": baichuan_generate_stream,
|
||||||
|
"spark_proxyllm": spark_generate_stream,
|
||||||
}
|
}
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
model_name = model_params.model_name
|
model_name = model_params.model_name
|
||||||
|
86
pilot/model/proxy/llms/baichuan.py
Normal file
86
pilot/model/proxy/llms/baichuan.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
from typing import List
|
||||||
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B"
|
||||||
|
|
||||||
|
def _calculate_md5(text: str) -> str:
|
||||||
|
"""Calculate md5 """
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
md5.update(text.encode("utf-8"))
|
||||||
|
encrypted = md5.hexdigest()
|
||||||
|
return encrypted
|
||||||
|
|
||||||
|
def _sign(data: dict, secret_key: str, timestamp: str):
|
||||||
|
data_str = json.dumps(data)
|
||||||
|
signature = _calculate_md5(secret_key + data_str + timestamp)
|
||||||
|
return signature
|
||||||
|
|
||||||
|
def baichuan_generate_stream(
|
||||||
|
model: ProxyModel, tokenizer, params, device, context_len=4096
|
||||||
|
):
|
||||||
|
model_params = model.get_params()
|
||||||
|
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||||
|
|
||||||
|
model_name = os.getenv("BAICHUN_MODEL_NAME") or BAICHUAN_DEFAULT_MODEL
|
||||||
|
proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
||||||
|
proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET")
|
||||||
|
|
||||||
|
|
||||||
|
history = []
|
||||||
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
# Add history conversation
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
history.append({"role": "user", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
history.append({"role": "system", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
|
history.append({"role": "assistant", "content": "message.content"})
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": history,
|
||||||
|
"parameters": {
|
||||||
|
"temperature": params.get("temperature"),
|
||||||
|
"top_k": params.get("top_k", 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
_signature = _sign(payload, proxy_api_secret, str(timestamp))
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer " + proxy_api_key,
|
||||||
|
"X-BC-Request-Id": params.get("request_id") or "dbgpt",
|
||||||
|
"X-BC-Timestamp": str(timestamp),
|
||||||
|
"X-BC-Signature": _signature,
|
||||||
|
"X-BC-Sign-Algo": "MD5",
|
||||||
|
}
|
||||||
|
|
||||||
|
res = requests.post(url=url, json=payload, headers=headers, stream=True)
|
||||||
|
print(f"Send request to {url} with real model {model_name}")
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for line in res.iter_lines():
|
||||||
|
if line:
|
||||||
|
if not line.startswith(b"data: "):
|
||||||
|
error_message = line.decode("utf-8")
|
||||||
|
yield error_message
|
||||||
|
else:
|
||||||
|
json_data = line.split(b": ", 1)[1]
|
||||||
|
decoded_line = json_data.decode("utf-8")
|
||||||
|
if decoded_line.lower() != "[DONE]".lower():
|
||||||
|
obj = json.loads(json_data)
|
||||||
|
if obj["data"]["messages"][0].get("content") is not None:
|
||||||
|
content = obj["data"]["messages"][0].get("content")
|
||||||
|
text += content
|
||||||
|
yield text
|
134
pilot/model/proxy/llms/spark.py
Normal file
134
pilot/model/proxy/llms/spark.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import websockets
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
from time import mktime
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from wsgiref.handlers import format_date_time
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
|
SPARK_DEFAULT_API_VERSION = "v2"
|
||||||
|
|
||||||
|
def spark_generate_stream(
|
||||||
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
|
):
|
||||||
|
model_params = model.get_params()
|
||||||
|
proxy_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") or SPARK_DEFAULT_API_VERSION
|
||||||
|
proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
|
||||||
|
proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
|
||||||
|
proxy_app_id = os.getenv("XUNFEI_SPARK_APPID")
|
||||||
|
|
||||||
|
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
||||||
|
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||||
|
domain = "generalv2"
|
||||||
|
else:
|
||||||
|
domain = "general"
|
||||||
|
url = "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||||
|
|
||||||
|
|
||||||
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
|
||||||
|
history = []
|
||||||
|
# Add history conversation
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
history.append({"role": "user", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
history.append({"role": "system", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
|
history.append({"role": "assistant", "content": message.content})
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
|
||||||
|
request_url = spark_api.gen_url()
|
||||||
|
|
||||||
|
temp_his = history[::-1]
|
||||||
|
last_user_input = None
|
||||||
|
for m in temp_his:
|
||||||
|
if m["role"] == "user":
|
||||||
|
last_user_input = m
|
||||||
|
break
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"header": {
|
||||||
|
"app_id": proxy_app_id,
|
||||||
|
"uid": params.get("request_id", 1)
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"chat": {
|
||||||
|
"domain": domain,
|
||||||
|
"random_threshold": 0.5,
|
||||||
|
"max_tokens": context_len,
|
||||||
|
"auditing": "default",
|
||||||
|
"temperature": params.get("temperature")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"message": {
|
||||||
|
"text": last_user_input.get("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
|
||||||
|
|
||||||
|
async def async_call(request_url, data):
|
||||||
|
async with websockets.connect(request_url) as ws:
|
||||||
|
await ws.send(json.dumps(data, ensure_ascii=False))
|
||||||
|
finish = False
|
||||||
|
while not finish:
|
||||||
|
chunk = ws.recv()
|
||||||
|
response = json.loads(chunk)
|
||||||
|
if response.get("header", {}).get("status") == 2:
|
||||||
|
finish = True
|
||||||
|
if text := response.get("payload", {}).get("choices", {}).get("text"):
|
||||||
|
yield text[0]["content"]
|
||||||
|
|
||||||
|
class SparkAPI:
|
||||||
|
|
||||||
|
def __init__(self, appid: str, api_key: str, api_secret: str, spark_url: str) -> None:
|
||||||
|
self.appid = appid
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_secret = api_secret
|
||||||
|
self.host = urlparse(spark_url).netloc
|
||||||
|
self.path = urlparse(spark_url).path
|
||||||
|
|
||||||
|
self.spark_url = spark_url
|
||||||
|
|
||||||
|
|
||||||
|
def gen_url(self):
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
date = format_date_time(mktime(now.timetuple()))
|
||||||
|
|
||||||
|
_signature = "host: " + self.host + "\n"
|
||||||
|
_signature += "data: " + date + "\n"
|
||||||
|
_signature += "GET " + self.path + " HTTP/1.1"
|
||||||
|
|
||||||
|
_signature_sha = hmac.new(self.api_secret.encode("utf-8"), _signature.encode("utf-8"),
|
||||||
|
digestmod=hashlib.sha256).digest()
|
||||||
|
|
||||||
|
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(encoding="utf-8")
|
||||||
|
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
|
||||||
|
|
||||||
|
authorization = base64.b64encode(_authorization.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
v = {
|
||||||
|
"authorization": authorization,
|
||||||
|
"date": date,
|
||||||
|
"host": self.host
|
||||||
|
}
|
||||||
|
|
||||||
|
url = self.spark_url + "?" + urlencode(v)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
Reference in New Issue
Block a user