mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -2,15 +2,19 @@ import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import List
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
from websockets.sync.client import connect
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
SPARK_DEFAULT_API_VERSION = "v3"
|
||||
@@ -34,63 +38,21 @@ def checklen(text):
|
||||
def spark_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
model_params = model.get_params()
|
||||
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
proxy_app_id = model_params.proxy_api_app_id
|
||||
|
||||
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
||||
url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
domain = "generalv3"
|
||||
else:
|
||||
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
domain = "generalv2"
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
last_user_input = None
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
print(f"index: {index}")
|
||||
if messages[index].role == ModelMessageRoleType.HUMAN:
|
||||
last_user_input = {"role": "user", "content": messages[index].content}
|
||||
del messages[index]
|
||||
break
|
||||
|
||||
# TODO: Support convert_to_compatible_format config
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||
|
||||
history = []
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
# There is no role for system in spark LLM
|
||||
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
question = checklen(history + [last_user_input])
|
||||
|
||||
print('last_user_input.get("content")', last_user_input.get("content"))
|
||||
data = {
|
||||
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": context_len,
|
||||
"auditing": "default",
|
||||
"temperature": params.get("temperature"),
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": question}},
|
||||
}
|
||||
|
||||
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
|
||||
request_url = spark_api.gen_url()
|
||||
return get_response(request_url, data)
|
||||
client: SparkLLMClient = model.proxy_llm_client
|
||||
context = ModelRequestContext(
|
||||
stream=True,
|
||||
user_name=params.get("user_name"),
|
||||
request_id=params.get("request_id"),
|
||||
)
|
||||
request = ModelRequest.build_request(
|
||||
client.default_model,
|
||||
messages=params["messages"],
|
||||
temperature=params.get("temperature"),
|
||||
context=context,
|
||||
max_new_tokens=params.get("max_new_tokens"),
|
||||
)
|
||||
for r in client.sync_generate_stream(request):
|
||||
yield r
|
||||
|
||||
|
||||
def get_response(request_url, data):
|
||||
@@ -107,8 +69,8 @@ def get_response(request_url, data):
|
||||
result += text[0]["content"]
|
||||
if choices.get("status") == 2:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
yield result
|
||||
|
||||
|
||||
@@ -155,3 +117,103 @@ class SparkAPI:
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
|
||||
class SparkLLMClient(ProxyLLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_secret: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_domain: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
model_alias: Optional[str] = "spark_proxyllm",
|
||||
context_length: Optional[int] = 4096,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
if not model_version:
|
||||
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
|
||||
if not api_base:
|
||||
if model_version == SPARK_DEFAULT_API_VERSION:
|
||||
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
domain = "generalv3"
|
||||
else:
|
||||
api_base = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
domain = "generalv2"
|
||||
if not api_domain:
|
||||
api_domain = domain
|
||||
self._model = model
|
||||
self.default_model = self._model
|
||||
self._model_version = model_version
|
||||
self._api_base = api_base
|
||||
self._domain = api_domain
|
||||
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
|
||||
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
|
||||
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
|
||||
|
||||
if not self._app_id:
|
||||
raise ValueError("app_id can't be empty")
|
||||
if not self._api_key:
|
||||
raise ValueError("api_key can't be empty")
|
||||
if not self._api_secret:
|
||||
raise ValueError("api_secret can't be empty")
|
||||
|
||||
super().__init__(
|
||||
model_names=[model, model_alias],
|
||||
context_length=context_length,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "SparkLLMClient":
|
||||
return cls(
|
||||
model=model_params.proxyllm_backend,
|
||||
app_id=model_params.proxy_api_app_id,
|
||||
api_key=model_params.proxy_api_key,
|
||||
api_secret=model_params.proxy_api_secret,
|
||||
api_base=model_params.proxy_api_base,
|
||||
model_alias=model_params.model_name,
|
||||
context_length=model_params.max_context_size,
|
||||
executor=default_executor,
|
||||
)
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
request = self.local_covert_message(request, message_converter)
|
||||
messages = request.to_common_messages(support_system_role=False)
|
||||
request_id = request.context.request_id or "1"
|
||||
data = {
|
||||
"header": {"app_id": self._app_id, "uid": request_id},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": self._domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": request.max_new_tokens,
|
||||
"auditing": "default",
|
||||
"temperature": request.temperature,
|
||||
}
|
||||
},
|
||||
"payload": {"message": {"text": messages}},
|
||||
}
|
||||
|
||||
spark_api = SparkAPI(
|
||||
self._app_id, self._api_key, self._api_secret, self._api_base
|
||||
)
|
||||
request_url = spark_api.gen_url()
|
||||
try:
|
||||
for text in get_response(request_url, data):
|
||||
yield ModelOutput(text=text, error_code=0)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=1,
|
||||
)
|
||||
|
Reference in New Issue
Block a user