feat: call xunfei spark with stream, and fix the temperature bug (#2121)

Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
HIYIZI 2024-11-19 23:30:02 +08:00 committed by GitHub
parent 4efe643db8
commit 3ccfa94219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 183 additions and 250 deletions

View File

@ -241,10 +241,8 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk}
#BAICHUAN_PROXY_API_SECRET={your-baichuan-sct}
# Xunfei Spark
#XUNFEI_SPARK_API_VERSION={version}
#XUNFEI_SPARK_APPID={your_app_id}
#XUNFEI_SPARK_API_KEY={your_api_key}
#XUNFEI_SPARK_API_SECRET={your_api_secret}
#XUNFEI_SPARK_API_PASSWORD={your_api_password}
#XUNFEI_SPARK_API_MODEL={version}
## Yi Proxyllm, https://platform.lingyiwanwu.com/docs
#YI_MODEL_VERSION=yi-34b-chat-0205

View File

@ -1,41 +1,41 @@
# Please run command `pre-commit install` to install pre-commit hook
repos:
- repo: local
hooks:
- id: python-fmt
name: Python Format
entry: make fmt-check
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test
name: Python Unit Test
entry: make test
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test-doc
name: Python Doc Test
entry: make test-doc
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-lint-mypy
name: Python Lint mypy
entry: make mypy
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
# Please run command `pre-commit install` to install pre-commit hook
repos:
- repo: local
hooks:
- id: python-fmt
name: Python Format
entry: make fmt-check
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test
name: Python Unit Test
entry: make test
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test-doc
name: Python Doc Test
entry: make test-doc
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-lint-mypy
name: Python Lint mypy
entry: make mypy
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []

View File

@ -78,17 +78,13 @@ class Config(metaclass=Singleton):
)
# xunfei spark
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION")
self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID")
if self.spark_proxy_api_key and self.spark_proxy_api_secret:
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version or ""
os.environ["spark_proxyllm_proxy_api_app_id"] = (
self.spark_proxy_api_appid or ""
)
self.spark_proxy_api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
self.spark_proxy_api_model = os.getenv("XUNFEI_SPARK_API_MODEL")
if self.spark_proxy_api_model and self.spark_proxy_api_password:
os.environ[
"spark_proxyllm_proxy_api_password"
] = self.spark_proxy_api_password
os.environ["spark_proxyllm_proxy_api_model"] = self.spark_proxy_api_model
# baichuan proxy
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")

View File

@ -291,6 +291,7 @@ class AWELAgentOperator(
prompt_template = None
if self.awel_agent.agent_prompt:
from dbgpt.serve.prompt.api.endpoints import get_service
prompt_service = get_service()
prompt_template = prompt_service.get_template(
self.awel_agent.agent_prompt.code

View File

@ -12,6 +12,7 @@ from dbgpt.core.awel.flow import (
ResourceCategory,
register_resource,
)
from ....resource.base import AgentResource, ResourceType
from ....resource.manage import get_resource_manager
from ....util.llm.llm import LLMConfig, LLMStrategyType
@ -20,6 +21,7 @@ from ...agent_manage import get_agent_manager
def _agent_resource_prompt_values() -> List[OptionValue]:
from dbgpt.serve.prompt.api.endpoints import get_service
prompt_service = get_service()
prompts = prompt_service.get_target_prompt()
return [

View File

@ -2,10 +2,10 @@ import logging
import os
import shutil
import tempfile
from typing import List
from pathlib import Path
from typing import List
from fastapi import APIRouter, Depends, File, Form, UploadFile, HTTPException
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import (
@ -333,70 +333,72 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest):
@router.post("/knowledge/{space_name}/document/upload")
async def document_upload(
space_name: str,
doc_name: str = Form(...),
doc_type: str = Form(...),
doc_file: UploadFile = File(...),
space_name: str,
doc_name: str = Form(...),
doc_type: str = Form(...),
doc_file: UploadFile = File(...),
):
print(f"/document/upload params: {space_name}")
try:
if doc_file:
# Sanitize inputs to prevent path traversal
safe_space_name = os.path.basename(space_name)
safe_filename = os.path.basename(doc_file.filename)
print(f"/document/upload params: {space_name}")
try:
if doc_file:
# Sanitize inputs to prevent path traversal
safe_space_name = os.path.basename(space_name)
safe_filename = os.path.basename(doc_file.filename)
# Create absolute paths and verify they are within allowed directory
upload_dir = os.path.abspath(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name))
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename))
# Create absolute paths and verify they are within allowed directory
upload_dir = os.path.abspath(
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name)
)
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename))
if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path:
raise HTTPException(status_code=400, detail="Invalid path detected")
if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path:
raise HTTPException(status_code=400, detail="Invalid path detected")
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
# Create temp file
tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir)
try:
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(tmp_path, target_path)
# Create temp file
tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir)
request = KnowledgeDocumentRequest()
request.doc_name = doc_name
request.doc_type = doc_type
request.content = target_path
try:
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
space_res = knowledge_space_service.get_knowledge_space(
KnowledgeSpaceRequest(name=safe_space_name)
)
if len(space_res) == 0:
# create default space
if "default" != safe_space_name:
raise Exception(f"you have not create your knowledge space.")
knowledge_space_service.create_knowledge_space(
KnowledgeSpaceRequest(
name=safe_space_name,
desc="first db-gpt rag application",
owner="dbgpt",
)
)
return Result.succ(
knowledge_space_service.create_knowledge_document(
space=safe_space_name, request=request
)
)
except Exception as e:
# Clean up temp file if anything goes wrong
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise e
shutil.move(tmp_path, target_path)
return Result.failed(code="E000X", msg=f"doc_file is None")
except Exception as e:
return Result.failed(code="E000X", msg=f"document add error {e}")
request = KnowledgeDocumentRequest()
request.doc_name = doc_name
request.doc_type = doc_type
request.content = target_path
space_res = knowledge_space_service.get_knowledge_space(
KnowledgeSpaceRequest(name=safe_space_name)
)
if len(space_res) == 0:
# create default space
if "default" != safe_space_name:
raise Exception(f"you have not create your knowledge space.")
knowledge_space_service.create_knowledge_space(
KnowledgeSpaceRequest(
name=safe_space_name,
desc="first db-gpt rag application",
owner="dbgpt",
)
)
return Result.succ(
knowledge_space_service.create_knowledge_document(
space=safe_space_name, request=request
)
)
except Exception as e:
# Clean up temp file if anything goes wrong
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise e
return Result.failed(code="E000X", msg=f"doc_file is None")
except Exception as e:
return Result.failed(code="E000X", msg=f"document add error {e}")
@router.post("/knowledge/{space_name}/document/sync")

View File

@ -232,7 +232,8 @@ class BaseChat(ABC):
)
node = AppChatComposerOperator(
model=self.llm_model,
temperature=float(self.prompt_template.temperature),
temperature=self._chat_param.get("temperature")
or float(self.prompt_template.temperature),
max_new_tokens=int(self.prompt_template.max_new_tokens),
prompt=self.prompt_template.prompt,
message_version=self._message_version,

View File

@ -1,21 +1,13 @@
import base64
import hashlib
import hmac
import json
import os
from concurrent.futures import Executor
from datetime import datetime
from time import mktime
from typing import Iterator, Optional
from urllib.parse import urlencode, urlparse
from typing import AsyncIterator, Optional
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v3"
def getlength(text):
length = 0
@ -49,7 +41,7 @@ def spark_generate_stream(
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
for r in client.generate_stream(request):
yield r
@ -74,120 +66,57 @@ def get_response(request_url, data):
yield result
class SparkAPI:
def __init__(
self, appid: str, api_key: str, api_secret: str, spark_url: str
) -> None:
self.appid = appid
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
def extract_content(line: str):
if not line.strip():
return line
if line.startswith("data: "):
json_str = line[len("data: ") :]
else:
raise ValueError("Error line content ")
self.spark_url = spark_url
try:
data = json.loads(json_str)
if data == "[DONE]":
return ""
def gen_url(self):
from wsgiref.handlers import format_date_time
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
choices = data.get("choices", [])
if choices and isinstance(choices, list):
delta = choices[0].get("delta", {})
content = delta.get("content", "")
return content
else:
raise ValueError("Error line content ")
except json.JSONDecodeError:
return ""
class SparkLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
app_id: Optional[str] = None,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
api_base: Optional[str] = None,
api_domain: Optional[str] = None,
model_version: Optional[str] = None,
model_alias: Optional[str] = "spark_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
"""
Tips: 星火大模型API当前有LiteProPro-128KMaxMax-32K和4.0 Ultra六个版本各版本独立计量tokens
传输协议 ws(s),为提高安全性强烈推荐wss
Spark4.0 Ultra 请求地址对应的domain参数为4.0Ultra
wss://spark-api.xf-yun.com/v4.0/chat
星火大模型API当前有LiteProPro-128KMaxMax-32K和4.0 Ultra六个版本
Spark4.0 Ultra 请求地址对应的domain参数为4.0Ultra
Spark Max-32K请求地址对应的domain参数为max-32k
wss://spark-api.xf-yun.com/chat/max-32k
Spark Max请求地址对应的domain参数为generalv3.5
wss://spark-api.xf-yun.com/v3.5/chat
Spark Pro-128K请求地址对应的domain参数为pro-128k
wss://spark-api.xf-yun.com/chat/pro-128k
Spark Pro请求地址对应的domain参数为generalv3
wss://spark-api.xf-yun.com/v3.1/chat
Spark Lite请求地址对应的domain参数为lite
wss://spark-api.xf-yun.com/v1.1/chat
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
"""
if not model_version:
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
if not api_base:
if model_version == SPARK_DEFAULT_API_VERSION:
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
elif model_version == "v4.0":
api_base = "ws://spark-api.xf-yun.com/v4.0/chat"
domain = "4.0Ultra"
elif model_version == "v3.5":
api_base = "ws://spark-api.xf-yun.com/v3.5/chat"
domain = "generalv3.5"
else:
api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
domain = "lite"
if not api_domain:
api_domain = domain
self._model = model
self._model_version = model_version
self._api_base = api_base
self._domain = api_domain
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
if not self._app_id:
raise ValueError("app_id can't be empty")
if not self._api_key:
raise ValueError("api_key can't be empty")
if not self._api_secret:
raise ValueError("api_secret can't be empty")
self._model = model or os.getenv("XUNFEI_SPARK_API_MODEL")
self._api_base = os.getenv("PROXY_SERVER_URL")
self._api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
if not self._model:
raise ValueError("model can't be empty")
if not self._api_base:
raise ValueError("api_base can't be empty")
if not self._api_password:
raise ValueError("api_password can't be empty")
super().__init__(
model_names=[model, model_alias],
@ -203,10 +132,6 @@ class SparkLLMClient(ProxyLLMClient):
) -> "SparkLLMClient":
return cls(
model=model_params.proxyllm_backend,
app_id=model_params.proxy_api_app_id,
api_key=model_params.proxy_api_key,
api_secret=model_params.proxy_api_secret,
api_base=model_params.proxy_api_base,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
@ -216,35 +141,45 @@ class SparkLLMClient(ProxyLLMClient):
def default_model(self) -> str:
return self._model
def sync_generate_stream(
def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
) -> AsyncIterator[ModelOutput]:
"""
reference:
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
"""
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False)
request_id = request.context.request_id or "1"
data = {
"header": {"app_id": self._app_id, "uid": request_id},
"parameter": {
"chat": {
"domain": self._domain,
"random_threshold": 0.5,
"max_tokens": request.max_new_tokens,
"auditing": "default",
"temperature": request.temperature,
}
},
"payload": {"message": {"text": messages}},
}
spark_api = SparkAPI(
self._app_id, self._api_key, self._api_secret, self._api_base
)
request_url = spark_api.gen_url()
try:
for text in get_response(request_url, data):
yield ModelOutput(text=text, error_code=0)
import requests
except ImportError as e:
raise ValueError(
"Could not import python package: requests "
"Please install requests by command `pip install requests"
) from e
data = {
"model": self._model, # 指定请求的模型
"messages": messages,
"temperature": request.temperature,
"stream": True,
}
header = {
"Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword
}
response = requests.post(self._api_base, headers=header, json=data, stream=True)
# 流式响应解析示例
response.encoding = "utf-8"
try:
content = ""
# data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
# data: [DONE]
for line in response.iter_lines(decode_unicode=True):
print("llm out:", line)
content = content + extract_content(line)
yield ModelOutput(text=content, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",

View File

@ -457,9 +457,7 @@ class MilvusStore(VectorStoreBase):
self.vector_field = x.name
# convert to milvus expr filter.
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
_, docs_and_scores = self._search(
query=text, k=topk, expr=milvus_filter_expr
)
_, docs_and_scores = self._search(query=text, k=topk, expr=milvus_filter_expr)
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
logger.warning(
"similarity score need between" f" 0 and 1, got {docs_and_scores}"