DB-GPT/dbgpt/model/utils/token_utils.py
2024-01-10 10:39:04 +08:00

81 lines
3.0 KiB
Python

from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Union
if TYPE_CHECKING:
from dbgpt.core.interface.message import BaseMessage, ModelMessage
logger = logging.getLogger(__name__)
class ProxyTokenizerWrapper:
def __init__(self) -> None:
self._support_encoding = True
self._encoding_model = None
def count_token(
self,
messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
model_name: Optional[str] = None,
) -> int:
"""Count token of given messages
Args:
messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
model_name (Optional[str], optional): model name. Defaults to None.
Returns:
int: token count, -1 if failed
"""
if not self._support_encoding:
logger.warning(
"model does not support encoding model, can't count token, returning -1"
)
return -1
encoding = self._get_or_create_encoding_model(model_name)
cnt = 0
if isinstance(messages, str):
cnt = len(encoding.encode(messages, disallowed_special=()))
elif isinstance(messages, BaseMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, ModelMessage):
cnt = len(encoding.encode(messages.content, disallowed_special=()))
elif isinstance(messages, list):
for message in messages:
cnt += len(encoding.encode(message.content, disallowed_special=()))
else:
logger.warning(
"unsupported type of messages, can't count token, returning -1"
)
return -1
return cnt
def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
"""Get or create encoding model for given model name
More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if self._encoding_model:
return self._encoding_model
try:
import tiktoken
logger.info(
"tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
"also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
)
except ImportError:
self._support_encoding = False
logger.warn("tiktoken not installed, cannot count tokens, returning -1")
return -1
try:
if not model_name:
model_name = "gpt-3.5-turbo"
self._encoding_model = tiktoken.model.encoding_for_model(model_name)
except KeyError:
logger.warning(
f"{model_name}'s tokenizer not found, using cl100k_base encoding."
)
self._encoding_model = tiktoken.get_encoding("cl100k_base")
return self._encoding_model