fix: Fix count token error

This commit is contained in:
Fangyin Cheng 2025-03-31 06:56:08 +08:00
parent 3f513973bc
commit c5ef02bf91

View File

@ -15,10 +15,12 @@ from dbgpt.core.interface.parameter import (
BaseDeployModelParameters, BaseDeployModelParameters,
LLMDeployModelParameters, LLMDeployModelParameters,
) )
from dbgpt.model.adapter.base import LLMModelAdapter, ModelType from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.loader import ModelLoader from dbgpt.model.adapter.loader import ModelLoader
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.proxy.base import TiktokenProxyTokenizer
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
from dbgpt.util.parameter_utils import _get_dict_from_obj from dbgpt.util.parameter_utils import _get_dict_from_obj
from dbgpt.util.system_utils import get_system_info from dbgpt.util.system_utils import get_system_info
@ -43,6 +45,8 @@ class DefaultModelWorker(ModelWorker):
self._support_generate_func = False self._support_generate_func = False
self.context_len = 4096 self.context_len = 4096
self._device = get_device() self._device = get_device()
# Use tiktoken to count token if model doesn't support
self._tiktoken = TiktokenProxyTokenizer()
def load_worker( def load_worker(
self, model_name: str, deploy_model_params: BaseDeployModelParameters, **kwargs self, model_name: str, deploy_model_params: BaseDeployModelParameters, **kwargs
@ -241,11 +245,9 @@ class DefaultModelWorker(ModelWorker):
return output return output
def count_token(self, prompt: str) -> int: def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer, self.model) return _try_to_count_token(prompt, self.tokenizer, self.model, self._tiktoken)
async def async_count_token(self, prompt: str) -> int: async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run
# transformer _try_to_count_token to async
from dbgpt.model.proxy.llms.proxy_model import ProxyModel from dbgpt.model.proxy.llms.proxy_model import ProxyModel
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client: if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
@ -253,9 +255,10 @@ class DefaultModelWorker(ModelWorker):
self.model.proxy_llm_client.default_model, prompt self.model.proxy_llm_client.default_model, prompt
) )
if self._model_params.provider == ModelType.VLLM: cnt = await blocking_func_to_async_no_executor(
return _try_to_count_token(prompt, self.tokenizer, self.model) _try_to_count_token, prompt, self.tokenizer, self.model, self._tiktoken
raise NotImplementedError )
return cnt
def get_model_metadata(self, params: Dict) -> ModelMetadata: def get_model_metadata(self, params: Dict) -> ModelMetadata:
ext_metadata = ModelExtraMedata( ext_metadata = ModelExtraMedata(
@ -597,7 +600,9 @@ def _new_metrics_from_model_output(
return metrics return metrics
def _try_to_count_token(prompt: str, tokenizer, model) -> int: def _try_to_count_token(
prompt: str, tokenizer, model, tiktoken: TiktokenProxyTokenizer
) -> int:
"""Try to count token of prompt """Try to count token of prompt
Args: Args:
@ -615,11 +620,11 @@ def _try_to_count_token(prompt: str, tokenizer, model) -> int:
if isinstance(model, ProxyModel): if isinstance(model, ProxyModel):
return model.count_token(prompt) return model.count_token(prompt)
# Only support huggingface model now # Only support huggingface and vllm model now
return len(tokenizer(prompt).input_ids[0]) return len(tokenizer([prompt]).input_ids[0])
except Exception as e: except Exception as _e:
logger.warning(f"Count token error, detail: {e}, return -1") logger.warning("Failed to count token, try tiktoken")
return -1 return tiktoken.count_token("cl100k_base", [prompt])[0]
def _try_import_torch(): def _try_import_torch():