refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -1,79 +1,109 @@
import logging
from typing import List
from concurrent.futures import Executor
from typing import Iterator, Optional
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
def __convert_2_tongyi_messages(messages: List[ModelMessage]):
chat_round = 0
tongyi_messages = []
last_usr_message = ""
system_messages = []
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
last_usr_message = message.content
elif message.role == ModelMessageRoleType.SYSTEM:
system_messages.append(message.content)
elif message.role == ModelMessageRoleType.AI:
last_ai_message = message.content
tongyi_messages.append({"role": "user", "content": last_usr_message})
tongyi_messages.append({"role": "assistant", "content": last_ai_message})
if len(system_messages) > 0:
if len(system_messages) < 2:
tongyi_messages.insert(0, {"role": "system", "content": system_messages[0]})
tongyi_messages.append({"role": "user", "content": last_usr_message})
else:
tongyi_messages.append({"role": "user", "content": system_messages[1]})
else:
last_message = messages[-1]
if last_message.role == ModelMessageRoleType.HUMAN:
tongyi_messages.append({"role": "user", "content": last_message.content})
return tongyi_messages
def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import dashscope
from dashscope import Generation
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
dashscope.api_key = proxy_api_key
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
messages: List[ModelMessage] = params["messages"]
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
if convert_to_compatible_format:
history = __convert_2_tongyi_messages(messages)
else:
history = ModelMessage.to_openai_messages(messages)
gen = Generation()
res = gen.call(
proxyllm_backend,
messages=history,
top_p=params.get("top_p", 0.8),
stream=True,
result_format="message",
client: TongyiLLMClient = model.proxy_llm_client
context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
request = ModelRequest.build_request(
client.default_model,
messages=params["messages"],
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
)
for r in client.sync_generate_stream(request):
yield r
for r in res:
if r:
if r["status_code"] == 200:
content = r["output"]["choices"][0]["message"].get("content")
yield content
else:
content = r["code"] + ":" + r["message"]
yield content
class TongyiLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_region: Optional[str] = None,
model_alias: Optional[str] = "tongyi_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
try:
import dashscope
from dashscope import Generation
except ImportError as exc:
raise ValueError(
"Could not import python package: dashscope "
"Please install dashscope by command `pip install dashscope"
) from exc
if not model:
model = Generation.Models.qwen_turbo
if api_key:
dashscope.api_key = api_key
if api_region:
dashscope.api_region = api_region
self._model = model
self.default_model = self._model
super().__init__(
model_names=[model, model_alias],
context_length=context_length,
executor=executor,
)
@classmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "TongyiLLMClient":
return cls(
model=model_params.proxyllm_backend,
api_key=model_params.proxy_api_key,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
)
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
from dashscope import Generation
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages()
model = request.model or self._model
try:
gen = Generation()
res = gen.call(
model,
messages=messages,
top_p=0.8,
stream=True,
result_format="message",
)
for r in res:
if r:
if r["status_code"] == 200:
content = r["output"]["choices"][0]["message"].get("content")
yield ModelOutput(text=content, error_code=0)
else:
content = r["code"] + ":" + r["message"]
yield ModelOutput(text=content, error_code=-1)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)