mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-10 21:39:33 +00:00
feat(model): Proxy model support count token (#996)
This commit is contained in:
80
dbgpt/model/utils/token_utils.py
Normal file
80
dbgpt/model/utils/token_utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.interface.message import ModelMessage, BaseMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyTokenizerWrapper:
|
||||
def __init__(self) -> None:
|
||||
self._support_encoding = True
|
||||
self._encoding_model = None
|
||||
|
||||
def count_token(
|
||||
self,
|
||||
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
|
||||
model_name: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Count token of given messages
|
||||
|
||||
Args:
|
||||
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
|
||||
model_name (Optional[str], optional): model name. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: token count, -1 if failed
|
||||
"""
|
||||
if not self._support_encoding:
|
||||
logger.warning(
|
||||
"model does not support encoding model, can't count token, returning -1"
|
||||
)
|
||||
return -1
|
||||
encoding = self._get_or_create_encoding_model(model_name)
|
||||
cnt = 0
|
||||
if isinstance(messages, str):
|
||||
cnt = len(encoding.encode(messages, disallowed_special=()))
|
||||
elif isinstance(messages, BaseMessage):
|
||||
cnt = len(encoding.encode(messages.content, disallowed_special=()))
|
||||
elif isinstance(messages, ModelMessage):
|
||||
cnt = len(encoding.encode(messages.content, disallowed_special=()))
|
||||
elif isinstance(messages, list):
|
||||
for message in messages:
|
||||
cnt += len(encoding.encode(message.content, disallowed_special=()))
|
||||
else:
|
||||
logger.warning(
|
||||
"unsupported type of messages, can't count token, returning -1"
|
||||
)
|
||||
return -1
|
||||
return cnt
|
||||
|
||||
def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
|
||||
"""Get or create encoding model for given model name
|
||||
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
"""
|
||||
if self._encoding_model:
|
||||
return self._encoding_model
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
logger.info(
|
||||
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
|
||||
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
|
||||
)
|
||||
except ImportError:
|
||||
self._support_encoding = False
|
||||
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
|
||||
return -1
|
||||
try:
|
||||
if not model_name:
|
||||
model_name = "gpt-3.5-turbo"
|
||||
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
|
||||
)
|
||||
self._encoding_model = tiktoken.get_encoding("cl100k_base")
|
||||
return self._encoding_model
|
Reference in New Issue
Block a user