mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 06:30:02 +00:00
feat(model): Proxy model support count token (#996)
This commit is contained in:
@@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
|
||||
|
||||
def __convert_2_gpt_messages(messages: List[ModelMessage]):
|
||||
chat_round = 0
|
||||
gpt_messages = []
|
||||
last_usr_message = ""
|
||||
system_messages = []
|
||||
|
||||
# TODO: We can't change message order in low level
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN or message.role == "user":
|
||||
last_usr_message = message.content
|
||||
|
@@ -1,9 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
import logging
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.interface.message import ModelMessage, BaseMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyModel:
|
||||
def __init__(self, model_params: ProxyModelParameters) -> None:
|
||||
self._model_params = model_params
|
||||
self._tokenizer = ProxyTokenizerWrapper()
|
||||
|
||||
def get_params(self) -> ProxyModelParameters:
|
||||
return self._model_params
|
||||
|
||||
def count_token(
|
||||
self,
|
||||
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
|
||||
model_name: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Count token of given messages
|
||||
|
||||
Args:
|
||||
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
|
||||
model_name (Optional[int], optional): model name. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
return self._tokenizer.count_token(messages, model_name)
|
||||
|
Reference in New Issue
Block a user