mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-28 14:27:20 +00:00
fix: Fix count token error
This commit is contained in:
parent
3f513973bc
commit
c5ef02bf91
@ -15,10 +15,12 @@ from dbgpt.core.interface.parameter import (
|
|||||||
BaseDeployModelParameters,
|
BaseDeployModelParameters,
|
||||||
LLMDeployModelParameters,
|
LLMDeployModelParameters,
|
||||||
)
|
)
|
||||||
from dbgpt.model.adapter.base import LLMModelAdapter, ModelType
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
from dbgpt.model.adapter.loader import ModelLoader
|
from dbgpt.model.adapter.loader import ModelLoader
|
||||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||||
|
from dbgpt.model.proxy.base import TiktokenProxyTokenizer
|
||||||
|
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||||
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
|
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
|
||||||
from dbgpt.util.parameter_utils import _get_dict_from_obj
|
from dbgpt.util.parameter_utils import _get_dict_from_obj
|
||||||
from dbgpt.util.system_utils import get_system_info
|
from dbgpt.util.system_utils import get_system_info
|
||||||
@ -43,6 +45,8 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self._support_generate_func = False
|
self._support_generate_func = False
|
||||||
self.context_len = 4096
|
self.context_len = 4096
|
||||||
self._device = get_device()
|
self._device = get_device()
|
||||||
|
# Use tiktoken to count token if model doesn't support
|
||||||
|
self._tiktoken = TiktokenProxyTokenizer()
|
||||||
|
|
||||||
def load_worker(
|
def load_worker(
|
||||||
self, model_name: str, deploy_model_params: BaseDeployModelParameters, **kwargs
|
self, model_name: str, deploy_model_params: BaseDeployModelParameters, **kwargs
|
||||||
@ -241,11 +245,9 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def count_token(self, prompt: str) -> int:
|
def count_token(self, prompt: str) -> int:
|
||||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
return _try_to_count_token(prompt, self.tokenizer, self.model, self._tiktoken)
|
||||||
|
|
||||||
async def async_count_token(self, prompt: str) -> int:
|
async def async_count_token(self, prompt: str) -> int:
|
||||||
# TODO if we deploy the model by vllm, it can't work, we should run
|
|
||||||
# transformer _try_to_count_token to async
|
|
||||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
|
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
|
||||||
@ -253,9 +255,10 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.model.proxy_llm_client.default_model, prompt
|
self.model.proxy_llm_client.default_model, prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._model_params.provider == ModelType.VLLM:
|
cnt = await blocking_func_to_async_no_executor(
|
||||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
_try_to_count_token, prompt, self.tokenizer, self.model, self._tiktoken
|
||||||
raise NotImplementedError
|
)
|
||||||
|
return cnt
|
||||||
|
|
||||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||||
ext_metadata = ModelExtraMedata(
|
ext_metadata = ModelExtraMedata(
|
||||||
@ -597,7 +600,9 @@ def _new_metrics_from_model_output(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
|
def _try_to_count_token(
|
||||||
|
prompt: str, tokenizer, model, tiktoken: TiktokenProxyTokenizer
|
||||||
|
) -> int:
|
||||||
"""Try to count token of prompt
|
"""Try to count token of prompt
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -615,11 +620,11 @@ def _try_to_count_token(prompt: str, tokenizer, model) -> int:
|
|||||||
|
|
||||||
if isinstance(model, ProxyModel):
|
if isinstance(model, ProxyModel):
|
||||||
return model.count_token(prompt)
|
return model.count_token(prompt)
|
||||||
# Only support huggingface model now
|
# Only support huggingface and vllm model now
|
||||||
return len(tokenizer(prompt).input_ids[0])
|
return len(tokenizer([prompt]).input_ids[0])
|
||||||
except Exception as e:
|
except Exception as _e:
|
||||||
logger.warning(f"Count token error, detail: {e}, return -1")
|
logger.warning("Failed to count token, try tiktoken")
|
||||||
return -1
|
return tiktoken.count_token("cl100k_base", [prompt])[0]
|
||||||
|
|
||||||
|
|
||||||
def _try_import_torch():
|
def _try_import_torch():
|
||||||
|
Loading…
Reference in New Issue
Block a user