refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -1,46 +1,14 @@
from typing import List
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
def __convert_2_zhipu_messages(messages: List[ModelMessage]):
chat_round = 0
wenxin_messages = []
last_usr_message = ""
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
last_ai_message = message.content
wenxin_messages.append({"role": "user", "content": last_usr_message})
wenxin_messages.append({"role": "assistant", "content": last_ai_message})
# build last user messge
if len(system_messages) > 0:
if len(system_messages) > 1:
end_message = system_messages[-1]
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
end_message = system_messages[-1] + "\n" + last_message.content
else:
end_message = system_messages[-1]
else:
last_message = messages[-1]
end_message = last_message.content
wenxin_messages.append({"role": "user", "content": end_message})
return wenxin_messages, system_messages
def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@@ -48,27 +16,93 @@ def zhipu_generate_stream(
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
# TODO proxy model use unified config?
proxy_api_key = model_params.proxy_api_key
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
import zhipuai
zhipuai.api_key = proxy_api_key
messages: List[ModelMessage] = params["messages"]
# TODO: Support convert_to_compatible_format config, zhipu not support system message
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
history, systems = __convert_2_zhipu_messages(messages)
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,
# convert_to_compatible_format = params.get("convert_to_compatible_format", False)
# history, systems = __convert_2_zhipu_messages(messages)
client: ZhipuLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in res.events():
if r.event == "add":
yield r.data
for r in client.sync_generate_stream(request):
yield r
class ZhipuLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
model_alias: Optional[str] = "zhipu_proxyllm",
context_length: Optional[int] = 8192,
executor: Optional[Executor] = None,
):
try:
import zhipuai
except ImportError as exc:
raise ValueError(
"Could not import python package: zhipuai "
"Please install dashscope by command `pip install zhipuai"
) from exc
if not model:
model = CHATGLM_DEFAULT_MODEL
if api_key:
zhipuai.api_key = api_key
self._model = model
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "ZhipuLLMClient":
return cls(
model=model_params.proxyllm_backend,
api_key=model_params.proxy_api_key,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
import zhipuai
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False)
model = request.model or self._model
try:
res = zhipuai.model_api.sse_invoke(
model=model,
prompt=messages,
temperature=request.temperature,
# top_p=params.get("top_p"),
incremental=False,
)
for r in res.events():
if r.event == "add":
yield ModelOutput(text=r.data, error_code=0)
elif r.event == "error":
yield ModelOutput(text=r.data, error_code=1)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)