mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-21 11:29:15 +00:00
feat: call xunfei spark with stream, and fix the temperature bug (#2121)
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
parent
4efe643db8
commit
3ccfa94219
@ -241,10 +241,8 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk}
|
||||
#BAICHUAN_PROXY_API_SECRET={your-baichuan-sct}
|
||||
|
||||
# Xunfei Spark
|
||||
#XUNFEI_SPARK_API_VERSION={version}
|
||||
#XUNFEI_SPARK_APPID={your_app_id}
|
||||
#XUNFEI_SPARK_API_KEY={your_api_key}
|
||||
#XUNFEI_SPARK_API_SECRET={your_api_secret}
|
||||
#XUNFEI_SPARK_API_PASSWORD={your_api_password}
|
||||
#XUNFEI_SPARK_API_MODEL={version}
|
||||
|
||||
## Yi Proxyllm, https://platform.lingyiwanwu.com/docs
|
||||
#YI_MODEL_VERSION=yi-34b-chat-0205
|
||||
|
@ -1,41 +1,41 @@
|
||||
# Please run command `pre-commit install` to install pre-commit hook
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: python-fmt
|
||||
name: Python Format
|
||||
entry: make fmt-check
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-test
|
||||
name: Python Unit Test
|
||||
entry: make test
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-test-doc
|
||||
name: Python Doc Test
|
||||
entry: make test-doc
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-lint-mypy
|
||||
name: Python Lint mypy
|
||||
entry: make mypy
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
|
||||
# Please run command `pre-commit install` to install pre-commit hook
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: python-fmt
|
||||
name: Python Format
|
||||
entry: make fmt-check
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-test
|
||||
name: Python Unit Test
|
||||
entry: make test
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-test-doc
|
||||
name: Python Doc Test
|
||||
entry: make test-doc
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
- id: python-lint-mypy
|
||||
name: Python Lint mypy
|
||||
entry: make mypy
|
||||
language: system
|
||||
exclude: '^dbgpt/app/static/|^web/'
|
||||
types: [python]
|
||||
stages: [commit]
|
||||
pass_filenames: false
|
||||
args: []
|
||||
|
||||
|
@ -78,17 +78,13 @@ class Config(metaclass=Singleton):
|
||||
)
|
||||
|
||||
# xunfei spark
|
||||
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION")
|
||||
self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
|
||||
self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
|
||||
self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID")
|
||||
if self.spark_proxy_api_key and self.spark_proxy_api_secret:
|
||||
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
|
||||
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
|
||||
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version or ""
|
||||
os.environ["spark_proxyllm_proxy_api_app_id"] = (
|
||||
self.spark_proxy_api_appid or ""
|
||||
)
|
||||
self.spark_proxy_api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
|
||||
self.spark_proxy_api_model = os.getenv("XUNFEI_SPARK_API_MODEL")
|
||||
if self.spark_proxy_api_model and self.spark_proxy_api_password:
|
||||
os.environ[
|
||||
"spark_proxyllm_proxy_api_password"
|
||||
] = self.spark_proxy_api_password
|
||||
os.environ["spark_proxyllm_proxy_api_model"] = self.spark_proxy_api_model
|
||||
|
||||
# baichuan proxy
|
||||
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
||||
|
@ -291,6 +291,7 @@ class AWELAgentOperator(
|
||||
prompt_template = None
|
||||
if self.awel_agent.agent_prompt:
|
||||
from dbgpt.serve.prompt.api.endpoints import get_service
|
||||
|
||||
prompt_service = get_service()
|
||||
prompt_template = prompt_service.get_template(
|
||||
self.awel_agent.agent_prompt.code
|
||||
|
@ -12,6 +12,7 @@ from dbgpt.core.awel.flow import (
|
||||
ResourceCategory,
|
||||
register_resource,
|
||||
)
|
||||
|
||||
from ....resource.base import AgentResource, ResourceType
|
||||
from ....resource.manage import get_resource_manager
|
||||
from ....util.llm.llm import LLMConfig, LLMStrategyType
|
||||
@ -20,6 +21,7 @@ from ...agent_manage import get_agent_manager
|
||||
|
||||
def _agent_resource_prompt_values() -> List[OptionValue]:
|
||||
from dbgpt.serve.prompt.api.endpoints import get_service
|
||||
|
||||
prompt_service = get_service()
|
||||
prompts = prompt_service.get_target_prompt()
|
||||
return [
|
||||
|
@ -2,10 +2,10 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, UploadFile, HTTPException
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
@ -333,70 +333,72 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest):
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/upload")
|
||||
async def document_upload(
|
||||
space_name: str,
|
||||
doc_name: str = Form(...),
|
||||
doc_type: str = Form(...),
|
||||
doc_file: UploadFile = File(...),
|
||||
space_name: str,
|
||||
doc_name: str = Form(...),
|
||||
doc_type: str = Form(...),
|
||||
doc_file: UploadFile = File(...),
|
||||
):
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
# Sanitize inputs to prevent path traversal
|
||||
safe_space_name = os.path.basename(space_name)
|
||||
safe_filename = os.path.basename(doc_file.filename)
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
# Sanitize inputs to prevent path traversal
|
||||
safe_space_name = os.path.basename(space_name)
|
||||
safe_filename = os.path.basename(doc_file.filename)
|
||||
|
||||
# Create absolute paths and verify they are within allowed directory
|
||||
upload_dir = os.path.abspath(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name))
|
||||
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename))
|
||||
# Create absolute paths and verify they are within allowed directory
|
||||
upload_dir = os.path.abspath(
|
||||
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name)
|
||||
)
|
||||
target_path = os.path.abspath(os.path.join(upload_dir, safe_filename))
|
||||
|
||||
if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path:
|
||||
raise HTTPException(status_code=400, detail="Invalid path detected")
|
||||
if not os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) in target_path:
|
||||
raise HTTPException(status_code=400, detail="Invalid path detected")
|
||||
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir)
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir)
|
||||
|
||||
# Create temp file
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir)
|
||||
|
||||
try:
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
|
||||
shutil.move(tmp_path, target_path)
|
||||
# Create temp file
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir)
|
||||
|
||||
request = KnowledgeDocumentRequest()
|
||||
request.doc_name = doc_name
|
||||
request.doc_type = doc_type
|
||||
request.content = target_path
|
||||
try:
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
|
||||
space_res = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest(name=safe_space_name)
|
||||
)
|
||||
if len(space_res) == 0:
|
||||
# create default space
|
||||
if "default" != safe_space_name:
|
||||
raise Exception(f"you have not create your knowledge space.")
|
||||
knowledge_space_service.create_knowledge_space(
|
||||
KnowledgeSpaceRequest(
|
||||
name=safe_space_name,
|
||||
desc="first db-gpt rag application",
|
||||
owner="dbgpt",
|
||||
)
|
||||
)
|
||||
return Result.succ(
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=safe_space_name, request=request
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Clean up temp file if anything goes wrong
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
raise e
|
||||
shutil.move(tmp_path, target_path)
|
||||
|
||||
return Result.failed(code="E000X", msg=f"doc_file is None")
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||
request = KnowledgeDocumentRequest()
|
||||
request.doc_name = doc_name
|
||||
request.doc_type = doc_type
|
||||
request.content = target_path
|
||||
|
||||
space_res = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest(name=safe_space_name)
|
||||
)
|
||||
if len(space_res) == 0:
|
||||
# create default space
|
||||
if "default" != safe_space_name:
|
||||
raise Exception(f"you have not create your knowledge space.")
|
||||
knowledge_space_service.create_knowledge_space(
|
||||
KnowledgeSpaceRequest(
|
||||
name=safe_space_name,
|
||||
desc="first db-gpt rag application",
|
||||
owner="dbgpt",
|
||||
)
|
||||
)
|
||||
return Result.succ(
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=safe_space_name, request=request
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Clean up temp file if anything goes wrong
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
raise e
|
||||
|
||||
return Result.failed(code="E000X", msg=f"doc_file is None")
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync")
|
||||
|
@ -232,7 +232,8 @@ class BaseChat(ABC):
|
||||
)
|
||||
node = AppChatComposerOperator(
|
||||
model=self.llm_model,
|
||||
temperature=float(self.prompt_template.temperature),
|
||||
temperature=self._chat_param.get("temperature")
|
||||
or float(self.prompt_template.temperature),
|
||||
max_new_tokens=int(self.prompt_template.max_new_tokens),
|
||||
prompt=self.prompt_template.prompt,
|
||||
message_version=self._message_version,
|
||||
|
@ -1,21 +1,13 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
SPARK_DEFAULT_API_VERSION = "v3"
|
||||
|
||||
|
||||
def getlength(text):
|
||||
length = 0
|
||||
@ -49,7 +41,7 @@ def spark_generate_stream(
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
stop=params.get("stop"),
|
||||
)
|
||||
for r in client.sync_generate_stream(request):
|
||||
for r in client.generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
@ -74,120 +66,57 @@ def get_response(request_url, data):
|
||||
yield result
|
||||
|
||||
|
||||
class SparkAPI:
|
||||
def __init__(
|
||||
self, appid: str, api_key: str, api_secret: str, spark_url: str
|
||||
) -> None:
|
||||
self.appid = appid
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.host = urlparse(spark_url).netloc
|
||||
self.path = urlparse(spark_url).path
|
||||
def extract_content(line: str):
|
||||
if not line.strip():
|
||||
return line
|
||||
if line.startswith("data: "):
|
||||
json_str = line[len("data: ") :]
|
||||
else:
|
||||
raise ValueError("Error line content ")
|
||||
|
||||
self.spark_url = spark_url
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if data == "[DONE]":
|
||||
return ""
|
||||
|
||||
def gen_url(self):
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"),
|
||||
signature_origin.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
choices = data.get("choices", [])
|
||||
if choices and isinstance(choices, list):
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
return content
|
||||
else:
|
||||
raise ValueError("Error line content ")
|
||||
except json.JSONDecodeError:
|
||||
return ""
|
||||
|
||||
|
||||
class SparkLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_secret: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_domain: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
model_alias: Optional[str] = "spark_proxyllm",
|
||||
context_length: Optional[int] = 4096,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
"""
|
||||
Tips: 星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本,各版本独立计量tokens。
|
||||
传输协议 :ws(s),为提高安全性,强烈推荐wss
|
||||
|
||||
Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra:
|
||||
wss://spark-api.xf-yun.com/v4.0/chat
|
||||
|
||||
星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本
|
||||
Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra
|
||||
Spark Max-32K请求地址,对应的domain参数为max-32k
|
||||
wss://spark-api.xf-yun.com/chat/max-32k
|
||||
|
||||
Spark Max请求地址,对应的domain参数为generalv3.5
|
||||
wss://spark-api.xf-yun.com/v3.5/chat
|
||||
|
||||
Spark Pro-128K请求地址,对应的domain参数为pro-128k:
|
||||
wss://spark-api.xf-yun.com/chat/pro-128k
|
||||
|
||||
Spark Pro请求地址,对应的domain参数为generalv3:
|
||||
wss://spark-api.xf-yun.com/v3.1/chat
|
||||
|
||||
Spark Lite请求地址,对应的domain参数为lite:
|
||||
wss://spark-api.xf-yun.com/v1.1/chat
|
||||
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
|
||||
"""
|
||||
if not model_version:
|
||||
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
|
||||
if not api_base:
|
||||
if model_version == SPARK_DEFAULT_API_VERSION:
|
||||
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
domain = "generalv3"
|
||||
elif model_version == "v4.0":
|
||||
api_base = "ws://spark-api.xf-yun.com/v4.0/chat"
|
||||
domain = "4.0Ultra"
|
||||
elif model_version == "v3.5":
|
||||
api_base = "ws://spark-api.xf-yun.com/v3.5/chat"
|
||||
domain = "generalv3.5"
|
||||
else:
|
||||
api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
domain = "lite"
|
||||
if not api_domain:
|
||||
api_domain = domain
|
||||
self._model = model
|
||||
self._model_version = model_version
|
||||
self._api_base = api_base
|
||||
self._domain = api_domain
|
||||
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
|
||||
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
|
||||
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
|
||||
|
||||
if not self._app_id:
|
||||
raise ValueError("app_id can't be empty")
|
||||
if not self._api_key:
|
||||
raise ValueError("api_key can't be empty")
|
||||
if not self._api_secret:
|
||||
raise ValueError("api_secret can't be empty")
|
||||
self._model = model or os.getenv("XUNFEI_SPARK_API_MODEL")
|
||||
self._api_base = os.getenv("PROXY_SERVER_URL")
|
||||
self._api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
|
||||
if not self._model:
|
||||
raise ValueError("model can't be empty")
|
||||
if not self._api_base:
|
||||
raise ValueError("api_base can't be empty")
|
||||
if not self._api_password:
|
||||
raise ValueError("api_password can't be empty")
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
@ -203,10 +132,6 @@ class SparkLLMClient(ProxyLLMClient):
|
||||
) -> "SparkLLMClient":
|
||||
return cls(
|
||||
model=model_params.proxyllm_backend,
|
||||
app_id=model_params.proxy_api_app_id,
|
||||
api_key=model_params.proxy_api_key,
|
||||
api_secret=model_params.proxy_api_secret,
|
||||
api_base=model_params.proxy_api_base,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=model_params.max_context_size,
|
||||
executor=default_executor,
|
||||
@ -216,35 +141,45 @@ class SparkLLMClient(ProxyLLMClient):
|
||||
def default_model(self) -> str:
|
||||
return self._model
|
||||
|
||||
def sync_generate_stream(
|
||||
def generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""
|
||||
reference:
|
||||
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
|
||||
"""
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
messages = request.to_common_messages(support_system_role=False)
|
||||
request_id = request.context.request_id or "1"
|
||||
data = {
|
||||
"header": {"app_id": self._app_id, "uid": request_id},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": self._domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": request.max_new_tokens,
|
||||
"auditing": "default",
|
||||
"temperature": request.temperature,
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": messages}},
|
||||
}
|
||||
|
||||
spark_api = SparkAPI(
|
||||
self._app_id, self._api_key, self._api_secret, self._api_base
|
||||
)
|
||||
request_url = spark_api.gen_url()
|
||||
try:
|
||||
for text in get_response(request_url, data):
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
import requests
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import python package: requests "
|
||||
"Please install requests by command `pip install requests"
|
||||
) from e
|
||||
|
||||
data = {
|
||||
"model": self._model, # 指定请求的模型
|
||||
"messages": messages,
|
||||
"temperature": request.temperature,
|
||||
"stream": True,
|
||||
}
|
||||
header = {
|
||||
"Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword
|
||||
}
|
||||
response = requests.post(self._api_base, headers=header, json=data, stream=True)
|
||||
# 流式响应解析示例
|
||||
response.encoding = "utf-8"
|
||||
try:
|
||||
content = ""
|
||||
# data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
|
||||
# data: [DONE]
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
print("llm out:", line)
|
||||
content = content + extract_content(line)
|
||||
yield ModelOutput(text=content, error_code=0)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
|
@ -457,9 +457,7 @@ class MilvusStore(VectorStoreBase):
|
||||
self.vector_field = x.name
|
||||
# convert to milvus expr filter.
|
||||
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
|
||||
_, docs_and_scores = self._search(
|
||||
query=text, k=topk, expr=milvus_filter_expr
|
||||
)
|
||||
_, docs_and_scores = self._search(query=text, k=topk, expr=milvus_filter_expr)
|
||||
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
||||
logger.warning(
|
||||
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
||||
|
Loading…
Reference in New Issue
Block a user