mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-21 19:31:43 +00:00
feat: call xunfei spark with stream, and fix the temperature bug (#2121)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
4efe643db8
commit
3ccfa94219
@ -241,10 +241,8 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk}
|
|||||||
#BAICHUAN_PROXY_API_SECRET={your-baichuan-sct}
|
#BAICHUAN_PROXY_API_SECRET={your-baichuan-sct}
|
||||||
|
|
||||||
# Xunfei Spark
|
# Xunfei Spark
|
||||||
#XUNFEI_SPARK_API_VERSION={version}
|
#XUNFEI_SPARK_API_PASSWORD={your_api_password}
|
||||||
#XUNFEI_SPARK_APPID={your_app_id}
|
#XUNFEI_SPARK_API_MODEL={version}
|
||||||
#XUNFEI_SPARK_API_KEY={your_api_key}
|
|
||||||
#XUNFEI_SPARK_API_SECRET={your_api_secret}
|
|
||||||
|
|
||||||
## Yi Proxyllm, https://platform.lingyiwanwu.com/docs
|
## Yi Proxyllm, https://platform.lingyiwanwu.com/docs
|
||||||
#YI_MODEL_VERSION=yi-34b-chat-0205
|
#YI_MODEL_VERSION=yi-34b-chat-0205
|
||||||
|
@ -1,41 +1,41 @@
|
|||||||
# Please run command `pre-commit install` to install pre-commit hook
|
# Please run command `pre-commit install` to install pre-commit hook
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-fmt
|
- id: python-fmt
|
||||||
name: Python Format
|
name: Python Format
|
||||||
entry: make fmt-check
|
entry: make fmt-check
|
||||||
language: system
|
language: system
|
||||||
exclude: '^dbgpt/app/static/|^web/'
|
exclude: '^dbgpt/app/static/|^web/'
|
||||||
types: [python]
|
types: [python]
|
||||||
stages: [commit]
|
stages: [commit]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
args: []
|
args: []
|
||||||
- id: python-test
|
- id: python-test
|
||||||
name: Python Unit Test
|
name: Python Unit Test
|
||||||
entry: make test
|
entry: make test
|
||||||
language: system
|
language: system
|
||||||
exclude: '^dbgpt/app/static/|^web/'
|
exclude: '^dbgpt/app/static/|^web/'
|
||||||
types: [python]
|
types: [python]
|
||||||
stages: [commit]
|
stages: [commit]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
args: []
|
args: []
|
||||||
- id: python-test-doc
|
- id: python-test-doc
|
||||||
name: Python Doc Test
|
name: Python Doc Test
|
||||||
entry: make test-doc
|
entry: make test-doc
|
||||||
language: system
|
language: system
|
||||||
exclude: '^dbgpt/app/static/|^web/'
|
exclude: '^dbgpt/app/static/|^web/'
|
||||||
types: [python]
|
types: [python]
|
||||||
stages: [commit]
|
stages: [commit]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
args: []
|
args: []
|
||||||
- id: python-lint-mypy
|
- id: python-lint-mypy
|
||||||
name: Python Lint mypy
|
name: Python Lint mypy
|
||||||
entry: make mypy
|
entry: make mypy
|
||||||
language: system
|
language: system
|
||||||
exclude: '^dbgpt/app/static/|^web/'
|
exclude: '^dbgpt/app/static/|^web/'
|
||||||
types: [python]
|
types: [python]
|
||||||
stages: [commit]
|
stages: [commit]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
args: []
|
args: []
|
||||||
|
|
||||||
|
@ -78,17 +78,13 @@ class Config(metaclass=Singleton):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# xunfei spark
|
# xunfei spark
|
||||||
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION")
|
self.spark_proxy_api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
|
||||||
self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
|
self.spark_proxy_api_model = os.getenv("XUNFEI_SPARK_API_MODEL")
|
||||||
self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
|
if self.spark_proxy_api_model and self.spark_proxy_api_password:
|
||||||
self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID")
|
os.environ[
|
||||||
if self.spark_proxy_api_key and self.spark_proxy_api_secret:
|
"spark_proxyllm_proxy_api_password"
|
||||||
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
|
] = self.spark_proxy_api_password
|
||||||
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
|
os.environ["spark_proxyllm_proxy_api_model"] = self.spark_proxy_api_model
|
||||||
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version or ""
|
|
||||||
os.environ["spark_proxyllm_proxy_api_app_id"] = (
|
|
||||||
self.spark_proxy_api_appid or ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# baichuan proxy
|
# baichuan proxy
|
||||||
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
||||||
|
@ -291,6 +291,7 @@ class AWELAgentOperator(
|
|||||||
prompt_template = None
|
prompt_template = None
|
||||||
if self.awel_agent.agent_prompt:
|
if self.awel_agent.agent_prompt:
|
||||||
from dbgpt.serve.prompt.api.endpoints import get_service
|
from dbgpt.serve.prompt.api.endpoints import get_service
|
||||||
|
|
||||||
prompt_service = get_service()
|
prompt_service = get_service()
|
||||||
prompt_template = prompt_service.get_template(
|
prompt_template = prompt_service.get_template(
|
||||||
self.awel_agent.agent_prompt.code
|
self.awel_agent.agent_prompt.code
|
||||||
|
@ -12,6 +12,7 @@ from dbgpt.core.awel.flow import (
|
|||||||
ResourceCategory,
|
ResourceCategory,
|
||||||
register_resource,
|
register_resource,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ....resource.base import AgentResource, ResourceType
|
from ....resource.base import AgentResource, ResourceType
|
||||||
from ....resource.manage import get_resource_manager
|
from ....resource.manage import get_resource_manager
|
||||||
from ....util.llm.llm import LLMConfig, LLMStrategyType
|
from ....util.llm.llm import LLMConfig, LLMStrategyType
|
||||||
@ -20,6 +21,7 @@ from ...agent_manage import get_agent_manager
|
|||||||
|
|
||||||
def _agent_resource_prompt_values() -> List[OptionValue]:
|
def _agent_resource_prompt_values() -> List[OptionValue]:
|
||||||
from dbgpt.serve.prompt.api.endpoints import get_service
|
from dbgpt.serve.prompt.api.endpoints import get_service
|
||||||
|
|
||||||
prompt_service = get_service()
|
prompt_service = get_service()
|
||||||
prompts = prompt_service.get_target_prompt()
|
prompts = prompt_service.get_target_prompt()
|
||||||
return [
|
return [
|
||||||
|
@ -2,10 +2,10 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, UploadFile, HTTPException
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||||
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
from dbgpt.app.knowledge.request.request import (
|
from dbgpt.app.knowledge.request.request import (
|
||||||
@ -333,70 +333,72 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest):
|
|||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/upload")
|
@router.post("/knowledge/{space_name}/document/upload")
|
||||||
async def document_upload(
|
async def document_upload(
|
||||||
space_name: str,
|
space_name: str,
|
||||||
doc_name: str = Form(...),
|
doc_name: str = Form(...),
|
||||||
doc_type: str = Form(...),
|
doc_type: str = Form(...),
|
||||||
doc_file: UploadFile = File(...),
|
doc_file: UploadFile = File(...),
|
||||||
):
|
):
|
||||||
print(f"/document/upload params: {space_name}")
|
print(f"/document/upload params: {space_name}")
|
||||||
try:
|
try:
|
||||||
if doc_file:
|
if doc_file:
|
||||||
# Sanitize inputs to prevent path traversal
|
# Sanitize inputs to prevent path traversal
|
||||||
safe_space_name = os.path.basename(space_name)
|
safe_space_name = os.path.basename(space_name)
|
||||||
safe_filename = os.path.basename(doc_file.filename)
|
safe_filename = os.path.basename(doc_file.filename)
|
||||||
|
|
||||||
# Create absolute paths and verify they are within allowed directory
|
# Create absolute paths and verify they are within allowed directory
|
||||||
upload_dir = os.path.abspath(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name))
|
upload_dir = os.path.abspath(
|
||||||
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename))
|
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name)
|
||||||
|
)
|
||||||
|
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename))
|
||||||
|
|
||||||
if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path:
|
if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path:
|
||||||
raise HTTPException(status_code=400, detail="Invalid path detected")
|
raise HTTPException(status_code=400, detail="Invalid path detected")
|
||||||
|
|
||||||
if not os.path.exists(upload_dir):
|
if not os.path.exists(upload_dir):
|
||||||
os.makedirs(upload_dir)
|
os.makedirs(upload_dir)
|
||||||
|
|
||||||
# Create temp file
|
# Create temp file
|
||||||
tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir)
|
tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir)
|
||||||
|
|
||||||
try:
|
|
||||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
|
||||||
tmp.write(await doc_file.read())
|
|
||||||
|
|
||||||
shutil.move(tmp_path, target_path)
|
|
||||||
|
|
||||||
request = KnowledgeDocumentRequest()
|
try:
|
||||||
request.doc_name = doc_name
|
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||||
request.doc_type = doc_type
|
tmp.write(await doc_file.read())
|
||||||
request.content = target_path
|
|
||||||
|
|
||||||
space_res = knowledge_space_service.get_knowledge_space(
|
shutil.move(tmp_path, target_path)
|
||||||
KnowledgeSpaceRequest(name=safe_space_name)
|
|
||||||
)
|
|
||||||
if len(space_res) == 0:
|
|
||||||
# create default space
|
|
||||||
if "default" != safe_space_name:
|
|
||||||
raise Exception(f"you have not create your knowledge space.")
|
|
||||||
knowledge_space_service.create_knowledge_space(
|
|
||||||
KnowledgeSpaceRequest(
|
|
||||||
name=safe_space_name,
|
|
||||||
desc="first db-gpt rag application",
|
|
||||||
owner="dbgpt",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return Result.succ(
|
|
||||||
knowledge_space_service.create_knowledge_document(
|
|
||||||
space=safe_space_name, request=request
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Clean up temp file if anything goes wrong
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return Result.failed(code="E000X", msg=f"doc_file is None")
|
request = KnowledgeDocumentRequest()
|
||||||
except Exception as e:
|
request.doc_name = doc_name
|
||||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
request.doc_type = doc_type
|
||||||
|
request.content = target_path
|
||||||
|
|
||||||
|
space_res = knowledge_space_service.get_knowledge_space(
|
||||||
|
KnowledgeSpaceRequest(name=safe_space_name)
|
||||||
|
)
|
||||||
|
if len(space_res) == 0:
|
||||||
|
# create default space
|
||||||
|
if "default" != safe_space_name:
|
||||||
|
raise Exception(f"you have not create your knowledge space.")
|
||||||
|
knowledge_space_service.create_knowledge_space(
|
||||||
|
KnowledgeSpaceRequest(
|
||||||
|
name=safe_space_name,
|
||||||
|
desc="first db-gpt rag application",
|
||||||
|
owner="dbgpt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return Result.succ(
|
||||||
|
knowledge_space_service.create_knowledge_document(
|
||||||
|
space=safe_space_name, request=request
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up temp file if anything goes wrong
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return Result.failed(code="E000X", msg=f"doc_file is None")
|
||||||
|
except Exception as e:
|
||||||
|
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/sync")
|
@router.post("/knowledge/{space_name}/document/sync")
|
||||||
|
@ -232,7 +232,8 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
node = AppChatComposerOperator(
|
node = AppChatComposerOperator(
|
||||||
model=self.llm_model,
|
model=self.llm_model,
|
||||||
temperature=float(self.prompt_template.temperature),
|
temperature=self._chat_param.get("temperature")
|
||||||
|
or float(self.prompt_template.temperature),
|
||||||
max_new_tokens=int(self.prompt_template.max_new_tokens),
|
max_new_tokens=int(self.prompt_template.max_new_tokens),
|
||||||
prompt=self.prompt_template.prompt,
|
prompt=self.prompt_template.prompt,
|
||||||
message_version=self._message_version,
|
message_version=self._message_version,
|
||||||
|
@ -1,21 +1,13 @@
|
|||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import Executor
|
from concurrent.futures import Executor
|
||||||
from datetime import datetime
|
from typing import AsyncIterator, Optional
|
||||||
from time import mktime
|
|
||||||
from typing import Iterator, Optional
|
|
||||||
from urllib.parse import urlencode, urlparse
|
|
||||||
|
|
||||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||||
from dbgpt.model.parameter import ProxyModelParameters
|
from dbgpt.model.parameter import ProxyModelParameters
|
||||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
SPARK_DEFAULT_API_VERSION = "v3"
|
|
||||||
|
|
||||||
|
|
||||||
def getlength(text):
|
def getlength(text):
|
||||||
length = 0
|
length = 0
|
||||||
@ -49,7 +41,7 @@ def spark_generate_stream(
|
|||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
stop=params.get("stop"),
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
for r in client.sync_generate_stream(request):
|
for r in client.generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
|
|
||||||
|
|
||||||
@ -74,120 +66,57 @@ def get_response(request_url, data):
|
|||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
class SparkAPI:
|
def extract_content(line: str):
|
||||||
def __init__(
|
if not line.strip():
|
||||||
self, appid: str, api_key: str, api_secret: str, spark_url: str
|
return line
|
||||||
) -> None:
|
if line.startswith("data: "):
|
||||||
self.appid = appid
|
json_str = line[len("data: ") :]
|
||||||
self.api_key = api_key
|
else:
|
||||||
self.api_secret = api_secret
|
raise ValueError("Error line content ")
|
||||||
self.host = urlparse(spark_url).netloc
|
|
||||||
self.path = urlparse(spark_url).path
|
|
||||||
|
|
||||||
self.spark_url = spark_url
|
try:
|
||||||
|
data = json.loads(json_str)
|
||||||
|
if data == "[DONE]":
|
||||||
|
return ""
|
||||||
|
|
||||||
def gen_url(self):
|
choices = data.get("choices", [])
|
||||||
from wsgiref.handlers import format_date_time
|
if choices and isinstance(choices, list):
|
||||||
|
delta = choices[0].get("delta", {})
|
||||||
# 生成RFC1123格式的时间戳
|
content = delta.get("content", "")
|
||||||
now = datetime.now()
|
return content
|
||||||
date = format_date_time(mktime(now.timetuple()))
|
else:
|
||||||
|
raise ValueError("Error line content ")
|
||||||
# 拼接字符串
|
except json.JSONDecodeError:
|
||||||
signature_origin = "host: " + self.host + "\n"
|
return ""
|
||||||
signature_origin += "date: " + date + "\n"
|
|
||||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
|
||||||
|
|
||||||
# 进行hmac-sha256进行加密
|
|
||||||
signature_sha = hmac.new(
|
|
||||||
self.api_secret.encode("utf-8"),
|
|
||||||
signature_origin.encode("utf-8"),
|
|
||||||
digestmod=hashlib.sha256,
|
|
||||||
).digest()
|
|
||||||
|
|
||||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
|
||||||
|
|
||||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
|
||||||
|
|
||||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
|
||||||
encoding="utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将请求的鉴权参数组合为字典
|
|
||||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
|
||||||
# 拼接鉴权参数,生成url
|
|
||||||
url = self.spark_url + "?" + urlencode(v)
|
|
||||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
class SparkLLMClient(ProxyLLMClient):
|
class SparkLLMClient(ProxyLLMClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_secret: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
api_domain: Optional[str] = None,
|
|
||||||
model_version: Optional[str] = None,
|
|
||||||
model_alias: Optional[str] = "spark_proxyllm",
|
model_alias: Optional[str] = "spark_proxyllm",
|
||||||
context_length: Optional[int] = 4096,
|
context_length: Optional[int] = 4096,
|
||||||
executor: Optional[Executor] = None,
|
executor: Optional[Executor] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Tips: 星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本,各版本独立计量tokens。
|
星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本
|
||||||
传输协议 :ws(s),为提高安全性,强烈推荐wss
|
Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra
|
||||||
|
|
||||||
Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra:
|
|
||||||
wss://spark-api.xf-yun.com/v4.0/chat
|
|
||||||
|
|
||||||
Spark Max-32K请求地址,对应的domain参数为max-32k
|
Spark Max-32K请求地址,对应的domain参数为max-32k
|
||||||
wss://spark-api.xf-yun.com/chat/max-32k
|
|
||||||
|
|
||||||
Spark Max请求地址,对应的domain参数为generalv3.5
|
Spark Max请求地址,对应的domain参数为generalv3.5
|
||||||
wss://spark-api.xf-yun.com/v3.5/chat
|
|
||||||
|
|
||||||
Spark Pro-128K请求地址,对应的domain参数为pro-128k:
|
Spark Pro-128K请求地址,对应的domain参数为pro-128k:
|
||||||
wss://spark-api.xf-yun.com/chat/pro-128k
|
|
||||||
|
|
||||||
Spark Pro请求地址,对应的domain参数为generalv3:
|
Spark Pro请求地址,对应的domain参数为generalv3:
|
||||||
wss://spark-api.xf-yun.com/v3.1/chat
|
|
||||||
|
|
||||||
Spark Lite请求地址,对应的domain参数为lite:
|
Spark Lite请求地址,对应的domain参数为lite:
|
||||||
wss://spark-api.xf-yun.com/v1.1/chat
|
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
|
||||||
"""
|
"""
|
||||||
if not model_version:
|
self._model = model or os.getenv("XUNFEI_SPARK_API_MODEL")
|
||||||
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
|
self._api_base = os.getenv("PROXY_SERVER_URL")
|
||||||
if not api_base:
|
self._api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
|
||||||
if model_version == SPARK_DEFAULT_API_VERSION:
|
if not self._model:
|
||||||
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
|
raise ValueError("model can't be empty")
|
||||||
domain = "generalv3"
|
if not self._api_base:
|
||||||
elif model_version == "v4.0":
|
raise ValueError("api_base can't be empty")
|
||||||
api_base = "ws://spark-api.xf-yun.com/v4.0/chat"
|
if not self._api_password:
|
||||||
domain = "4.0Ultra"
|
raise ValueError("api_password can't be empty")
|
||||||
elif model_version == "v3.5":
|
|
||||||
api_base = "ws://spark-api.xf-yun.com/v3.5/chat"
|
|
||||||
domain = "generalv3.5"
|
|
||||||
else:
|
|
||||||
api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
|
|
||||||
domain = "lite"
|
|
||||||
if not api_domain:
|
|
||||||
api_domain = domain
|
|
||||||
self._model = model
|
|
||||||
self._model_version = model_version
|
|
||||||
self._api_base = api_base
|
|
||||||
self._domain = api_domain
|
|
||||||
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
|
|
||||||
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
|
|
||||||
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
|
|
||||||
|
|
||||||
if not self._app_id:
|
|
||||||
raise ValueError("app_id can't be empty")
|
|
||||||
if not self._api_key:
|
|
||||||
raise ValueError("api_key can't be empty")
|
|
||||||
if not self._api_secret:
|
|
||||||
raise ValueError("api_secret can't be empty")
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_names=[model, model_alias],
|
model_names=[model, model_alias],
|
||||||
@ -203,10 +132,6 @@ class SparkLLMClient(ProxyLLMClient):
|
|||||||
) -> "SparkLLMClient":
|
) -> "SparkLLMClient":
|
||||||
return cls(
|
return cls(
|
||||||
model=model_params.proxyllm_backend,
|
model=model_params.proxyllm_backend,
|
||||||
app_id=model_params.proxy_api_app_id,
|
|
||||||
api_key=model_params.proxy_api_key,
|
|
||||||
api_secret=model_params.proxy_api_secret,
|
|
||||||
api_base=model_params.proxy_api_base,
|
|
||||||
model_alias=model_params.model_name,
|
model_alias=model_params.model_name,
|
||||||
context_length=model_params.max_context_size,
|
context_length=model_params.max_context_size,
|
||||||
executor=default_executor,
|
executor=default_executor,
|
||||||
@ -216,35 +141,45 @@ class SparkLLMClient(ProxyLLMClient):
|
|||||||
def default_model(self) -> str:
|
def default_model(self) -> str:
|
||||||
return self._model
|
return self._model
|
||||||
|
|
||||||
def sync_generate_stream(
|
def generate_stream(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
message_converter: Optional[MessageConverter] = None,
|
message_converter: Optional[MessageConverter] = None,
|
||||||
) -> Iterator[ModelOutput]:
|
) -> AsyncIterator[ModelOutput]:
|
||||||
|
"""
|
||||||
|
reference:
|
||||||
|
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
|
||||||
|
"""
|
||||||
request = self.local_covert_message(request, message_converter)
|
request = self.local_covert_message(request, message_converter)
|
||||||
messages = request.to_common_messages(support_system_role=False)
|
messages = request.to_common_messages(support_system_role=False)
|
||||||
request_id = request.context.request_id or "1"
|
|
||||||
data = {
|
|
||||||
"header": {"app_id": self._app_id, "uid": request_id},
|
|
||||||
"parameter": {
|
|
||||||
"chat": {
|
|
||||||
"domain": self._domain,
|
|
||||||
"random_threshold": 0.5,
|
|
||||||
"max_tokens": request.max_new_tokens,
|
|
||||||
"auditing": "default",
|
|
||||||
"temperature": request.temperature,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"payload": {"message": {"text": messages}},
|
|
||||||
}
|
|
||||||
|
|
||||||
spark_api = SparkAPI(
|
|
||||||
self._app_id, self._api_key, self._api_secret, self._api_base
|
|
||||||
)
|
|
||||||
request_url = spark_api.gen_url()
|
|
||||||
try:
|
try:
|
||||||
for text in get_response(request_url, data):
|
import requests
|
||||||
yield ModelOutput(text=text, error_code=0)
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import python package: requests "
|
||||||
|
"Please install requests by command `pip install requests"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": self._model, # 指定请求的模型
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
header = {
|
||||||
|
"Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword
|
||||||
|
}
|
||||||
|
response = requests.post(self._api_base, headers=header, json=data, stream=True)
|
||||||
|
# 流式响应解析示例
|
||||||
|
response.encoding = "utf-8"
|
||||||
|
try:
|
||||||
|
content = ""
|
||||||
|
# data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
|
||||||
|
# data: [DONE]
|
||||||
|
for line in response.iter_lines(decode_unicode=True):
|
||||||
|
print("llm out:", line)
|
||||||
|
content = content + extract_content(line)
|
||||||
|
yield ModelOutput(text=content, error_code=0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
|
@ -457,9 +457,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.vector_field = x.name
|
self.vector_field = x.name
|
||||||
# convert to milvus expr filter.
|
# convert to milvus expr filter.
|
||||||
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
|
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
|
||||||
_, docs_and_scores = self._search(
|
_, docs_and_scores = self._search(query=text, k=topk, expr=milvus_filter_expr)
|
||||||
query=text, k=topk, expr=milvus_filter_expr
|
|
||||||
)
|
|
||||||
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
||||||
|
Loading…
Reference in New Issue
Block a user