import logging import os import time import traceback from typing import Dict, Iterator, List, Optional from dbgpt.configs.model_config import get_device from dbgpt.core import ( ModelExtraMedata, ModelInferenceMetrics, ModelMetadata, ModelOutput, ) from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.parameter import ModelParameters from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj from dbgpt.util.system_utils import get_system_info from dbgpt.util.tracer import SpanType, SpanTypeRunName, root_tracer logger = logging.getLogger(__name__) _torch_imported = False torch = None class DefaultModelWorker(ModelWorker): def __init__(self) -> None: self.model = None self.tokenizer = None self._model_params = None self.llm_adapter: LLMModelAdapter = None self._support_async = False def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: if model_path.endswith("/"): model_path = model_path[:-1] model_path = _get_model_real_path(model_name, model_path) self.model_name = model_name self.model_path = model_path model_type = kwargs.get("model_type") ### Temporary configuration, fastchat will be used by default in the future. use_fastchat = os.getenv("USE_FASTCHAT", "True").lower() == "true" self.llm_adapter = get_llm_model_adapter( self.model_name, self.model_path, use_fastchat=use_fastchat, model_type=model_type, ) model_type = self.llm_adapter.model_type() self.param_cls = self.llm_adapter.model_param_class(model_type) self._support_async = self.llm_adapter.support_async() logger.info( f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}" ) self.ml: ModelLoader = ModelLoader( model_path=self.model_path, model_name=self.model_name ) # Default model context len self.context_len = 2048 def model_param_class(self) -> ModelParameters: return self.param_cls def support_async(self) -> bool: return self._support_async def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: param_cls = self.model_param_class() model_args = EnvArgumentParser() env_prefix = EnvArgumentParser.get_env_prefix(self.model_name) model_type = self.llm_adapter.model_type() model_params: ModelParameters = model_args.parse_args_into_dataclass( param_cls, env_prefixes=[env_prefix, "LLM_"], command_args=command_args, model_name=self.model_name, model_path=self.model_path, model_type=model_type, ) if not model_params.device: model_params.device = get_device() logger.info( f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}" ) return model_params def start( self, model_params: ModelParameters = None, command_args: List[str] = None ) -> None: # Lazy load torch _try_import_torch() if not model_params: model_params = self.parse_parameters(command_args) self._model_params = model_params logger.info(f"Begin load model, model params: {model_params}") metadata = { "model_name": self.model_name, "model_path": self.model_path, "model_type": self.llm_adapter.model_type(), "llm_adapter": str(self.llm_adapter), "run_service": SpanTypeRunName.MODEL_WORKER, "params": _get_dict_from_obj(model_params), "sys_infos": _get_dict_from_obj(get_system_info()), } with root_tracer.start_span( "DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata ): self.model, self.tokenizer = self.ml.loader_with_params( model_params, self.llm_adapter ) model_max_length = self.llm_adapter.parse_max_length( self.model, self.tokenizer ) if model_max_length: logger.info( f"Parse model max length {model_max_length} from model {self.model_name}." ) self.context_len = model_max_length elif hasattr(model_params, "max_context_size"): self.context_len = model_params.max_context_size def stop(self) -> None: if not self.model: logger.warn("Model has been stopped!!") return del self.model del self.tokenizer self.model = None self.tokenizer = None _clear_model_cache(self._model_params.device) def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: span = root_tracer.start_span( "DefaultModelWorker.generate_stream", params.get("span_id") ) try: ( params, model_context, generate_stream_func, model_span, ) = self._prepare_generate_stream( params, span_operation_name="DefaultModelWorker_call.generate_stream_func", ) previous_response = "" last_metrics = ModelInferenceMetrics.create_metrics() is_first_generate = True context_len = params.get("context_len") or self.context_len for output in generate_stream_func( self.model, self.tokenizer, params, get_device(), context_len ): ( model_output, incremental_output, output_str, current_metrics, ) = self._handle_output( output, previous_response, model_context, last_metrics, is_first_generate, ) if is_first_generate: is_first_generate = False previous_response = output_str last_metrics = current_metrics yield model_output print( f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}" ) model_span.end(metadata={"output": previous_response}) span.end() except Exception as e: output = self._handle_exception(e) yield output span.end(metadata={"error": output.to_dict()}) def generate(self, params: Dict) -> ModelOutput: """Generate non stream result""" output = None for out in self.generate_stream(params): output = out return output def count_token(self, prompt: str) -> int: return _try_to_count_token(prompt, self.tokenizer, self.model) async def async_count_token(self, prompt: str) -> int: # TODO if we deploy the model by vllm, it can't work, we should run # transformer _try_to_count_token to async from dbgpt.model.proxy.llms.proxy_model import ProxyModel if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client: return await self.model.proxy_llm_client.count_token( self.model.proxy_llm_client.default_model, prompt ) raise NotImplementedError def get_model_metadata(self, params: Dict) -> ModelMetadata: ext_metadata = ModelExtraMedata( prompt_roles=self.llm_adapter.get_prompt_roles(), prompt_sep=self.llm_adapter.get_default_message_separator(), ) return ModelMetadata( model=self.model_name, context_length=self.context_len, ext_metadata=ext_metadata, ) async def async_get_model_metadata(self, params: Dict) -> ModelMetadata: return self.get_model_metadata(params) def embeddings(self, params: Dict) -> List[List[float]]: raise NotImplementedError async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: span = root_tracer.start_span( "DefaultModelWorker.async_generate_stream", params.get("span_id") ) try: ( params, model_context, generate_stream_func, model_span, ) = self._prepare_generate_stream( params, span_operation_name="DefaultModelWorker_call.generate_stream_func", ) previous_response = "" context_len = params.get("context_len") or self.context_len last_metrics = ModelInferenceMetrics.create_metrics() is_first_generate = True async for output in generate_stream_func( self.model, self.tokenizer, params, get_device(), context_len ): ( model_output, incremental_output, output_str, current_metrics, ) = self._handle_output( output, previous_response, model_context, last_metrics, is_first_generate, ) if is_first_generate: is_first_generate = False previous_response = output_str last_metrics = current_metrics yield model_output print( f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}" ) model_span.end(metadata={"output": previous_response}) span.end() except Exception as e: output = self._handle_exception(e) yield output span.end(metadata={"error": output.to_dict()}) async def async_generate(self, params: Dict) -> ModelOutput: output = None async for out in self.async_generate_stream(params): output = out return output def _prepare_generate_stream(self, params: Dict, span_operation_name: str): params, model_context = self.llm_adapter.model_adaptation( params, self.model_name, self.model_path, self.tokenizer, prompt_template=self.ml.prompt_template, ) stream_type = "" if self.support_async(): generate_stream_func = self.llm_adapter.get_async_generate_stream_function( self.model, self.model_path ) stream_type = "async " logger.info( "current generate stream function is asynchronous stream function" ) else: generate_stream_func = self.llm_adapter.get_generate_stream_function( self.model, self.model_path ) str_prompt = params.get("prompt") if not str_prompt: str_prompt = params.get("string_prompt") print( f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n" ) generate_stream_func_str_name = "{}.{}".format( generate_stream_func.__module__, generate_stream_func.__name__ ) span_params = {k: v for k, v in params.items()} if "messages" in span_params: span_params["messages"] = list( map(lambda m: m.dict(), span_params["messages"]) ) metadata = { "is_async_func": self.support_async(), "llm_adapter": str(self.llm_adapter), "generate_stream_func": generate_stream_func_str_name, } metadata.update(span_params) metadata.update(model_context) metadata["prompt"] = str_prompt model_span = root_tracer.start_span(span_operation_name, metadata=metadata) return params, model_context, generate_stream_func, model_span def _handle_output( self, output, previous_response, model_context, last_metrics: ModelInferenceMetrics, is_first_generate: bool, ): finish_reason = None usage = None error_code = 0 if isinstance(output, dict): finish_reason = output.get("finish_reason") usage = output.get("usage") output = output["text"] if finish_reason is not None: logger.info(f"finish_reason: {finish_reason}") elif isinstance(output, ModelOutput): finish_reason = output.finish_reason usage = output.usage error_code = output.error_code output = output.text incremental_output = output[len(previous_response) :] print(incremental_output, end="", flush=True) metrics = _new_metrics_from_model_output(last_metrics, is_first_generate, usage) model_output = ModelOutput( text=output, error_code=error_code, model_context=model_context, finish_reason=finish_reason, usage=usage, metrics=metrics, ) return model_output, incremental_output, output, metrics def _handle_exception(self, e): # Check if the exception is a torch.cuda.CudaError and if torch was imported. if _torch_imported and isinstance(e, torch.cuda.CudaError): model_output = ModelOutput( text="**GPU OutOfMemory, Please Refresh.**", error_code=1 ) else: msg = traceback.format_exc() logger.error(f"Model inference error, detail: {msg}") model_output = ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", error_code=1, ) return model_output def _parse_model_max_length(model, tokenizer) -> Optional[int]: if not (tokenizer or model): return None try: if tokenizer and hasattr(tokenizer, "model_max_length"): return tokenizer.model_max_length if model and hasattr(model, "config"): model_config = model.config if hasattr(model_config, "max_sequence_length"): return model_config.max_sequence_length if hasattr(model_config, "max_position_embeddings"): return model_config.max_position_embeddings except Exception: return None def _new_metrics_from_model_output( last_metric: ModelInferenceMetrics, is_first_generate: bool, usage: Optional[Dict] = None, ) -> ModelInferenceMetrics: metrics = ModelInferenceMetrics.create_metrics(last_metric) metrics.collect_index = last_metric.collect_index + 1 if is_first_generate: logger.info(f"is_first_generate, usage: {usage}") metrics.first_completion_time_ms = time.time_ns() // 1_000_000 if not usage or not isinstance(usage, dict): return metrics prompt_tokens = usage.get("prompt_tokens") completion_tokens = usage.get("completion_tokens") total_tokens = usage.get("total_tokens") if prompt_tokens is None: prompt_tokens = metrics.prompt_tokens if completion_tokens is None: completion_tokens = metrics.completion_tokens if total_tokens is None: total_tokens = metrics.total_tokens if is_first_generate and (completion_tokens is not None): # completion_tokens == 0 is prefill metrics.first_completion_tokens = completion_tokens if completion_tokens == 1: metrics.first_token_time_ms = metrics.first_completion_time_ms if ( not is_first_generate and metrics.first_token_time_ms is None and completion_tokens == 1 ): # Case: first generate has 0 token, and second generate has 1 token metrics.first_token_time_ms = time.time_ns() // 1_000_000 if prompt_tokens: metrics.prompt_tokens = prompt_tokens if completion_tokens: metrics.completion_tokens = completion_tokens if total_tokens: metrics.total_tokens = total_tokens elif prompt_tokens and completion_tokens: total_tokens = prompt_tokens + completion_tokens metrics.total_tokens = total_tokens if total_tokens: # time cost(seconds) duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0 metrics.speed_per_second = total_tokens / duration current_gpu_infos = _get_current_cuda_memory() metrics.current_gpu_infos = current_gpu_infos if not metrics.avg_gpu_infos: metrics.avg_gpu_infos = current_gpu_infos elif current_gpu_infos: for i, last_avg in enumerate(metrics.avg_gpu_infos): allocated_memory_gb = ( last_avg.allocated_memory_gb * (metrics.collect_index - 1) + current_gpu_infos[i].allocated_memory_gb ) metrics.avg_gpu_infos[i].allocated_memory_gb = ( allocated_memory_gb / metrics.collect_index ) metrics.avg_gpu_infos[i].total_memory_gb = current_gpu_infos[ i ].total_memory_gb metrics.avg_gpu_infos[i].cached_memory_gb = current_gpu_infos[ i ].cached_memory_gb metrics.avg_gpu_infos[i].available_memory_gb = current_gpu_infos[ i ].available_memory_gb return metrics def _try_to_count_token(prompt: str, tokenizer, model) -> int: """Try to count token of prompt Args: prompt (str): prompt tokenizer ([type]): tokenizer model ([type]): model Returns: int: token count, if error return -1 TODO: More implementation """ try: from dbgpt.model.proxy.llms.proxy_model import ProxyModel if isinstance(model, ProxyModel): return model.count_token(prompt) # Only support huggingface model now return len(tokenizer(prompt).input_ids[0]) except Exception as e: logger.warning(f"Count token error, detail: {e}, return -1") return -1 def _try_import_torch(): global torch global _torch_imported try: import torch _torch_imported = True except ImportError: pass