DB-GPT/dbgpt/model/proxy/base.py
2024-01-16 17:36:26 +08:00

253 lines
7.8 KiB
Python

from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import cache
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
from dbgpt.core import (
LLMClient,
MessageConverter,
ModelMetadata,
ModelOutput,
ModelRequest,
)
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.util.executor_utils import blocking_func_to_async
if TYPE_CHECKING:
from tiktoken import Encoding
logger = logging.getLogger(__name__)
class ProxyTokenizer(ABC):
@abstractmethod
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
"""Count token of given prompts.
Args:
model_name (str): model name
prompts (List[str]): prompts to count token
Returns:
List[int]: token count, -1 if failed
"""
class TiktokenProxyTokenizer(ProxyTokenizer):
def __init__(self):
self._cache = {}
def count_token(self, model_name: str, prompts: List[str]) -> List[int]:
encoding_model = self._get_or_create_encoding_model(model_name)
if not encoding_model:
return [-1] * len(prompts)
return [
len(encoding_model.encode(prompt, disallowed_special=()))
for prompt in prompts
]
def _get_or_create_encoding_model(self, model_name: str) -> Optional[Encoding]:
if model_name in self._cache:
return self._cache[model_name]
encoding_model = None
try:
import tiktoken
logger.info(
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
)
except ImportError:
self._support_encoding = False
logger.warn("tiktoken not installed, cannot count tokens")
return None
try:
if not model_name:
model_name = "gpt-3.5-turbo"
encoding_model = tiktoken.model.encoding_for_model(model_name)
except KeyError:
logger.warning(
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
)
if encoding_model:
self._cache[model_name] = encoding_model
return encoding_model
class ProxyLLMClient(LLMClient):
"""Proxy LLM client base class"""
executor: Executor
model_names: List[str]
def __init__(
self,
model_names: List[str],
context_length: int = 4096,
executor: Optional[Executor] = None,
proxy_tokenizer: Optional[ProxyTokenizer] = None,
):
self.model_names = model_names
self.context_length = context_length
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()
@classmethod
@abstractmethod
def new_client(
cls,
model_params: ProxyModelParameters,
default_executor: Optional[Executor] = None,
) -> "ProxyLLMClient":
"""Create a new client instance from model parameters.
Args:
model_params (ProxyModelParameters): model parameters
default_executor (Executor): default executor, If your model is blocking,
you should pass a ThreadPoolExecutor.
"""
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
"""Generate model output from model request.
We strongly recommend you to implement this method instead of sync_generate for high performance.
Args:
request (ModelRequest): model request
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
Returns:
ModelOutput: model output
"""
return await blocking_func_to_async(
self.executor, self.sync_generate, request, message_converter
)
def sync_generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
"""Generate model output from model request.
Args:
request (ModelRequest): model request
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
Returns:
ModelOutput: model output
"""
output = None
for out in self.sync_generate_stream(request, message_converter):
output = out
return output
async def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
"""Generate model output stream from model request.
We strongly recommend you to implement this method instead of sync_generate_stream for high performance.
Args:
request (ModelRequest): model request
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
Returns:
AsyncIterator[ModelOutput]: model output stream
"""
from starlette.concurrency import iterate_in_threadpool
async for output in iterate_in_threadpool(
self.sync_generate_stream(request, message_converter)
):
yield output
def sync_generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
"""Generate model output stream from model request.
Args:
request (ModelRequest): model request
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
Returns:
Iterator[ModelOutput]: model output stream
"""
raise NotImplementedError()
async def models(self) -> List[ModelMetadata]:
"""Get model metadata list
Returns:
List[ModelMetadata]: model metadata list
"""
return self._models()
@property
def default_model(self) -> str:
"""Get default model name
Returns:
str: default model name
"""
return self.model_names[0]
@cache
def _models(self) -> List[ModelMetadata]:
results = []
for model in self.model_names:
results.append(
ModelMetadata(model=model, context_length=self.context_length)
)
return results
def local_covert_message(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelRequest:
"""Convert message locally
Args:
request (ModelRequest): model request
message_converter (Optional[MessageConverter], optional): message converter. Defaults to None.
Returns:
ModelRequest: converted model request
"""
if not message_converter:
return request
metadata = self._models[0].ext_metadata
new_request = request.copy()
new_messages = message_converter.convert(request.messages, metadata)
new_request.messages = new_messages
return new_request
async def count_token(self, model: str, prompt: str) -> int:
"""Count token of given prompt
Args:
model (str): model name
prompt (str): prompt to count token
Returns:
int: token count, -1 if failed
"""
counts = await blocking_func_to_async(
self.executor, self.proxy_tokenizer.count_token, model, [prompt]
)
return counts[0]