refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -3,34 +3,21 @@ from __future__ import annotations
import importlib.metadata as metadata
import logging
import os
from abc import ABC
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
)
from dbgpt._private.pydantic import model_to_json
from dbgpt.component import ComponentType
from dbgpt.core.awel import BaseOperator, TransformStreamAbsOperator
from dbgpt.core.interface.llm import (
LLMClient,
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
)
from dbgpt.core.awel import TransformStreamAbsOperator
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.operator import BaseLLM
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
import httpx
@@ -101,14 +88,14 @@ def _initialize_openai_v1(init_params: OpenAIParameters):
return openai_params, api_type, api_version
def _build_openai_client(init_params: OpenAIParameters):
def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]:
import httpx
openai_params, api_type, api_version = _initialize_openai_v1(init_params)
if api_type == "azure":
from openai import AsyncAzureOpenAI
return AsyncAzureOpenAI(
return api_type, AsyncAzureOpenAI(
api_key=openai_params["api_key"],
api_version=api_version,
azure_endpoint=openai_params["base_url"],
@@ -117,149 +104,11 @@ def _build_openai_client(init_params: OpenAIParameters):
else:
from openai import AsyncOpenAI
return AsyncOpenAI(
return api_type, AsyncOpenAI(
**openai_params, http_client=httpx.AsyncClient(proxies=init_params.proxies)
)
class OpenAILLMClient(LLMClient):
"""An implementation of LLMClient using OpenAI API.
In order to have as few dependencies as possible, we directly use the http API.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
model: Optional[str] = "gpt-3.5-turbo",
proxies: Optional["ProxiesTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "chatgpt_proxyllm",
context_length: Optional[int] = 8192,
openai_client: Optional["ClientType"] = None,
openai_kwargs: Optional[Dict[str, Any]] = None,
):
self._init_params = OpenAIParameters(
api_type=api_type,
api_base=api_base,
api_key=api_key,
api_version=api_version,
proxies=proxies,
)
self._model = model
self._proxies = proxies
self._timeout = timeout
self._model_alias = model_alias
self._context_length = context_length
self._client = openai_client
self._openai_kwargs = openai_kwargs or {}
self._tokenizer = ProxyTokenizerWrapper()
@property
def client(self) -> ClientType:
if self._client is None:
self._client = _build_openai_client(init_params=self._init_params)
return self._client
def _build_request(
self, request: ModelRequest, stream: Optional[bool] = False
) -> Dict[str, Any]:
payload = {"model": request.model or self._model, "stream": stream}
# Apply openai kwargs
for k, v in self._openai_kwargs.items():
payload[k] = v
if request.temperature:
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
return payload
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
request = await self.covert_message(request, message_converter)
messages = request.to_openai_messages()
payload = self._build_request(request)
logger.info(
f"Send request to openai, payload: {payload}\n\n messages:\n{messages}"
)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = chat_completion.choices[0].message.content
usage = chat_completion.usage.dict()
return ModelOutput(text=text, error_code=0, usage=usage)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
request = await self.covert_message(request, message_converter)
messages = request.to_openai_messages()
payload = self._build_request(request, True)
logger.info(
f"Send request to openai, payload: {payload}\n\n messages:\n{messages}"
)
try:
chat_completion = await self.client.chat.completions.create(
messages=messages, **payload
)
text = ""
async for r in chat_completion:
if len(r.choices) == 0:
continue
if r.choices[0].delta.content is not None:
content = r.choices[0].delta.content
text += content
yield ModelOutput(text=text, error_code=0)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
async def models(self) -> List[ModelMetadata]:
model_metadata = ModelMetadata(
model=self._model_alias,
context_length=await self.get_context_length(),
)
return [model_metadata]
async def get_context_length(self) -> int:
"""Get the context length of the model.
Returns:
int: The context length.
# TODO: This is a temporary solution. We should have a better way to get the context length.
eg. get real context length from the openai api.
"""
return self._context_length
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
Args:
model (str): The model name.
prompt (str): The prompt.
"""
return self._tokenizer.count_token(prompt, model)
class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]):
"""Transform ModelOutput to openai stream format."""