From c5ef02bf91876d5ebe098e5e8d1341d2207b183b Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 31 Mar 2025 06:56:08 +0800 Subject: [PATCH] fix: Fix count token error --- .../model/cluster/worker/default_worker.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py b/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py index 5c10b0843..61a1c75b9 100644 --- a/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py +++ b/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py @@ -15,10 +15,12 @@ from dbgpt.core.interface.parameter import ( BaseDeployModelParameters, LLMDeployModelParameters, ) -from dbgpt.model.adapter.base import LLMModelAdapter, ModelType +from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.loader import ModelLoader from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.model.cluster.worker_base import ModelWorker +from dbgpt.model.proxy.base import TiktokenProxyTokenizer +from dbgpt.util.executor_utils import blocking_func_to_async_no_executor from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory from dbgpt.util.parameter_utils import _get_dict_from_obj from dbgpt.util.system_utils import get_system_info @@ -43,6 +45,8 @@ class DefaultModelWorker(ModelWorker): self._support_generate_func = False self.context_len = 4096 self._device = get_device() + # Use tiktoken to count token if model doesn't support + self._tiktoken = TiktokenProxyTokenizer() def load_worker( self, model_name: str, deploy_model_params: BaseDeployModelParameters, **kwargs @@ -241,11 +245,9 @@ class DefaultModelWorker(ModelWorker): return output def count_token(self, prompt: str) -> int: - return _try_to_count_token(prompt, self.tokenizer, self.model) + return _try_to_count_token(prompt, self.tokenizer, self.model, self._tiktoken) async def async_count_token(self, prompt: str) -> int: - # TODO if we deploy the model by vllm, it can't work, we should run - # transformer _try_to_count_token to async from dbgpt.model.proxy.llms.proxy_model import ProxyModel if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client: @@ -253,9 +255,10 @@ class DefaultModelWorker(ModelWorker): self.model.proxy_llm_client.default_model, prompt ) - if self._model_params.provider == ModelType.VLLM: - return _try_to_count_token(prompt, self.tokenizer, self.model) - raise NotImplementedError + cnt = await blocking_func_to_async_no_executor( + _try_to_count_token, prompt, self.tokenizer, self.model, self._tiktoken + ) + return cnt def get_model_metadata(self, params: Dict) -> ModelMetadata: ext_metadata = ModelExtraMedata( @@ -597,7 +600,9 @@ def _new_metrics_from_model_output( return metrics -def _try_to_count_token(prompt: str, tokenizer, model) -> int: +def _try_to_count_token( + prompt: str, tokenizer, model, tiktoken: TiktokenProxyTokenizer +) -> int: """Try to count token of prompt Args: @@ -615,11 +620,11 @@ def _try_to_count_token(prompt: str, tokenizer, model) -> int: if isinstance(model, ProxyModel): return model.count_token(prompt) - # Only support huggingface model now - return len(tokenizer(prompt).input_ids[0]) - except Exception as e: - logger.warning(f"Count token error, detail: {e}, return -1") - return -1 + # Only support huggingface and vllm model now + return len(tokenizer([prompt]).input_ids[0]) + except Exception as _e: + logger.warning("Failed to count token, try tiktoken") + return tiktoken.count_token("cl100k_base", [prompt])[0] def _try_import_torch():