feat: call xunfei spark with stream, and fix the temperature bug (#2121)

Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
HIYIZI
2024-11-19 23:30:02 +08:00
committed by GitHub
parent 4efe643db8
commit 3ccfa94219
9 changed files with 183 additions and 250 deletions

View File

@@ -1,21 +1,13 @@
import base64
import hashlib
import hmac
import json
import os
from concurrent.futures import Executor
from datetime import datetime
from time import mktime
from typing import Iterator, Optional
from urllib.parse import urlencode, urlparse
from typing import AsyncIterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v3"
def getlength(text):
length = 0
@@ -49,7 +41,7 @@ def spark_generate_stream(
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
for r in client.generate_stream(request):
yield r
@@ -74,120 +66,57 @@ def get_response(request_url, data):
yield result
class SparkAPI:
def __init__(
self, appid: str, api_key: str, api_secret: str, spark_url: str
) -> None:
self.appid = appid
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
def extract_content(line: str):
if not line.strip():
return line
if line.startswith("data: "):
json_str = line[len("data: ") :]
else:
raise ValueError("Error line content ")
self.spark_url = spark_url
try:
data = json.loads(json_str)
if data == "[DONE]":
return ""
def gen_url(self):
from wsgiref.handlers import format_date_time
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
choices = data.get("choices", [])
if choices and isinstance(choices, list):
delta = choices[0].get("delta", {})
content = delta.get("content", "")
return content
else:
raise ValueError("Error line content ")
except json.JSONDecodeError:
return ""
class SparkLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
app_id: Optional[str] = None,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
api_base: Optional[str] = None,
api_domain: Optional[str] = None,
model_version: Optional[str] = None,
model_alias: Optional[str] = "spark_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
"""
Tips: 星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本各版本独立计量tokens。
传输协议 ws(s),为提高安全性强烈推荐wss
Spark4.0 Ultra 请求地址对应的domain参数为4.0Ultra
wss://spark-api.xf-yun.com/v4.0/chat
星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本
Spark4.0 Ultra 请求地址对应的domain参数为4.0Ultra
Spark Max-32K请求地址对应的domain参数为max-32k
wss://spark-api.xf-yun.com/chat/max-32k
Spark Max请求地址对应的domain参数为generalv3.5
wss://spark-api.xf-yun.com/v3.5/chat
Spark Pro-128K请求地址对应的domain参数为pro-128k
wss://spark-api.xf-yun.com/chat/pro-128k
Spark Pro请求地址对应的domain参数为generalv3
wss://spark-api.xf-yun.com/v3.1/chat
Spark Lite请求地址对应的domain参数为lite
wss://spark-api.xf-yun.com/v1.1/chat
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
"""
if not model_version:
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
if not api_base:
if model_version == SPARK_DEFAULT_API_VERSION:
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
elif model_version == "v4.0":
api_base = "ws://spark-api.xf-yun.com/v4.0/chat"
domain = "4.0Ultra"
elif model_version == "v3.5":
api_base = "ws://spark-api.xf-yun.com/v3.5/chat"
domain = "generalv3.5"
else:
api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
domain = "lite"
if not api_domain:
api_domain = domain
self._model = model
self._model_version = model_version
self._api_base = api_base
self._domain = api_domain
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
if not self._app_id:
raise ValueError("app_id can't be empty")
if not self._api_key:
raise ValueError("api_key can't be empty")
if not self._api_secret:
raise ValueError("api_secret can't be empty")
self._model = model or os.getenv("XUNFEI_SPARK_API_MODEL")
self._api_base = os.getenv("PROXY_SERVER_URL")
self._api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
if not self._model:
raise ValueError("model can't be empty")
if not self._api_base:
raise ValueError("api_base can't be empty")
if not self._api_password:
raise ValueError("api_password can't be empty")
super().__init__(
model_names=[model, model_alias],
@@ -203,10 +132,6 @@ class SparkLLMClient(ProxyLLMClient):
) -> "SparkLLMClient":
return cls(
model=model_params.proxyllm_backend,
app_id=model_params.proxy_api_app_id,
api_key=model_params.proxy_api_key,
api_secret=model_params.proxy_api_secret,
api_base=model_params.proxy_api_base,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
@@ -216,35 +141,45 @@ class SparkLLMClient(ProxyLLMClient):
def default_model(self) -> str:
return self._model
def sync_generate_stream(
def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
) -> AsyncIterator[ModelOutput]:
"""
reference:
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
"""
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False)
request_id = request.context.request_id or "1"
data = {
"header": {"app_id": self._app_id, "uid": request_id},
"parameter": {
"chat": {
"domain": self._domain,
"random_threshold": 0.5,
"max_tokens": request.max_new_tokens,
"auditing": "default",
"temperature": request.temperature,
}
},
"payload": {"message": {"text": messages}},
}
spark_api = SparkAPI(
self._app_id, self._api_key, self._api_secret, self._api_base
)
request_url = spark_api.gen_url()
try:
for text in get_response(request_url, data):
yield ModelOutput(text=text, error_code=0)
import requests
except ImportError as e:
raise ValueError(
"Could not import python package: requests "
"Please install requests by command `pip install requests"
) from e
data = {
"model": self._model, # 指定请求的模型
"messages": messages,
"temperature": request.temperature,
"stream": True,
}
header = {
"Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword
}
response = requests.post(self._api_base, headers=header, json=data, stream=True)
# 流式响应解析示例
response.encoding = "utf-8"
try:
content = ""
# data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
# data: [DONE]
for line in response.iter_lines(decode_unicode=True):
print("llm out:", line)
content = content + extract_content(line)
yield ModelOutput(text=content, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",