mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-22 17:39:02 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			81 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			81 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| import logging
 | |
| from typing import TYPE_CHECKING, List, Optional, Union
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from dbgpt.core.interface.message import BaseMessage, ModelMessage
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class ProxyTokenizerWrapper:
 | |
|     def __init__(self) -> None:
 | |
|         self._support_encoding = True
 | |
|         self._encoding_model = None
 | |
| 
 | |
|     def count_token(
 | |
|         self,
 | |
|         messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]],
 | |
|         model_name: Optional[str] = None,
 | |
|     ) -> int:
 | |
|         """Count token of given messages
 | |
| 
 | |
|         Args:
 | |
|             messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token
 | |
|             model_name (Optional[str], optional): model name. Defaults to None.
 | |
| 
 | |
|         Returns:
 | |
|             int: token count, -1 if failed
 | |
|         """
 | |
|         if not self._support_encoding:
 | |
|             logger.warning(
 | |
|                 "model does not support encoding model, can't count token, returning -1"
 | |
|             )
 | |
|             return -1
 | |
|         encoding = self._get_or_create_encoding_model(model_name)
 | |
|         cnt = 0
 | |
|         if isinstance(messages, str):
 | |
|             cnt = len(encoding.encode(messages, disallowed_special=()))
 | |
|         elif isinstance(messages, BaseMessage):
 | |
|             cnt = len(encoding.encode(messages.content, disallowed_special=()))
 | |
|         elif isinstance(messages, ModelMessage):
 | |
|             cnt = len(encoding.encode(messages.content, disallowed_special=()))
 | |
|         elif isinstance(messages, list):
 | |
|             for message in messages:
 | |
|                 cnt += len(encoding.encode(message.content, disallowed_special=()))
 | |
|         else:
 | |
|             logger.warning(
 | |
|                 "unsupported type of messages, can't count token, returning -1"
 | |
|             )
 | |
|             return -1
 | |
|         return cnt
 | |
| 
 | |
|     def _get_or_create_encoding_model(self, model_name: Optional[str] = None):
 | |
|         """Get or create encoding model for given model name
 | |
|         More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 | |
|         """
 | |
|         if self._encoding_model:
 | |
|             return self._encoding_model
 | |
|         try:
 | |
|             import tiktoken
 | |
| 
 | |
|             logger.info(
 | |
|                 "tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, "
 | |
|                 "also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR"
 | |
|             )
 | |
|         except ImportError:
 | |
|             self._support_encoding = False
 | |
|             logger.warn("tiktoken not installed, cannot count tokens, returning -1")
 | |
|             return -1
 | |
|         try:
 | |
|             if not model_name:
 | |
|                 model_name = "gpt-3.5-turbo"
 | |
|             self._encoding_model = tiktoken.model.encoding_for_model(model_name)
 | |
|         except KeyError:
 | |
|             logger.warning(
 | |
|                 f"{model_name}'s tokenizer not found, using cl100k_base encoding."
 | |
|             )
 | |
|             self._encoding_model = tiktoken.get_encoding("cl100k_base")
 | |
|         return self._encoding_model
 |