feat(model): Proxy model support count token (#996)

This commit is contained in:
Fangyin Cheng
2023-12-29 12:01:31 +08:00
committed by GitHub
parent ba0599ebf4
commit 0cdc77abb2
16 changed files with 366 additions and 248 deletions

View File

@@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):
def __convert_2_gpt_messages(messages: List[ModelMessage]):
chat_round = 0
gpt_messages = []
last_usr_message = ""
system_messages = []
# TODO: We can't change message order in low level
for message in messages:
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
last_usr_message = message.content

View File

@@ -1,9 +1,36 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import logging
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
if TYPE_CHECKING:
from dbgpt.core.interface.message import ModelMessage, BaseMessage
logger = logging.getLogger(__name__)
class ProxyModel:
def __init__(self, model_params: ProxyModelParameters) -> None:
self._model_params = model_params
self._tokenizer = ProxyTokenizerWrapper()
def get_params(self) -> ProxyModelParameters:
return self._model_params
def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[int] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[int], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
return self._tokenizer.count_token(messages, model_name)