From 0126e9ef2fe112ff31ccbd90a284d7be0ec0678c Mon Sep 17 00:00:00 2001 From: csunny Date: Wed, 11 Oct 2023 18:27:26 +0800 Subject: [PATCH] feat: add spark proxy api --- examples/{tongyi.py => proxy_example.py} | 14 ++- pilot/configs/model_config.py | 2 + pilot/model/llm_out/proxy_llm.py | 4 + pilot/model/proxy/llms/baichuan.py | 86 +++++++++++++++ pilot/model/proxy/llms/spark.py | 134 +++++++++++++++++++++++ 5 files changed, 239 insertions(+), 1 deletion(-) rename examples/{tongyi.py => proxy_example.py} (85%) create mode 100644 pilot/model/proxy/llms/baichuan.py create mode 100644 pilot/model/proxy/llms/spark.py diff --git a/examples/tongyi.py b/examples/proxy_example.py similarity index 85% rename from examples/tongyi.py rename to examples/proxy_example.py index 9240b430a..5d2f8e5db 100644 --- a/examples/tongyi.py +++ b/examples/proxy_example.py @@ -3,6 +3,7 @@ import dashscope import requests +import hashlib from http import HTTPStatus from dashscope import Generation @@ -44,6 +45,17 @@ def build_access_token(api_key: str, secret_key: str) -> str: if res.status_code == 200: return res.json().get("access_token") - + +def _calculate_md5(text: str) -> str: + + md5 = hashlib.md5() + md5.update(text.encode("utf-8")) + encrypted = md5.hexdigest() + return encrypted + +def baichuan_call(): + url = "https://api.baichuan-ai.com/v1/stream/chat" + + if __name__ == '__main__': call_with_messages() \ No newline at end of file diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 927272cb1..dba647ccd 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -60,6 +60,8 @@ LLM_MODEL_CONFIG = { "wenxin_proxyllm": "wenxin_proxyllm", "tongyi_proxyllm": "tongyi_proxyllm", "zhipu_proxyllm": "zhipu_proxyllm", + "bc_proxyllm": "bc_proxyllm", + "spark_proxyllm": "spark_proxyllm", "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index a0a4dd514..c25085083 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -8,6 +8,8 @@ from pilot.model.proxy.llms.claude import claude_generate_stream from pilot.model.proxy.llms.wenxin import wenxin_generate_stream from pilot.model.proxy.llms.tongyi import tongyi_generate_stream from pilot.model.proxy.llms.zhipu import zhipu_generate_stream +from pilot.model.proxy.llms.baichuan import baichuan_generate_stream +from pilot.model.proxy.llms.spark import spark_generate_stream from pilot.model.proxy.llms.proxy_model import ProxyModel @@ -23,6 +25,8 @@ def proxyllm_generate_stream( "wenxin_proxyllm": wenxin_generate_stream, "tongyi_proxyllm": tongyi_generate_stream, "zhipu_proxyllm": zhipu_generate_stream, + "bc_proxyllm": baichuan_generate_stream, + "spark_proxyllm": spark_generate_stream, } model_params = model.get_params() model_name = model_params.model_name diff --git a/pilot/model/proxy/llms/baichuan.py b/pilot/model/proxy/llms/baichuan.py new file mode 100644 index 000000000..436002cc6 --- /dev/null +++ b/pilot/model/proxy/llms/baichuan.py @@ -0,0 +1,86 @@ +import os +import hashlib +import json +import time +import requests +from typing import List +from pilot.model.proxy.llms.proxy_model import ProxyModel +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType + +BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B" + +def _calculate_md5(text: str) -> str: + """Calculate md5 """ + md5 = hashlib.md5() + md5.update(text.encode("utf-8")) + encrypted = md5.hexdigest() + return encrypted + +def _sign(data: dict, secret_key: str, timestamp: str): + data_str = json.dumps(data) + signature = _calculate_md5(secret_key + data_str + timestamp) + return signature + +def baichuan_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=4096 +): + model_params = model.get_params() + url = "https://api.baichuan-ai.com/v1/stream/chat" + + model_name = os.getenv("BAICHUN_MODEL_NAME") or BAICHUAN_DEFAULT_MODEL + proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY") + proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET") + + + history = [] + messages: List[ModelMessage] = params["messages"] + # Add history conversation + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + history.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + history.append({"role": "system", "content": message.content}) + elif message.role == ModelMessageRoleType.AI: + history.append({"role": "assistant", "content": "message.content"}) + else: + pass + + payload = { + "model": model_name, + "messages": history, + "parameters": { + "temperature": params.get("temperature"), + "top_k": params.get("top_k", 10) + } + } + + timestamp = int(time.time()) + _signature = _sign(payload, proxy_api_secret, str(timestamp)) + + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + proxy_api_key, + "X-BC-Request-Id": params.get("request_id") or "dbgpt", + "X-BC-Timestamp": str(timestamp), + "X-BC-Signature": _signature, + "X-BC-Sign-Algo": "MD5", + } + + res = requests.post(url=url, json=payload, headers=headers, stream=True) + print(f"Send request to {url} with real model {model_name}") + + text = "" + for line in res.iter_lines(): + if line: + if not line.startswith(b"data: "): + error_message = line.decode("utf-8") + yield error_message + else: + json_data = line.split(b": ", 1)[1] + decoded_line = json_data.decode("utf-8") + if decoded_line.lower() != "[DONE]".lower(): + obj = json.loads(json_data) + if obj["data"]["messages"][0].get("content") is not None: + content = obj["data"]["messages"][0].get("content") + text += content + yield text diff --git a/pilot/model/proxy/llms/spark.py b/pilot/model/proxy/llms/spark.py new file mode 100644 index 000000000..dfcfd9ff7 --- /dev/null +++ b/pilot/model/proxy/llms/spark.py @@ -0,0 +1,134 @@ +import os +import json +import base64 +import hmac +import hashlib +import websockets +import asyncio +from datetime import datetime +from typing import List +from time import mktime +from urllib.parse import urlencode +from urllib.parse import urlparse +from wsgiref.handlers import format_date_time +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +from pilot.model.proxy.llms.proxy_model import ProxyModel + +SPARK_DEFAULT_API_VERSION = "v2" + +def spark_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + model_params = model.get_params() + proxy_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") or SPARK_DEFAULT_API_VERSION + proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY") + proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET") + proxy_app_id = os.getenv("XUNFEI_SPARK_APPID") + + if proxy_api_version == SPARK_DEFAULT_API_VERSION: + url = "ws://spark-api.xf-yun.com/v2.1/chat" + domain = "generalv2" + else: + domain = "general" + url = "ws://spark-api.xf-yun.com/v1.1/chat" + + + messages: List[ModelMessage] = params["messages"] + + history = [] + # Add history conversation + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + history.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + history.append({"role": "system", "content": message.content}) + elif message.role == ModelMessageRoleType.AI: + history.append({"role": "assistant", "content": message.content}) + else: + pass + + spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url) + request_url = spark_api.gen_url() + + temp_his = history[::-1] + last_user_input = None + for m in temp_his: + if m["role"] == "user": + last_user_input = m + break + + data = { + "header": { + "app_id": proxy_app_id, + "uid": params.get("request_id", 1) + }, + "parameter": { + "chat": { + "domain": domain, + "random_threshold": 0.5, + "max_tokens": context_len, + "auditing": "default", + "temperature": params.get("temperature") + } + }, + "payload": { + "message": { + "text": last_user_input.get("") + } + } + } + + # TODO + + +async def async_call(request_url, data): + async with websockets.connect(request_url) as ws: + await ws.send(json.dumps(data, ensure_ascii=False)) + finish = False + while not finish: + chunk = ws.recv() + response = json.loads(chunk) + if response.get("header", {}).get("status") == 2: + finish = True + if text := response.get("payload", {}).get("choices", {}).get("text"): + yield text[0]["content"] + +class SparkAPI: + + def __init__(self, appid: str, api_key: str, api_secret: str, spark_url: str) -> None: + self.appid = appid + self.api_key = api_key + self.api_secret = api_secret + self.host = urlparse(spark_url).netloc + self.path = urlparse(spark_url).path + + self.spark_url = spark_url + + + def gen_url(self): + + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + _signature = "host: " + self.host + "\n" + _signature += "data: " + date + "\n" + _signature += "GET " + self.path + " HTTP/1.1" + + _signature_sha = hmac.new(self.api_secret.encode("utf-8"), _signature.encode("utf-8"), + digestmod=hashlib.sha256).digest() + + _signature_sha_base64 = base64.b64encode(_signature_sha).decode(encoding="utf-8") + _authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'" + + authorization = base64.b64encode(_authorization.encode('utf-8')).decode(encoding='utf-8') + + v = { + "authorization": authorization, + "date": date, + "host": self.host + } + + url = self.spark_url + "?" + urlencode(v) + return url + + \ No newline at end of file