refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -2,15 +2,19 @@ import base64
import hashlib
import hmac
import json
import os
from concurrent.futures import Executor
from datetime import datetime
from time import mktime
from typing import List
from typing import Iterator, Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
from websockets.sync.client import connect
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v3"
@@ -34,63 +38,21 @@ def checklen(text):
def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
model_params = model.get_params()
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
proxy_app_id = model_params.proxy_api_app_id
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
url = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
else:
url = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
messages: List[ModelMessage] = params["messages"]
last_user_input = None
for index in range(len(messages) - 1, -1, -1):
print(f"index: {index}")
if messages[index].role == ModelMessageRoleType.HUMAN:
last_user_input = {"role": "user", "content": messages[index].content}
del messages[index]
break
# TODO: Support convert_to_compatible_format config
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history = []
# Add history conversation
for message in messages:
# There is no role for system in spark LLM
if message.role == ModelMessageRoleType.HUMAN or ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
question = checklen(history + [last_user_input])
print('last_user_input.get("content")', last_user_input.get("content"))
data = {
"header": {"app_id": proxy_app_id, "uid": str(params.get("request_id", 1))},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": context_len,
"auditing": "default",
"temperature": params.get("temperature"),
}
},
"payload": {"message": {"text": question}},
}
spark_api = SparkAPI(proxy_app_id, proxy_api_key, proxy_api_secret, url)
request_url = spark_api.gen_url()
return get_response(request_url, data)
client: SparkLLMClient = model.proxy_llm_client
context = ModelRequestContext(
stream=True,
user_name=params.get("user_name"),
request_id=params.get("request_id"),
)
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
def get_response(request_url, data):
@@ -107,8 +69,8 @@ def get_response(request_url, data):
result += text[0]["content"]
if choices.get("status") == 2:
break
except Exception:
break
except Exception as e:
raise e
yield result
@@ -155,3 +117,103 @@ class SparkAPI:
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
class SparkLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
app_id: Optional[str] = None,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
api_base: Optional[str] = None,
api_domain: Optional[str] = None,
model_version: Optional[str] = None,
model_alias: Optional[str] = "spark_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
if not model_version:
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
if not api_base:
if model_version == SPARK_DEFAULT_API_VERSION:
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
else:
api_base = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
if not api_domain:
api_domain = domain
self._model = model
self.default_model = self._model
self._model_version = model_version
self._api_base = api_base
self._domain = api_domain
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")
if not self._app_id:
raise ValueError("app_id can't be empty")
if not self._api_key:
raise ValueError("api_key can't be empty")
if not self._api_secret:
raise ValueError("api_secret can't be empty")
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "SparkLLMClient":
return cls(
model=model_params.proxyllm_backend,
app_id=model_params.proxy_api_app_id,
api_key=model_params.proxy_api_key,
api_secret=model_params.proxy_api_secret,
api_base=model_params.proxy_api_base,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False)
request_id = request.context.request_id or "1"
data = {
"header": {"app_id": self._app_id, "uid": request_id},
"parameter": {
"chat": {
"domain": self._domain,
"random_threshold": 0.5,
"max_tokens": request.max_new_tokens,
"auditing": "default",
"temperature": request.temperature,
}
},
"payload": {"message": {"text": messages}},
}
spark_api = SparkAPI(
self._app_id, self._api_key, self._api_secret, self._api_base
)
request_url = spark_api.gen_url()
try:
for text in get_response(request_url, data):
yield ModelOutput(text=text, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)