DB-GPT/dbgpt/model/cluster/worker/default_worker.py
2024-07-05 15:20:21 +08:00

516 lines
18 KiB
Python

import logging
import os
import time
import traceback
from typing import Dict, Iterator, List, Optional
from dbgpt.configs.model_config import get_device
from dbgpt.core import (
ModelExtraMedata,
ModelInferenceMetrics,
ModelMetadata,
ModelOutput,
)
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import ModelParameters
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
from dbgpt.util.system_utils import get_system_info
from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer
logger = logging.getLogger(__name__)
_torch_imported = False
torch = None
class DefaultModelWorker(ModelWorker):
def __init__(self) -> None:
self.model = None
self.tokenizer = None
self._model_params = None
self.llm_adapter: LLMModelAdapter = None
self._support_async = False
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
model_path = model_path[:-1]
model_path = _get_model_real_path(model_name, model_path)
self.model_name = model_name
self.model_path = model_path
model_type = kwargs.get("model_type")
### Temporary configuration, fastchat will be used by default in the future.
use_fastchat = os.getenv("USE_FASTCHAT", "True").lower() == "true"
self.llm_adapter = get_llm_model_adapter(
self.model_name,
self.model_path,
use_fastchat=use_fastchat,
model_type=model_type,
)
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)
self._support_async = self.llm_adapter.support_async()
logger.info(
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
)
self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name
)
# Default model context len
self.context_len = 2048
def model_param_class(self) -> ModelParameters:
return self.param_cls
def support_async(self) -> bool:
return self._support_async
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefixes=[env_prefix, "LLM_"],
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,
model_type=model_type,
)
if not model_params.device:
model_params.device = get_device()
logger.info(
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params
def start(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
# Lazy load torch
_try_import_torch()
if not model_params:
model_params = self.parse_parameters(command_args)
self._model_params = model_params
logger.info(f"Begin load model, model params: {model_params}")
metadata = {
"model_name": self.model_name,
"model_path": self.model_path,
"model_type": self.llm_adapter.model_type(),
"llm_adapter": str(self.llm_adapter),
"run_service": SpanTypeRunName.MODEL_WORKER,
"params": _get_dict_from_obj(model_params),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata
):
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
model_max_length = self.llm_adapter.parse_max_length(
self.model, self.tokenizer
)
if model_max_length:
logger.info(
f"Parse model max length {model_max_length} from model {self.model_name}."
)
self.context_len = model_max_length
elif hasattr(model_params, "max_context_size"):
self.context_len = model_params.max_context_size
def stop(self) -> None:
if not self.model:
logger.warn("Model has been stopped!!")
return
del self.model
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span(
"DefaultModelWorker.generate_stream", params.get("span_id")
)
try:
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_stream_func",
)
previous_response = ""
last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True
context_len = params.get("context_len") or self.context_len
for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), context_len
):
(
model_output,
incremental_output,
output_str,
current_metrics,
) = self._handle_output(
output,
previous_response,
model_context,
last_metrics,
is_first_generate,
)
if is_first_generate:
is_first_generate = False
previous_response = output_str
last_metrics = current_metrics
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
model_span.end(metadata={"output": previous_response})
span.end()
except Exception as e:
output = self._handle_exception(e)
yield output
span.end(metadata={"error": output.to_dict()})
def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
output = None
for out in self.generate_stream(params):
output = out
return output
def count_token(self, prompt: str) -> int:
return _try_to_count_token(prompt, self.tokenizer, self.model)
async def async_count_token(self, prompt: str) -> int:
# TODO if we deploy the model by vllm, it can't work, we should run
# transformer _try_to_count_token to async
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client:
return await self.model.proxy_llm_client.count_token(
self.model.proxy_llm_client.default_model, prompt
)
raise NotImplementedError
def get_model_metadata(self, params: Dict) -> ModelMetadata:
ext_metadata = ModelExtraMedata(
prompt_roles=self.llm_adapter.get_prompt_roles(),
prompt_sep=self.llm_adapter.get_default_message_separator(),
)
return ModelMetadata(
model=self.model_name,
context_length=self.context_len,
ext_metadata=ext_metadata,
)
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
return self.get_model_metadata(params)
def embeddings(self, params: Dict) -> List[List[float]]:
raise NotImplementedError
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span(
"DefaultModelWorker.async_generate_stream", params.get("span_id")
)
try:
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_stream_func",
)
previous_response = ""
context_len = params.get("context_len") or self.context_len
last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True
async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), context_len
):
(
model_output,
incremental_output,
output_str,
current_metrics,
) = self._handle_output(
output,
previous_response,
model_context,
last_metrics,
is_first_generate,
)
if is_first_generate:
is_first_generate = False
previous_response = output_str
last_metrics = current_metrics
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
model_span.end(metadata={"output": previous_response})
span.end()
except Exception as e:
output = self._handle_exception(e)
yield output
span.end(metadata={"error": output.to_dict()})
async def async_generate(self, params: Dict) -> ModelOutput:
output = None
async for out in self.async_generate_stream(params):
output = out
return output
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
params, model_context = self.llm_adapter.model_adaptation(
params,
self.model_name,
self.model_path,
self.tokenizer,
prompt_template=self.ml.prompt_template,
)
stream_type = ""
if self.support_async():
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
self.model, self.model_path
)
stream_type = "async "
logger.info(
"current generate stream function is asynchronous stream function"
)
else:
generate_stream_func = self.llm_adapter.get_generate_stream_function(
self.model, self.model_path
)
str_prompt = params.get("prompt")
if not str_prompt:
str_prompt = params.get("string_prompt")
print(
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
)
generate_stream_func_str_name = "{}.{}".format(
generate_stream_func.__module__, generate_stream_func.__name__
)
span_params = {k: v for k, v in params.items()}
if "messages" in span_params:
span_params["messages"] = list(
map(lambda m: m.dict(), span_params["messages"])
)
metadata = {
"is_async_func": self.support_async(),
"llm_adapter": str(self.llm_adapter),
"generate_stream_func": generate_stream_func_str_name,
}
metadata.update(span_params)
metadata.update(model_context)
metadata["prompt"] = str_prompt
model_span = root_tracer.start_span(span_operation_name, metadata=metadata)
return params, model_context, generate_stream_func, model_span
def _handle_output(
self,
output,
previous_response,
model_context,
last_metrics: ModelInferenceMetrics,
is_first_generate: bool,
):
finish_reason = None
usage = None
error_code = 0
if isinstance(output, dict):
finish_reason = output.get("finish_reason")
usage = output.get("usage")
output = output["text"]
if finish_reason is not None:
logger.info(f"finish_reason: {finish_reason}")
elif isinstance(output, ModelOutput):
finish_reason = output.finish_reason
usage = output.usage
error_code = output.error_code
output = output.text
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
metrics = _new_metrics_from_model_output(last_metrics, is_first_generate, usage)
model_output = ModelOutput(
text=output,
error_code=error_code,
model_context=model_context,
finish_reason=finish_reason,
usage=usage,
metrics=metrics,
)
return model_output, incremental_output, output, metrics
def _handle_exception(self, e):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if _torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
)
else:
msg = traceback.format_exc()
logger.error(f"Model inference error, detail: {msg}")
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
)
return model_output
def _parse_model_max_length(model, tokenizer) -> Optional[int]:
if not (tokenizer or model):
return None
try:
if tokenizer and hasattr(tokenizer, "model_max_length"):
return tokenizer.model_max_length
if model and hasattr(model, "config"):
model_config = model.config
if hasattr(model_config, "max_sequence_length"):
return model_config.max_sequence_length
if hasattr(model_config, "max_position_embeddings"):
return model_config.max_position_embeddings
except Exception:
return None
def _new_metrics_from_model_output(
last_metric: ModelInferenceMetrics,
is_first_generate: bool,
usage: Optional[Dict] = None,
) -> ModelInferenceMetrics:
metrics = ModelInferenceMetrics.create_metrics(last_metric)
metrics.collect_index = last_metric.collect_index + 1
if is_first_generate:
logger.info(f"is_first_generate, usage: {usage}")
metrics.first_completion_time_ms = time.time_ns() // 1_000_000
if not usage or not isinstance(usage, dict):
return metrics
prompt_tokens = usage.get("prompt_tokens")
completion_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")
if prompt_tokens is None:
prompt_tokens = metrics.prompt_tokens
if completion_tokens is None:
completion_tokens = metrics.completion_tokens
if total_tokens is None:
total_tokens = metrics.total_tokens
if is_first_generate and (completion_tokens is not None):
# completion_tokens == 0 is prefill
metrics.first_completion_tokens = completion_tokens
if completion_tokens == 1:
metrics.first_token_time_ms = metrics.first_completion_time_ms
if (
not is_first_generate
and metrics.first_token_time_ms is None
and completion_tokens == 1
):
# Case: first generate has 0 token, and second generate has 1 token
metrics.first_token_time_ms = time.time_ns() // 1_000_000
if prompt_tokens:
metrics.prompt_tokens = prompt_tokens
if completion_tokens:
metrics.completion_tokens = completion_tokens
if total_tokens:
metrics.total_tokens = total_tokens
elif prompt_tokens and completion_tokens:
total_tokens = prompt_tokens + completion_tokens
metrics.total_tokens = total_tokens
if total_tokens:
# time cost(seconds)
duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0
metrics.speed_per_second = total_tokens / duration
current_gpu_infos = _get_current_cuda_memory()
metrics.current_gpu_infos = current_gpu_infos
if not metrics.avg_gpu_infos:
metrics.avg_gpu_infos = current_gpu_infos
elif current_gpu_infos:
for i, last_avg in enumerate(metrics.avg_gpu_infos):
allocated_memory_gb = (
last_avg.allocated_memory_gb * (metrics.collect_index - 1)
+ current_gpu_infos[i].allocated_memory_gb
)
metrics.avg_gpu_infos[i].allocated_memory_gb = (
allocated_memory_gb / metrics.collect_index
)
metrics.avg_gpu_infos[i].total_memory_gb = current_gpu_infos[
i
].total_memory_gb
metrics.avg_gpu_infos[i].cached_memory_gb = current_gpu_infos[
i
].cached_memory_gb
metrics.avg_gpu_infos[i].available_memory_gb = current_gpu_infos[
i
].available_memory_gb
return metrics
def _try_to_count_token(prompt: str, tokenizer, model) -> int:
"""Try to count token of prompt
Args:
prompt (str): prompt
tokenizer ([type]): tokenizer
model ([type]): model
Returns:
int: token count, if error return -1
TODO: More implementation
"""
try:
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
if isinstance(model, ProxyModel):
return model.count_token(prompt)
# Only support huggingface model now
return len(tokenizer(prompt).input_ids[0])
except Exception as e:
logger.warning(f"Count token error, detail: {e}, return -1")
return -1
def _try_import_torch():
global torch
global _torch_imported
try:
import torch
_torch_imported = True
except ImportError:
pass