Files
DB-GPT/pilot/model/model_adapter.py
2023-10-26 15:29:32 +08:00

462 lines
16 KiB
Python

from __future__ import annotations
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
import dataclasses
import logging
import threading
import os
from functools import cache
from pilot.model.base import ModelType
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
ProxyModelParameters,
)
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.utils.parameter_utils import (
_extract_parameter_details,
_build_parameter_class,
_get_dataclass_print_str,
)
try:
from fastchat.conversation import (
Conversation,
register_conv_template,
SeparatorStyle,
)
except ImportError as exc:
raise ValueError(
"Could not import python package: fschat "
"Please install fastchat by command `pip install fschat` "
) from exc
if TYPE_CHECKING:
from fastchat.model.model_adapter import BaseModelAdapter
from pilot.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
from torch.nn import Module as TorchNNModule
logger = logging.getLogger(__name__)
thread_local = threading.local()
_OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
]
class LLMModelAdaper:
"""New Adapter for DB-GPT LLM models"""
def use_fast_tokenizer(self) -> bool:
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
for a given model.
"""
return False
def model_type(self) -> str:
return ModelType.HF
def model_param_class(self, model_type: str = None) -> ModelParameters:
"""Get the startup parameters instance of the model"""
model_type = model_type if model_type else self.model_type()
if model_type == ModelType.LLAMA_CPP:
return LlamaCppModelParameters
elif model_type == ModelType.PROXY:
return ProxyModelParameters
return ModelParameters
def load(self, model_path: str, from_pretrained_kwargs: dict):
"""Load model and tokenizer"""
raise NotImplementedError
def load_from_params(self, params):
"""Load the model and tokenizer according to the given parameters"""
raise NotImplementedError
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""
return False
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
raise NotImplementedError
def get_async_generate_stream_function(self, model, model_path: str):
"""Get the asynchronous generate stream function of the model"""
raise NotImplementedError
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
"""Get the default conv template"""
raise NotImplementedError
def model_adaptation(
self,
params: Dict,
model_name: str,
model_path: str,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
conv = self.get_default_conv_template(model_name, model_path)
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages
if prompt_template:
logger.info(f"Use prompt template {prompt_template} from config")
conv = get_conv_template(prompt_template)
if not conv or not messages:
# Nothing to do
logger.info(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return params, model_context
conv = conv.copy()
system_messages = []
for message in messages:
role, content = None, None
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
# Support for multiple system messages
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
conv.append_message(conv.roles[0], content)
elif role == ModelMessageRoleType.AI:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
if system_messages:
conv.set_system_message("".join(system_messages))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt
# Overwrite model params:
params["stop"] = conv.stop_str
params["stop_token_ids"] = conv.stop_token_ids
return params, model_context
class OldLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping old adapter, which may be removed later"""
def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
self._adapter = adapter
self._chat_adapter = chat_adapter
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer()
def model_type(self) -> str:
return self._adapter.model_type()
def model_param_class(self, model_type: str = None) -> ModelParameters:
return self._adapter.model_param_class(model_type)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._chat_adapter.get_conv_template(model_path)
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.loader(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model, model_path: str):
return self._chat_adapter.get_generate_stream_func(model_path)
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping fastchat adapter"""
def __init__(self, adapter: "BaseModelAdapter") -> None:
self._adapter = adapter
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.load_model(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
from fastchat.model.model_adapter import get_generate_stream_function
return get_generate_stream_function(model, model_path)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._adapter.get_default_conv_template(model_path)
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
def get_conv_template(name: str) -> "Conversation":
"""Get a conversation template."""
from fastchat.conversation import get_conv_template
return get_conv_template(name)
@cache
def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
try:
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
return adapter.get_default_conv_template(model_name, model_path)
except Exception:
return None
@cache
def get_llm_model_adapter(
model_name: str,
model_path: str,
use_fastchat: bool = True,
use_fastchat_monkey_patch: bool = False,
model_type: str = None,
) -> LLMModelAdaper:
if model_type == ModelType.VLLM:
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
return VLLMModelAdaperWrapper()
must_use_old = any(m in model_name for m in _OLD_MODELS)
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
adapter = _get_fastchat_model_adapter(
model_name,
model_path,
_fastchat_get_adapter_monkey_patch,
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
)
return FastChatLLMModelAdaperWrapper(adapter)
else:
from pilot.model.adapter import (
get_llm_model_adapter as _old_get_llm_model_adapter,
)
from pilot.server.chat_adapter import get_llm_chat_adapter
logger.info("Use DB-GPT old adapter")
return OldLLMModelAdaperWrapper(
_old_get_llm_model_adapter(model_name, model_path),
get_llm_chat_adapter(model_name, model_path),
)
def _get_fastchat_model_adapter(
model_name: str,
model_path: str,
caller: Callable[[str], None] = None,
use_fastchat_monkey_patch: bool = False,
):
from fastchat.model import model_adapter
_bak_get_model_adapter = model_adapter.get_model_adapter
try:
if use_fastchat_monkey_patch:
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
thread_local.model_name = model_name
if caller:
return caller(model_path)
finally:
del thread_local.model_name
model_adapter.get_model_adapter = _bak_get_model_adapter
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
if not model_name:
if not hasattr(thread_local, "model_name"):
raise RuntimeError("fastchat get adapter monkey path need model_name")
model_name = thread_local.model_name
from fastchat.model.model_adapter import model_adapters
for adapter in model_adapters:
if adapter.match(model_name):
logger.info(
f"Found llm model adapter with model name: {model_name}, {adapter}"
)
return adapter
model_path_basename = (
None if not model_path else os.path.basename(os.path.normpath(model_path))
)
for adapter in model_adapters:
if model_path_basename and adapter.match(model_path_basename):
logger.info(
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
)
return adapter
for adapter in model_adapters:
if model_path and adapter.match(model_path):
logger.info(
f"Found llm model adapter with model path: {model_path}, {adapter}"
)
return adapter
raise ValueError(
f"Invalid model adapter for model name {model_name} and model path {model_path}"
)
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]
class VLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping vllm engine"""
def model_type(self) -> str:
return ModelType.VLLM
def model_param_class(self, model_type: str = None) -> ModelParameters:
import argparse
from vllm.engine.arg_utils import AsyncEngineArgs
parser = argparse.ArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument("--model_name", type=str, help="model name")
parser.add_argument(
"--model_path",
type=str,
help="local model path of the huggingface model to use",
)
parser.add_argument("--model_type", type=str, help="model type")
parser.add_argument("--device", type=str, default=None, help="device")
# TODO parse prompt templete from `model_name` and `model_path`
parser.add_argument(
"--prompt_template",
type=str,
default=None,
help="Prompt template. If None, the prompt template is automatically determined from model path",
)
descs = _extract_parameter_details(
parser,
"pilot.model.parameter.VLLMModelParameters",
skip_names=["model"],
overwrite_default_values={"trust_remote_code": True},
)
return _build_parameter_class(descs)
def load_from_params(self, params):
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
import torch
num_gpus = torch.cuda.device_count()
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
setattr(params, "tensor_parallel_size", num_gpus)
logger.info(
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
)
params = dataclasses.asdict(params)
params["model"] = params["model_path"]
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
# Set the attributes from the parsed arguments.
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, engine.engine.tokenizer
def support_async(self) -> bool:
return True
def get_async_generate_stream_function(self, model, model_path: str):
from pilot.model.llm_out.vllm_llm import generate_stream
return generate_stream
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return _auto_get_conv_template(model_name, model_path)
def __str__(self) -> str:
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.
register_conv_template(
Conversation(
name="internlm-chat",
system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
roles=("<|User|>", "<|Bot|>"),
sep_style=SeparatorStyle.CHATINTERN,
sep="<eoh>",
sep2="<eoa>",
stop_token_ids=[1, 103028],
# TODO feedback stop_str to fastchat
stop_str="<eoa>",
),
override=True,
)