Files
DB-GPT/pilot/model/cluster/worker/default_worker.py
2023-11-02 20:39:14 +08:00

290 lines
11 KiB
Python

import os
import logging
from typing import Dict, Iterator, List
from pilot.configs.model_config import get_device
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
from pilot.utils.system_utils import get_system_info
logger = logging.getLogger(__name__)
_torch_imported = False
try:
import torch
_torch_imported = True
except ImportError:
pass
class DefaultModelWorker(ModelWorker):
def __init__(self) -> None:
self.model = None
self.tokenizer = None
self._model_params = None
self.llm_adapter: LLMModelAdaper = None
self._support_async = False
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
model_path = model_path[:-1]
model_path = _get_model_real_path(model_name, model_path)
self.model_name = model_name
self.model_path = model_path
model_type = kwargs.get("model_type")
### Temporary configuration, fastchat will be used by default in the future.
use_fastchat = os.getenv("USE_FASTCHAT", "True").lower() == "true"
self.llm_adapter = get_llm_model_adapter(
self.model_name,
self.model_path,
use_fastchat=use_fastchat,
model_type=model_type,
)
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)
self._support_async = self.llm_adapter.support_async()
logger.info(
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
)
self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name
)
# TODO read context len from model config
self.context_len = 2048
def model_param_class(self) -> ModelParameters:
return self.param_cls
def support_async(self) -> bool:
return self._support_async
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefixes=[env_prefix, "LLM_"],
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,
model_type=model_type,
)
if not model_params.device:
model_params.device = get_device()
logger.info(
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
)
return model_params
def start(
self, model_params: ModelParameters = None, command_args: List[str] = None
) -> None:
if not model_params:
model_params = self.parse_parameters(command_args)
self._model_params = model_params
logger.info(f"Begin load model, model params: {model_params}")
metadata = {
"model_name": self.model_name,
"model_path": self.model_path,
"model_type": self.llm_adapter.model_type(),
"llm_adapter": str(self.llm_adapter),
"run_service": SpanTypeRunName.MODEL_WORKER,
"params": _get_dict_from_obj(model_params),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata
):
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
def stop(self) -> None:
if not self.model:
logger.warn("Model has been stopped!!")
return
del self.model
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span(
"DefaultModelWorker.generate_stream", params.get("span_id")
)
try:
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_stream_func",
)
previous_response = ""
for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
)
previous_response = output_str
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
model_span.end(metadata={"output": previous_response})
span.end()
except Exception as e:
output = self._handle_exception(e)
yield output
span.end(metadata={"error": output.to_dict()})
def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
output = None
for out in self.generate_stream(params):
output = out
return output
def embeddings(self, params: Dict) -> List[List[float]]:
raise NotImplementedError
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span(
"DefaultModelWorker.async_generate_stream", params.get("span_id")
)
try:
(
params,
model_context,
generate_stream_func,
model_span,
) = self._prepare_generate_stream(
params,
span_operation_name="DefaultModelWorker_call.generate_stream_func",
)
previous_response = ""
async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
)
previous_response = output_str
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
model_span.end(metadata={"output": previous_response})
span.end()
except Exception as e:
output = self._handle_exception(e)
yield output
span.end(metadata={"error": output.to_dict()})
async def async_generate(self, params: Dict) -> ModelOutput:
output = None
async for out in self.async_generate_stream(params):
output = out
return output
def _prepare_generate_stream(self, params: Dict, span_operation_name: str):
params, model_context = self.llm_adapter.model_adaptation(
params,
self.model_name,
self.model_path,
prompt_template=self.ml.prompt_template,
)
stream_type = ""
if self.support_async():
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
self.model, self.model_path
)
stream_type = "async "
logger.info(
"current generate stream function is asynchronous stream function"
)
else:
generate_stream_func = self.llm_adapter.get_generate_stream_function(
self.model, self.model_path
)
str_prompt = params.get("prompt")
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
generate_stream_func_str_name = "{}.{}".format(
generate_stream_func.__module__, generate_stream_func.__name__
)
span_params = {k: v for k, v in params.items()}
if "messages" in span_params:
span_params["messages"] = list(
map(lambda m: m.dict(), span_params["messages"])
)
model_span = root_tracer.start_span(
span_operation_name,
metadata={
"prompt": str_prompt,
"params": span_params,
"is_async_func": self.support_async(),
"llm_adapter": str(self.llm_adapter),
"generate_stream_func": generate_stream_func_str_name,
"model_context": model_context,
},
)
return params, model_context, generate_stream_func, model_span
def _handle_output(self, output, previous_response, model_context):
finish_reason = None
usage = None
if isinstance(output, dict):
finish_reason = output.get("finish_reason")
usage = output.get("usage")
output = output["text"]
if finish_reason is not None:
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
model_output = ModelOutput(
text=output,
error_code=0,
model_context=model_context,
finish_reason=finish_reason,
usage=usage,
)
return model_output, incremental_output, output
def _handle_exception(self, e):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if _torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
return model_output