feat: call xunfei spark with stream, and fix the temperature bug (#2121)

Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
HIYIZI 2024-11-19 23:30:02 +08:00 committed by GitHub
parent 4efe643db8
commit 3ccfa94219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 183 additions and 250 deletions

View File

@ -241,10 +241,8 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk}
#BAICHUAN_PROXY_API_SECRET={your-baichuan-sct} #BAICHUAN_PROXY_API_SECRET={your-baichuan-sct}
# Xunfei Spark # Xunfei Spark
#XUNFEI_SPARK_API_VERSION={version} #XUNFEI_SPARK_API_PASSWORD={your_api_password}
#XUNFEI_SPARK_APPID={your_app_id} #XUNFEI_SPARK_API_MODEL={version}
#XUNFEI_SPARK_API_KEY={your_api_key}
#XUNFEI_SPARK_API_SECRET={your_api_secret}
## Yi Proxyllm, https://platform.lingyiwanwu.com/docs ## Yi Proxyllm, https://platform.lingyiwanwu.com/docs
#YI_MODEL_VERSION=yi-34b-chat-0205 #YI_MODEL_VERSION=yi-34b-chat-0205

View File

@ -1,41 +1,41 @@
# Please run command `pre-commit install` to install pre-commit hook # Please run command `pre-commit install` to install pre-commit hook
repos: repos:
- repo: local - repo: local
hooks: hooks:
- id: python-fmt - id: python-fmt
name: Python Format name: Python Format
entry: make fmt-check entry: make fmt-check
language: system language: system
exclude: '^dbgpt/app/static/|^web/' exclude: '^dbgpt/app/static/|^web/'
types: [python] types: [python]
stages: [commit] stages: [commit]
pass_filenames: false pass_filenames: false
args: [] args: []
- id: python-test - id: python-test
name: Python Unit Test name: Python Unit Test
entry: make test entry: make test
language: system language: system
exclude: '^dbgpt/app/static/|^web/' exclude: '^dbgpt/app/static/|^web/'
types: [python] types: [python]
stages: [commit] stages: [commit]
pass_filenames: false pass_filenames: false
args: [] args: []
- id: python-test-doc - id: python-test-doc
name: Python Doc Test name: Python Doc Test
entry: make test-doc entry: make test-doc
language: system language: system
exclude: '^dbgpt/app/static/|^web/' exclude: '^dbgpt/app/static/|^web/'
types: [python] types: [python]
stages: [commit] stages: [commit]
pass_filenames: false pass_filenames: false
args: [] args: []
- id: python-lint-mypy - id: python-lint-mypy
name: Python Lint mypy name: Python Lint mypy
entry: make mypy entry: make mypy
language: system language: system
exclude: '^dbgpt/app/static/|^web/' exclude: '^dbgpt/app/static/|^web/'
types: [python] types: [python]
stages: [commit] stages: [commit]
pass_filenames: false pass_filenames: false
args: [] args: []

View File

@ -78,17 +78,13 @@ class Config(metaclass=Singleton):
) )
# xunfei spark # xunfei spark
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") self.spark_proxy_api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY") self.spark_proxy_api_model = os.getenv("XUNFEI_SPARK_API_MODEL")
self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET") if self.spark_proxy_api_model and self.spark_proxy_api_password:
self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID") os.environ[
if self.spark_proxy_api_key and self.spark_proxy_api_secret: "spark_proxyllm_proxy_api_password"
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key ] = self.spark_proxy_api_password
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret os.environ["spark_proxyllm_proxy_api_model"] = self.spark_proxy_api_model
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version or ""
os.environ["spark_proxyllm_proxy_api_app_id"] = (
self.spark_proxy_api_appid or ""
)
# baichuan proxy # baichuan proxy
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY") self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")

View File

@ -291,6 +291,7 @@ class AWELAgentOperator(
prompt_template = None prompt_template = None
if self.awel_agent.agent_prompt: if self.awel_agent.agent_prompt:
from dbgpt.serve.prompt.api.endpoints import get_service from dbgpt.serve.prompt.api.endpoints import get_service
prompt_service = get_service() prompt_service = get_service()
prompt_template = prompt_service.get_template( prompt_template = prompt_service.get_template(
self.awel_agent.agent_prompt.code self.awel_agent.agent_prompt.code

View File

@ -12,6 +12,7 @@ from dbgpt.core.awel.flow import (
ResourceCategory, ResourceCategory,
register_resource, register_resource,
) )
from ....resource.base import AgentResource, ResourceType from ....resource.base import AgentResource, ResourceType
from ....resource.manage import get_resource_manager from ....resource.manage import get_resource_manager
from ....util.llm.llm import LLMConfig, LLMStrategyType from ....util.llm.llm import LLMConfig, LLMStrategyType
@ -20,6 +21,7 @@ from ...agent_manage import get_agent_manager
def _agent_resource_prompt_values() -> List[OptionValue]: def _agent_resource_prompt_values() -> List[OptionValue]:
from dbgpt.serve.prompt.api.endpoints import get_service from dbgpt.serve.prompt.api.endpoints import get_service
prompt_service = get_service() prompt_service = get_service()
prompts = prompt_service.get_target_prompt() prompts = prompt_service.get_target_prompt()
return [ return [

View File

@ -2,10 +2,10 @@ import logging
import os import os
import shutil import shutil
import tempfile import tempfile
from typing import List
from pathlib import Path from pathlib import Path
from typing import List
from fastapi import APIRouter, Depends, File, Form, UploadFile, HTTPException from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import ( from dbgpt.app.knowledge.request.request import (
@ -333,70 +333,72 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest):
@router.post("/knowledge/{space_name}/document/upload") @router.post("/knowledge/{space_name}/document/upload")
async def document_upload( async def document_upload(
space_name: str, space_name: str,
doc_name: str = Form(...), doc_name: str = Form(...),
doc_type: str = Form(...), doc_type: str = Form(...),
doc_file: UploadFile = File(...), doc_file: UploadFile = File(...),
): ):
print(f"/document/upload params: {space_name}") print(f"/document/upload params: {space_name}")
try: try:
if doc_file: if doc_file:
# Sanitize inputs to prevent path traversal # Sanitize inputs to prevent path traversal
safe_space_name = os.path.basename(space_name) safe_space_name = os.path.basename(space_name)
safe_filename = os.path.basename(doc_file.filename) safe_filename = os.path.basename(doc_file.filename)
# Create absolute paths and verify they are within allowed directory # Create absolute paths and verify they are within allowed directory
upload_dir = os.path.abspath(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name)) upload_dir = os.path.abspath(
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename)) os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name)
)
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename))
if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path: if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path:
raise HTTPException(status_code=400, detail="Invalid path detected") raise HTTPException(status_code=400, detail="Invalid path detected")
if not os.path.exists(upload_dir): if not os.path.exists(upload_dir):
os.makedirs(upload_dir) os.makedirs(upload_dir)
# Create temp file # Create temp file
tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir) tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir)
try:
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(tmp_path, target_path)
request = KnowledgeDocumentRequest() try:
request.doc_name = doc_name with os.fdopen(tmp_fd, "wb") as tmp:
request.doc_type = doc_type tmp.write(await doc_file.read())
request.content = target_path
space_res = knowledge_space_service.get_knowledge_space( shutil.move(tmp_path, target_path)
KnowledgeSpaceRequest(name=safe_space_name)
)
if len(space_res) == 0:
# create default space
if "default" != safe_space_name:
raise Exception(f"you have not create your knowledge space.")
knowledge_space_service.create_knowledge_space(
KnowledgeSpaceRequest(
name=safe_space_name,
desc="first db-gpt rag application",
owner="dbgpt",
)
)
return Result.succ(
knowledge_space_service.create_knowledge_document(
space=safe_space_name, request=request
)
)
except Exception as e:
# Clean up temp file if anything goes wrong
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise e
return Result.failed(code="E000X", msg=f"doc_file is None") request = KnowledgeDocumentRequest()
except Exception as e: request.doc_name = doc_name
return Result.failed(code="E000X", msg=f"document add error {e}") request.doc_type = doc_type
request.content = target_path
space_res = knowledge_space_service.get_knowledge_space(
KnowledgeSpaceRequest(name=safe_space_name)
)
if len(space_res) == 0:
# create default space
if "default" != safe_space_name:
raise Exception(f"you have not create your knowledge space.")
knowledge_space_service.create_knowledge_space(
KnowledgeSpaceRequest(
name=safe_space_name,
desc="first db-gpt rag application",
owner="dbgpt",
)
)
return Result.succ(
knowledge_space_service.create_knowledge_document(
space=safe_space_name, request=request
)
)
except Exception as e:
# Clean up temp file if anything goes wrong
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise e
return Result.failed(code="E000X", msg=f"doc_file is None")
except Exception as e:
return Result.failed(code="E000X", msg=f"document add error {e}")
@router.post("/knowledge/{space_name}/document/sync") @router.post("/knowledge/{space_name}/document/sync")

View File

@ -232,7 +232,8 @@ class BaseChat(ABC):
) )
node = AppChatComposerOperator( node = AppChatComposerOperator(
model=self.llm_model, model=self.llm_model,
temperature=float(self.prompt_template.temperature), temperature=self._chat_param.get("temperature")
or float(self.prompt_template.temperature),
max_new_tokens=int(self.prompt_template.max_new_tokens), max_new_tokens=int(self.prompt_template.max_new_tokens),
prompt=self.prompt_template.prompt, prompt=self.prompt_template.prompt,
message_version=self._message_version, message_version=self._message_version,

View File

@ -1,21 +1,13 @@
import base64
import hashlib
import hmac
import json import json
import os import os
from concurrent.futures import Executor from concurrent.futures import Executor
from datetime import datetime from typing import AsyncIterator, Optional
from time import mktime
from typing import Iterator, Optional
from urllib.parse import urlencode, urlparse
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel from dbgpt.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v3"
def getlength(text): def getlength(text):
length = 0 length = 0
@ -49,7 +41,7 @@ def spark_generate_stream(
max_new_tokens=params.get("max_new_tokens"), max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"), stop=params.get("stop"),
) )
for r in client.sync_generate_stream(request): for r in client.generate_stream(request):
yield r yield r
@ -74,120 +66,57 @@ def get_response(request_url, data):
yield result yield result
class SparkAPI: def extract_content(line: str):
def __init__( if not line.strip():
self, appid: str, api_key: str, api_secret: str, spark_url: str return line
) -> None: if line.startswith("data: "):
self.appid = appid json_str = line[len("data: ") :]
self.api_key = api_key else:
self.api_secret = api_secret raise ValueError("Error line content ")
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
self.spark_url = spark_url try:
data = json.loads(json_str)
if data == "[DONE]":
return ""
def gen_url(self): choices = data.get("choices", [])
from wsgiref.handlers import format_date_time if choices and isinstance(choices, list):
delta = choices[0].get("delta", {})
# 生成RFC1123格式的时间戳 content = delta.get("content", "")
now = datetime.now() return content
date = format_date_time(mktime(now.timetuple())) else:
raise ValueError("Error line content ")
# 拼接字符串 except json.JSONDecodeError:
signature_origin = "host: " + self.host + "\n" return ""
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)
# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
class SparkLLMClient(ProxyLLMClient): class SparkLLMClient(ProxyLLMClient):
def __init__( def __init__(
self, self,
model: Optional[str] = None, model: Optional[str] = None,
app_id: Optional[str] = None,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
api_base: Optional[str] = None,
api_domain: Optional[str] = None,
model_version: Optional[str] = None,
model_alias: Optional[str] = "spark_proxyllm", model_alias: Optional[str] = "spark_proxyllm",
context_length: Optional[int] = 4096, context_length: Optional[int] = 4096,
executor: Optional[Executor] = None, executor: Optional[Executor] = None,
): ):
""" """
Tips: 星火大模型API当前有LiteProPro-128KMaxMax-32K和4.0 Ultra六个版本各版本独立计量tokens 星火大模型API当前有LiteProPro-128KMaxMax-32K和4.0 Ultra六个版本
传输协议 ws(s),为提高安全性强烈推荐wss Spark4.0 Ultra 请求地址对应的domain参数为4.0Ultra
Spark4.0 Ultra 请求地址对应的domain参数为4.0Ultra
wss://spark-api.xf-yun.com/v4.0/chat
Spark Max-32K请求地址对应的domain参数为max-32k Spark Max-32K请求地址对应的domain参数为max-32k
wss://spark-api.xf-yun.com/chat/max-32k
Spark Max请求地址对应的domain参数为generalv3.5 Spark Max请求地址对应的domain参数为generalv3.5
wss://spark-api.xf-yun.com/v3.5/chat
Spark Pro-128K请求地址对应的domain参数为pro-128k Spark Pro-128K请求地址对应的domain参数为pro-128k
wss://spark-api.xf-yun.com/chat/pro-128k
Spark Pro请求地址对应的domain参数为generalv3 Spark Pro请求地址对应的domain参数为generalv3
wss://spark-api.xf-yun.com/v3.1/chat
Spark Lite请求地址对应的domain参数为lite Spark Lite请求地址对应的domain参数为lite
wss://spark-api.xf-yun.com/v1.1/chat https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
""" """
if not model_version: self._model = model or os.getenv("XUNFEI_SPARK_API_MODEL")
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION") self._api_base = os.getenv("PROXY_SERVER_URL")
if not api_base: self._api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
if model_version == SPARK_DEFAULT_API_VERSION: if not self._model:
api_base = "ws://spark-api.xf-yun.com/v3.1/chat" raise ValueError("model can't be empty")
domain = "generalv3" if not self._api_base:
elif model_version == "v4.0": raise ValueError("api_base can't be empty")
api_base = "ws://spark-api.xf-yun.com/v4.0/chat" if not self._api_password:
domain = "4.0Ultra" raise ValueError("api_password can't be empty")
elif model_version == "v3.5":
api_base = "ws://spark-api.xf-yun.com/v3.5/chat"
domain = "generalv3.5"
else:
api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
domain = "lite"
if not api_domain:
api_domain = domain
self._model = model
self._model_version = model_version
self._api_base = api_base
self._domain = api_domain
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
if not self._app_id:
raise ValueError("app_id can't be empty")
if not self._api_key:
raise ValueError("api_key can't be empty")
if not self._api_secret:
raise ValueError("api_secret can't be empty")
super().__init__( super().__init__(
model_names=[model, model_alias], model_names=[model, model_alias],
@ -203,10 +132,6 @@ class SparkLLMClient(ProxyLLMClient):
) -> "SparkLLMClient": ) -> "SparkLLMClient":
return cls( return cls(
model=model_params.proxyllm_backend, model=model_params.proxyllm_backend,
app_id=model_params.proxy_api_app_id,
api_key=model_params.proxy_api_key,
api_secret=model_params.proxy_api_secret,
api_base=model_params.proxy_api_base,
model_alias=model_params.model_name, model_alias=model_params.model_name,
context_length=model_params.max_context_size, context_length=model_params.max_context_size,
executor=default_executor, executor=default_executor,
@ -216,35 +141,45 @@ class SparkLLMClient(ProxyLLMClient):
def default_model(self) -> str: def default_model(self) -> str:
return self._model return self._model
def sync_generate_stream( def generate_stream(
self, self,
request: ModelRequest, request: ModelRequest,
message_converter: Optional[MessageConverter] = None, message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]: ) -> AsyncIterator[ModelOutput]:
"""
reference:
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
"""
request = self.local_covert_message(request, message_converter) request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False) messages = request.to_common_messages(support_system_role=False)
request_id = request.context.request_id or "1"
data = {
"header": {"app_id": self._app_id, "uid": request_id},
"parameter": {
"chat": {
"domain": self._domain,
"random_threshold": 0.5,
"max_tokens": request.max_new_tokens,
"auditing": "default",
"temperature": request.temperature,
}
},
"payload": {"message": {"text": messages}},
}
spark_api = SparkAPI(
self._app_id, self._api_key, self._api_secret, self._api_base
)
request_url = spark_api.gen_url()
try: try:
for text in get_response(request_url, data): import requests
yield ModelOutput(text=text, error_code=0) except ImportError as e:
raise ValueError(
"Could not import python package: requests "
"Please install requests by command `pip install requests"
) from e
data = {
"model": self._model, # 指定请求的模型
"messages": messages,
"temperature": request.temperature,
"stream": True,
}
header = {
"Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword
}
response = requests.post(self._api_base, headers=header, json=data, stream=True)
# 流式响应解析示例
response.encoding = "utf-8"
try:
content = ""
# data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
# data: [DONE]
for line in response.iter_lines(decode_unicode=True):
print("llm out:", line)
content = content + extract_content(line)
yield ModelOutput(text=content, error_code=0)
except Exception as e: except Exception as e:
return ModelOutput( return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",

View File

@ -457,9 +457,7 @@ class MilvusStore(VectorStoreBase):
self.vector_field = x.name self.vector_field = x.name
# convert to milvus expr filter. # convert to milvus expr filter.
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
_, docs_and_scores = self._search( _, docs_and_scores = self._search(query=text, k=topk, expr=milvus_filter_expr)
query=text, k=topk, expr=milvus_filter_expr
)
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores): if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
logger.warning( logger.warning(
"similarity score need between" f" 0 and 1, got {docs_and_scores}" "similarity score need between" f" 0 and 1, got {docs_and_scores}"