feat(model): Proxy model support count token (#996)

This commit is contained in:
Fangyin Cheng
2023-12-29 12:01:31 +08:00
committed by GitHub
parent ba0599ebf4
commit 0cdc77abb2
16 changed files with 366 additions and 248 deletions

View File

@@ -189,7 +189,7 @@ class DefaultModelWorker(ModelWorker):
return output
def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer)
return _try_to_count_token(prompt, self.tokenizer, self.model)
async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async
@@ -454,12 +454,13 @@ def _new_metrics_from_model_output(
return metrics
def _try_to_count_token(prompt: str, tokenizer) -> int:
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
"""Try to count token of prompt
Args:
prompt (str): prompt
tokenizer ([type]): tokenizer
model ([type]): model
Returns:
int: token count, if error return -1
@@ -467,6 +468,11 @@ def _try_to_count_token(prompt: str, tokenizer) -> int:
TODO: More implementation
"""
try:
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
if isinstance(model, ProxyModel):
return model.count_token(prompt)
# Only support huggingface model now
return len(tokenizer(prompt).input_ids[0])
except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1")

View File

@@ -197,7 +197,7 @@ class LocalWorkerManager(WorkerManager):
return True
else:
# TODO Update worker
logger.warn(f"Instance {worker_key} exist")
logger.warning(f"Instance {worker_key} exist")
return False
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
@@ -229,7 +229,7 @@ class LocalWorkerManager(WorkerManager):
)
if not success:
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
logger.warn(f"{msg}, worker_params: {worker_params}")
logger.warning(f"{msg}, worker_params: {worker_params}")
self._remove_worker(worker_params)
raise Exception(msg)
supported_types = WorkerType.values()