mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
242
dbgpt/model/proxy/base.py
Normal file
242
dbgpt/model/proxy/base.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
|
||||
|
||||
from dbgpt.core import (
|
||||
LLMClient,
|
||||
MessageConverter,
|
||||
ModelMetadata,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
)
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tiktoken import Encoding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
|
||||
"""Count token of given prompts.
|
||||
Args:
|
||||
model_name (str): model name
|
||||
prompts (List[str]): prompts to count token
|
||||
|
||||
Returns:
|
||||
List[int]: token count, -1 if failed
|
||||
"""
|
||||
|
||||
|
||||
class TiktokenProxyTokenizer(ProxyTokenizer):
|
||||
def __init__(self):
|
||||
self._cache = {}
|
||||
|
||||
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
|
||||
encoding_model = self._get_or_create_encoding_model(model_name)
|
||||
if not encoding_model:
|
||||
return [-1] * len(prompts)
|
||||
return [
|
||||
len(encoding_model.encode(prompt, disallowed_special=()))
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
def _get_or_create_encoding_model(self, model_name: str) -> Optional[Encoding]:
|
||||
if model_name in self._cache:
|
||||
return self._cache[model_name]
|
||||
encoding_model = None
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
logger.info(
|
||||
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
|
||||
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
|
||||
)
|
||||
except ImportError:
|
||||
self._support_encoding = False
|
||||
logger.warn("tiktoken not installed, cannot count tokens")
|
||||
return None
|
||||
try:
|
||||
if not model_name:
|
||||
model_name = "gpt-3.5-turbo"
|
||||
encoding_model = tiktoken.model.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
|
||||
)
|
||||
if encoding_model:
|
||||
self._cache[model_name] = encoding_model
|
||||
return encoding_model
|
||||
|
||||
|
||||
class ProxyLLMClient(LLMClient):
|
||||
"""Proxy LLM client base class"""
|
||||
|
||||
executor: Executor
|
||||
model_names: List[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: List[str],
|
||||
context_length: int = 4096,
|
||||
executor: Optional[Executor] = None,
|
||||
proxy_tokenizer: Optional[ProxyTokenizer] = None,
|
||||
):
|
||||
self.model_names = model_names
|
||||
self.context_length = context_length
|
||||
self.executor = executor or ThreadPoolExecutor()
|
||||
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def new_client(
|
||||
cls,
|
||||
model_params: ProxyModelParameters,
|
||||
default_executor: Optional[Executor] = None,
|
||||
) -> "ProxyLLMClient":
|
||||
"""Create a new client instance from model parameters.
|
||||
|
||||
Args:
|
||||
model_params (ProxyModelParameters): model parameters
|
||||
default_executor (Executor): default executor, If your model is blocking,
|
||||
you should pass a ThreadPoolExecutor.
|
||||
"""
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelOutput:
|
||||
"""Generate model output from model request.
|
||||
|
||||
We strongly recommend you to implement this method instead of sync_generate for high performance.
|
||||
|
||||
Args:
|
||||
request (ModelRequest): model request
|
||||
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ModelOutput: model output
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
self.executor, self.sync_generate, request, message_converter
|
||||
)
|
||||
|
||||
def sync_generate(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelOutput:
|
||||
"""Generate model output from model request.
|
||||
|
||||
Args:
|
||||
request (ModelRequest): model request
|
||||
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ModelOutput: model output
|
||||
"""
|
||||
output = None
|
||||
for out in self.sync_generate_stream(request, message_converter):
|
||||
output = out
|
||||
return output
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Generate model output stream from model request.
|
||||
|
||||
We strongly recommend you to implement this method instead of sync_generate_stream for high performance.
|
||||
|
||||
Args:
|
||||
request (ModelRequest): model request
|
||||
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: model output stream
|
||||
"""
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
async for output in iterate_in_threadpool(
|
||||
self.sync_generate_stream(request, message_converter)
|
||||
):
|
||||
yield output
|
||||
|
||||
def sync_generate_stream(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> Iterator[ModelOutput]:
|
||||
"""Generate model output stream from model request.
|
||||
|
||||
Args:
|
||||
request (ModelRequest): model request
|
||||
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Iterator[ModelOutput]: model output stream
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def models(self) -> List[ModelMetadata]:
|
||||
"""Get model metadata list
|
||||
|
||||
Returns:
|
||||
List[ModelMetadata]: model metadata list
|
||||
"""
|
||||
return self._models()
|
||||
|
||||
@cache
|
||||
def _models(self) -> List[ModelMetadata]:
|
||||
results = []
|
||||
for model in self.model_names:
|
||||
results.append(
|
||||
ModelMetadata(model=model, context_length=self.context_length)
|
||||
)
|
||||
return results
|
||||
|
||||
def local_covert_message(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelRequest:
|
||||
"""Convert message locally
|
||||
|
||||
Args:
|
||||
request (ModelRequest): model request
|
||||
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ModelRequest: converted model request
|
||||
"""
|
||||
if not message_converter:
|
||||
return request
|
||||
metadata = self._models[0].ext_metadata
|
||||
new_request = request.copy()
|
||||
new_messages = message_converter.convert(request.messages, metadata)
|
||||
new_request.messages = new_messages
|
||||
return new_request
|
||||
|
||||
async def count_token(self, model: str, prompt: str) -> int:
|
||||
"""Count token of given prompt
|
||||
|
||||
Args:
|
||||
model (str): model name
|
||||
prompt (str): prompt to count token
|
||||
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
|
||||
)[0]
|
Reference in New Issue
Block a user