mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 23:18:20 +00:00
feat(model): Proxy model support count token (#996)
This commit is contained in:
@@ -189,7 +189,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
return output
|
||||
|
||||
def count_token(self, prompt: str) -> int:
|
||||
return _try_to_count_token(prompt, self.tokenizer)
|
||||
return _try_to_count_token(prompt, self.tokenizer, self.model)
|
||||
|
||||
async def async_count_token(self, prompt: str) -> int:
|
||||
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
|
||||
@@ -454,12 +454,13 @@ def _new_metrics_from_model_output(
|
||||
return metrics
|
||||
|
||||
|
||||
def _try_to_count_token(prompt: str, tokenizer) -> int:
|
||||
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
|
||||
"""Try to count token of prompt
|
||||
|
||||
Args:
|
||||
prompt (str): prompt
|
||||
tokenizer ([type]): tokenizer
|
||||
model ([type]): model
|
||||
|
||||
Returns:
|
||||
int: token count, if error return -1
|
||||
@@ -467,6 +468,11 @@ def _try_to_count_token(prompt: str, tokenizer) -> int:
|
||||
TODO: More implementation
|
||||
"""
|
||||
try:
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
if isinstance(model, ProxyModel):
|
||||
return model.count_token(prompt)
|
||||
# Only support huggingface model now
|
||||
return len(tokenizer(prompt).input_ids[0])
|
||||
except Exception as e:
|
||||
logger.warning(f"Count token error, detail: {e}, return -1")
|
||||
|
@@ -197,7 +197,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
return True
|
||||
else:
|
||||
# TODO Update worker
|
||||
logger.warn(f"Instance {worker_key} exist")
|
||||
logger.warning(f"Instance {worker_key} exist")
|
||||
return False
|
||||
|
||||
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
|
||||
@@ -229,7 +229,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
)
|
||||
if not success:
|
||||
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
|
||||
logger.warn(f"{msg}, worker_params: {worker_params}")
|
||||
logger.warning(f"{msg}, worker_params: {worker_params}")
|
||||
self._remove_worker(worker_params)
|
||||
raise Exception(msg)
|
||||
supported_types = WorkerType.values()
|
||||
|
Reference in New Issue
Block a user