feat(model): Support Mixtral-8x7B (#959)

This commit is contained in:
Fangyin Cheng
2023-12-21 16:46:29 +08:00
committed by GitHub
parent aec124a5f1
commit 6b982e2879
17 changed files with 1386 additions and 688 deletions

View File

@@ -103,7 +103,8 @@ At present, we have introduced several key features to showcase our current capa
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more. We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
- News - News
- 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat) - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat) - 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
- [More Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf) - [More Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf)

View File

@@ -111,7 +111,8 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
海量模型支持包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型: 海量模型支持包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
- 新增支持模型 - 新增支持模型
- 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat) - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat) - 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
- [更多开源模型](https://www.yuque.com/eosphoros/dbgpt-docs/iqaaqwriwhp6zslc#qQktR) - [更多开源模型](https://www.yuque.com/eosphoros/dbgpt-docs/iqaaqwriwhp6zslc#qQktR)

View File

@@ -245,7 +245,7 @@ class WizardLMChatAdapter(BaseChatAdpter):
class LlamaCppChatAdapter(BaseChatAdpter): class LlamaCppChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
from dbgpt.model.adapter import LlamaCppAdapater from dbgpt.model.adapter.old_adapter import LlamaCppAdapater
if "llama-cpp" == model_path: if "llama-cpp" == model_path:
return True return True

View File

@@ -113,7 +113,9 @@ LLM_MODEL_CONFIG = {
# https://huggingface.co/microsoft/Orca-2-13b # https://huggingface.co/microsoft/Orca-2-13b
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"), "orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
# https://huggingface.co/openchat/openchat_3.5 # https://huggingface.co/openchat/openchat_3.5
"openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"), "openchat-3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
# https://huggingface.co/openchat/openchat-3.5-1210
"openchat-3.5-1210": os.path.join(MODEL_PATH, "openchat-3.5-1210"),
# https://huggingface.co/hfl/chinese-alpaca-2-7b # https://huggingface.co/hfl/chinese-alpaca-2-7b
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"), "chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
# https://huggingface.co/hfl/chinese-alpaca-2-13b # https://huggingface.co/hfl/chinese-alpaca-2-13b
@@ -124,6 +126,10 @@ LLM_MODEL_CONFIG = {
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"), "zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"), "mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
# https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
"mixtral-8x7b-instruct-v0.1": os.path.join(
MODEL_PATH, "Mixtral-8x7B-Instruct-v0.1"
),
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca # https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"), "mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1 # https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1

View File

437
dbgpt/model/adapter/base.py Normal file
View File

@@ -0,0 +1,437 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Tuple, Type, Callable
import logging
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.base import ModelType
from dbgpt.model.parameter import (
BaseModelParameters,
ModelParameters,
LlamaCppModelParameters,
ProxyModelParameters,
)
from dbgpt.model.adapter.template import (
get_conv_template,
ConversationAdapter,
ConversationAdapterFactory,
)
logger = logging.getLogger(__name__)
class LLMModelAdapter(ABC):
"""New Adapter for DB-GPT LLM models"""
model_name: Optional[str] = None
model_path: Optional[str] = None
conv_factory: Optional[ConversationAdapterFactory] = None
# TODO: more flexible quantization config
support_4bit: bool = False
support_8bit: bool = False
support_system_message: bool = True
def __repr__(self) -> str:
return f"<{self.__class__.__name__} model_name={self.model_name} model_path={self.model_path}>"
def __str__(self):
return self.__repr__()
@abstractmethod
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
"""Create a new adapter instance
Args:
**kwargs: The parameters of the new adapter instance
Returns:
LLMModelAdapter: The new adapter instance
"""
def use_fast_tokenizer(self) -> bool:
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
for a given model.
"""
return False
def model_type(self) -> str:
return ModelType.HF
def model_param_class(self, model_type: str = None) -> Type[BaseModelParameters]:
"""Get the startup parameters instance of the model
Args:
model_type (str, optional): The type of model. Defaults to None.
Returns:
Type[BaseModelParameters]: The startup parameters instance of the model
"""
# """Get the startup parameters instance of the model"""
model_type = model_type if model_type else self.model_type()
if model_type == ModelType.LLAMA_CPP:
return LlamaCppModelParameters
elif model_type == ModelType.PROXY:
return ProxyModelParameters
return ModelParameters
def match(
self,
model_type: str,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
) -> bool:
"""Whether the model adapter can load the given model
Args:
model_type (str): The type of model
model_name (Optional[str], optional): The name of model. Defaults to None.
model_path (Optional[str], optional): The path of model. Defaults to None.
"""
return False
def support_quantization_4bit(self) -> bool:
"""Whether the model adapter can load 4bit model
If it is True, we will load the 4bit model with :meth:`~LLMModelAdapter.load`
Returns:
bool: Whether the model adapter can load 4bit model, default is False
"""
return self.support_4bit
def support_quantization_8bit(self) -> bool:
"""Whether the model adapter can load 8bit model
If it is True, we will load the 8bit model with :meth:`~LLMModelAdapter.load`
Returns:
bool: Whether the model adapter can load 8bit model, default is False
"""
return self.support_8bit
def load(self, model_path: str, from_pretrained_kwargs: dict):
"""Load model and tokenizer"""
raise NotImplementedError
def load_from_params(self, params):
"""Load the model and tokenizer according to the given parameters"""
raise NotImplementedError
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""
return False
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
raise NotImplementedError
def get_async_generate_stream_function(self, model, model_path: str):
"""Get the asynchronous generate stream function of the model"""
raise NotImplementedError
def get_default_conv_template(
self, model_name: str, model_path: str
) -> Optional[ConversationAdapter]:
"""Get the default conversation template
Args:
model_name (str): The name of the model.
model_path (str): The path of the model.
Returns:
Optional[ConversationAdapter]: The conversation template.
"""
raise NotImplementedError
def get_default_message_separator(self) -> str:
"""Get the default message separator"""
try:
conv_template = self.get_default_conv_template(
self.model_name, self.model_path
)
return conv_template.sep
except Exception:
return "\n"
def transform_model_messages(
self, messages: List[ModelMessage]
) -> List[Dict[str, str]]:
"""Transform the model messages
Default is the OpenAI format, example:
.. code-block:: python
return_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"}
]
But some model may need to transform the messages to other format(e.g. There is no system message), such as:
.. code-block:: python
return_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"}
]
Args:
messages (List[ModelMessage]): The model messages
Returns:
List[Dict[str, str]]: The transformed model messages
"""
logger.info(f"support_system_message: {self.support_system_message}")
if not self.support_system_message:
return self._transform_to_no_system_messages(messages)
else:
return ModelMessage.to_openai_messages(messages)
def _transform_to_no_system_messages(
self, messages: List[ModelMessage]
) -> List[Dict[str, str]]:
"""Transform the model messages to no system messages
Some opensource chat model no system messages, so wo should transform the messages to no system messages.
Merge the system messages to the last user message, example:
.. code-block:: python
return_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"}
]
=>
return_messages = [
{"role": "user", "content": "You are a helpful assistant\nHello"},
{"role": "assistant", "content": "Hi"}
]
Args:
messages (List[ModelMessage]): The model messages
Returns:
List[Dict[str, str]]: The transformed model messages
"""
openai_messages = ModelMessage.to_openai_messages(messages)
system_messages = []
return_messages = []
for message in openai_messages:
if message["role"] == "system":
system_messages.append(message["content"])
else:
return_messages.append(message)
if len(system_messages) > 1:
# Too much system messages should be a warning
logger.warning("Your system messages have more than one message")
if system_messages:
sep = self.get_default_message_separator()
str_system_messages = ",".join(system_messages)
# Update last user message
return_messages[-1]["content"] = (
str_system_messages + sep + return_messages[-1]["content"]
)
return return_messages
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
"""Get the string prompt from the given parameters and messages
If the value of return is not None, we will skip :meth:`~LLMModelAdapter.get_prompt_with_template` and use the value of return.
Args:
params (Dict): The parameters
messages (List[ModelMessage]): The model messages
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
prompt_template (str, optional): The prompt template. Defaults to None.
Returns:
Optional[str]: The string prompt
"""
return None
def get_prompt_with_template(
self,
params: Dict,
messages: List[ModelMessage],
model_name: str,
model_path: str,
model_context: Dict,
prompt_template: str = None,
):
conv: ConversationAdapter = self.get_default_conv_template(
model_name, model_path
)
if prompt_template:
logger.info(f"Use prompt template {prompt_template} from config")
conv = get_conv_template(prompt_template)
if not conv or not messages:
# Nothing to do
logger.info(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return None, None, None
conv = conv.copy()
system_messages = []
user_messages = []
ai_messages = []
for message in messages:
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
# Support for multiple system messages
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
# conv.append_message(conv.roles[0], content)
user_messages.append(content)
elif role == ModelMessageRoleType.AI:
# conv.append_message(conv.roles[1], content)
ai_messages.append(content)
else:
raise ValueError(f"Unknown role: {role}")
can_use_systems: [] = []
if system_messages:
if len(system_messages) > 1:
# Compatible with dbgpt complex scenarios, the last system will protect more complete information
# entered by the current user
user_messages[-1] = system_messages[-1]
can_use_systems = system_messages[:-1]
else:
can_use_systems = system_messages
for i in range(len(user_messages)):
conv.append_message(conv.roles[0], user_messages[i])
if i < len(ai_messages):
conv.append_message(conv.roles[1], ai_messages[i])
# TODO join all system messages may not be a good idea
conv.set_system_message("".join(can_use_systems))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids
def model_adaptation(
self,
params: Dict,
model_name: str,
model_path: str,
tokenizer: Any,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
messages = params.get("messages")
# Some model context to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
conv_stop_str, conv_stop_token_ids = None, None
if not new_prompt:
(
new_prompt,
conv_stop_str,
conv_stop_token_ids,
) = self.get_prompt_with_template(
params, messages, model_name, model_path, model_context, prompt_template
)
if not new_prompt:
return params, model_context
# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt
custom_stop = params.get("stop")
custom_stop_token_ids = params.get("stop_token_ids")
# Prefer the value passed in from the input parameter
params["stop"] = custom_stop or conv_stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
return params, model_context
class AdapterEntry:
"""The entry of model adapter"""
def __init__(
self,
model_adapter: LLMModelAdapter,
match_funcs: List[Callable[[str, str, str], bool]] = None,
):
self.model_adapter = model_adapter
self.match_funcs = match_funcs or []
model_adapters: List[AdapterEntry] = []
def register_model_adapter(
model_adapter_cls: Type[LLMModelAdapter],
match_funcs: List[Callable[[str, str, str], bool]] = None,
) -> None:
"""Register a model adapter.
Args:
model_adapter_cls (Type[LLMModelAdapter]): The model adapter class.
match_funcs (List[Callable[[str, str, str], bool]], optional): The match functions. Defaults to None.
"""
model_adapters.append(AdapterEntry(model_adapter_cls(), match_funcs))
def get_model_adapter(
model_type: str,
model_name: str,
model_path: str,
conv_factory: Optional[ConversationAdapterFactory] = None,
) -> Optional[LLMModelAdapter]:
"""Get a model adapter.
Args:
model_type (str): The type of the model.
model_name (str): The name of the model.
model_path (str): The path of the model.
conv_factory (Optional[ConversationAdapterFactory], optional): The conversation factory. Defaults to None.
Returns:
Optional[LLMModelAdapter]: The model adapter.
"""
adapter = None
# First find adapter by model_name
for adapter_entry in model_adapters:
if adapter_entry.model_adapter.match(model_type, model_name, None):
adapter = adapter_entry.model_adapter
break
for adapter_entry in model_adapters:
if adapter_entry.model_adapter.match(model_type, None, model_path):
adapter = adapter_entry.model_adapter
break
if adapter:
new_adapter = adapter.new_adapter()
new_adapter.model_name = model_name
new_adapter.model_path = model_path
if conv_factory:
new_adapter.conv_factory = conv_factory
return new_adapter
return None

View File

@@ -0,0 +1,262 @@
"""Adapter for fastchat
You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it.
"""
import os
import threading
import logging
from functools import cache
from typing import TYPE_CHECKING, Callable, Tuple, List, Optional
try:
from fastchat.conversation import (
Conversation,
register_conv_template,
SeparatorStyle,
)
except ImportError as exc:
raise ValueError(
"Could not import python package: fschat "
"Please install fastchat by command `pip install fschat` "
) from exc
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.adapter.base import LLMModelAdapter
if TYPE_CHECKING:
from fastchat.model.model_adapter import BaseModelAdapter
from torch.nn import Module as TorchNNModule
logger = logging.getLogger(__name__)
thread_local = threading.local()
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
# If some model is not in the blacklist, but it still affects the loading of DB-GPT, you can add it to the blacklist.
__BLACK_LIST_MODEL_PROMPT = []
class FschatConversationAdapter(ConversationAdapter):
"""The conversation adapter for fschat."""
def __init__(self, conv: Conversation):
self._conv = conv
@property
def prompt_type(self) -> PromptType:
return PromptType.FSCHAT
@property
def roles(self) -> Tuple[str]:
return self._conv.roles
@property
def sep(self) -> Optional[str]:
return self._conv.sep
@property
def stop_str(self) -> str:
return self._conv.stop_str
@property
def stop_token_ids(self) -> Optional[List[int]]:
return self._conv.stop_token_ids
def get_prompt(self) -> str:
"""Get the prompt string."""
return self._conv.get_prompt()
def set_system_message(self, system_message: str) -> None:
"""Set the system message."""
self._conv.set_system_message(system_message)
def append_message(self, role: str, message: str) -> None:
"""Append a new message.
Args:
role (str): The role of the message.
message (str): The message content.
"""
self._conv.append_message(role, message)
def update_last_message(self, message: str) -> None:
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
Args:
message (str): The message content.
"""
self._conv.update_last_message(message)
def copy(self) -> "ConversationAdapter":
"""Copy the conversation."""
return FschatConversationAdapter(self._conv.copy())
class FastChatLLMModelAdapterWrapper(LLMModelAdapter):
"""Wrapping fastchat adapter"""
def __init__(self, adapter: "BaseModelAdapter") -> None:
self._adapter = adapter
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
return FastChatLLMModelAdapterWrapper(self._adapter)
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.load_model(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
if _IS_BENCHMARK:
from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
generate_stream,
)
return generate_stream
else:
from fastchat.model.model_adapter import get_generate_stream_function
return get_generate_stream_function(model, model_path)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> Optional[ConversationAdapter]:
conv_template = self._adapter.get_default_conv_template(model_path)
return FschatConversationAdapter(conv_template) if conv_template else None
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
def _get_fastchat_model_adapter(
model_name: str,
model_path: str,
caller: Callable[[str], None] = None,
use_fastchat_monkey_patch: bool = False,
):
from fastchat.model import model_adapter
_bak_get_model_adapter = model_adapter.get_model_adapter
try:
if use_fastchat_monkey_patch:
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
thread_local.model_name = model_name
_remove_black_list_model_of_fastchat()
if caller:
return caller(model_path)
finally:
del thread_local.model_name
model_adapter.get_model_adapter = _bak_get_model_adapter
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
if not model_name:
if not hasattr(thread_local, "model_name"):
raise RuntimeError("fastchat get adapter monkey path need model_name")
model_name = thread_local.model_name
from fastchat.model.model_adapter import model_adapters
for adapter in model_adapters:
if adapter.match(model_name):
logger.info(
f"Found llm model adapter with model name: {model_name}, {adapter}"
)
return adapter
model_path_basename = (
None if not model_path else os.path.basename(os.path.normpath(model_path))
)
for adapter in model_adapters:
if model_path_basename and adapter.match(model_path_basename):
logger.info(
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
)
return adapter
for adapter in model_adapters:
if model_path and adapter.match(model_path):
logger.info(
f"Found llm model adapter with model path: {model_path}, {adapter}"
)
return adapter
raise ValueError(
f"Invalid model adapter for model name {model_name} and model path {model_path}"
)
@cache
def _remove_black_list_model_of_fastchat():
from fastchat.model.model_adapter import model_adapters
black_list_models = []
for adapter in model_adapters:
try:
if (
adapter.get_default_conv_template("/data/not_exist_model_path").name
in __BLACK_LIST_MODEL_PROMPT
):
black_list_models.append(adapter)
except Exception:
pass
for adapter in black_list_models:
model_adapters.remove(adapter)
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
register_conv_template(
Conversation(
name="aquila-legacy",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("### Human: ", "### Assistant: ", "System"),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="\n",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
register_conv_template(
Conversation(
name="aquila",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant", "System"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_TWO,
sep="###",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
register_conv_template(
Conversation(
name="aquila-v1",
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
),
override=True,
)

View File

@@ -0,0 +1,136 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, List, Any
import logging
from dbgpt.core import ModelMessage
from dbgpt.model.base import ModelType
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
logger = logging.getLogger(__name__)
class NewHFChatModelAdapter(LLMModelAdapter, ABC):
"""Model adapter for new huggingface chat models
See https://huggingface.co/docs/transformers/main/en/chat_templating
We can transform the inference chat messages to chat model instead of create a
prompt template for this model
"""
def new_adapter(self, **kwargs) -> "NewHFChatModelAdapter":
return self.__class__()
def match(
self,
model_type: str,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
) -> bool:
if model_type != ModelType.HF:
return False
if model_name is None and model_path is None:
return False
model_name = model_name.lower() if model_name else None
model_path = model_path.lower() if model_path else None
return self.do_match(model_name) or self.do_match(model_path)
@abstractmethod
def do_match(self, lower_model_name_or_path: Optional[str] = None):
raise NotImplementedError()
def load(self, model_path: str, from_pretrained_kwargs: dict):
try:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
"Please install it with `pip install transformers`."
) from exc
if not transformers.__version__ >= "4.34.0":
raise ValueError(
"Current model (Load by NewHFChatModelAdapter) require transformers.__version__>=4.34.0"
)
revision = from_pretrained_kwargs.get("revision", "main")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=self.use_fast_tokenizer(),
revision=revision,
trust_remote_code=True,
)
except TypeError:
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, revision=revision, trust_remote_code=True
)
try:
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
except NameError:
model = AutoModel.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
# tokenizer.use_default_system_prompt = False
return model, tokenizer
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
return huggingface_chat_generate_stream
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
from transformers import AutoTokenizer
if not tokenizer:
raise ValueError("tokenizer is is None")
tokenizer: AutoTokenizer = tokenizer
messages = self.transform_model_messages(messages)
logger.debug(f"The messages after transform: \n{messages}")
str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return str_prompt
class YiAdapter(NewHFChatModelAdapter):
support_4bit: bool = True
support_8bit: bool = True
support_system_message: bool = True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "yi-" in lower_model_name_or_path
and "chat" in lower_model_name_or_path
)
class Mixtral8x7BAdapter(NewHFChatModelAdapter):
"""
https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
"""
support_4bit: bool = True
support_8bit: bool = True
support_system_message: bool = False
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "mixtral" in lower_model_name_or_path
and "8x7b" in lower_model_name_or_path
)
register_model_adapter(YiAdapter)
register_model_adapter(Mixtral8x7BAdapter)

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
from typing import (
List,
Type,
Optional,
)
import logging
import threading
import os
from functools import cache
from dbgpt.model.base import ModelType
from dbgpt.model.parameter import BaseModelParameters
from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
from dbgpt.model.adapter.template import (
ConversationAdapter,
ConversationAdapterFactory,
)
logger = logging.getLogger(__name__)
thread_local = threading.local()
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
_OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
"codellama-13b-sql-sft",
"codellama-7b",
"codellama-7b-sql-sft",
"codellama-13b",
]
@cache
def get_llm_model_adapter(
model_name: str,
model_path: str,
use_fastchat: bool = True,
use_fastchat_monkey_patch: bool = False,
model_type: str = None,
) -> LLMModelAdapter:
conv_factory = DefaultConversationAdapterFactory()
if model_type == ModelType.VLLM:
logger.info("Current model type is vllm, return VLLMModelAdapterWrapper")
from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
return VLLMModelAdapterWrapper(conv_factory)
# Import NewHFChatModelAdapter for it can be registered
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
new_model_adapter = get_model_adapter(
model_type, model_name, model_path, conv_factory
)
if new_model_adapter:
logger.info(f"Current model {model_name} use new adapter {new_model_adapter}")
return new_model_adapter
must_use_old = any(m in model_name for m in _OLD_MODELS)
result_adapter: Optional[LLMModelAdapter] = None
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
from dbgpt.model.adapter.fschat_adapter import (
_get_fastchat_model_adapter,
_fastchat_get_adapter_monkey_patch,
FastChatLLMModelAdapterWrapper,
)
adapter = _get_fastchat_model_adapter(
model_name,
model_path,
_fastchat_get_adapter_monkey_patch,
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
)
if adapter:
result_adapter = FastChatLLMModelAdapterWrapper(adapter)
else:
from dbgpt.model.adapter.old_adapter import (
get_llm_model_adapter as _old_get_llm_model_adapter,
OldLLMModelAdapterWrapper,
)
from dbgpt.app.chat_adapter import get_llm_chat_adapter
logger.info("Use DB-GPT old adapter")
result_adapter = OldLLMModelAdapterWrapper(
_old_get_llm_model_adapter(model_name, model_path),
get_llm_chat_adapter(model_name, model_path),
)
if result_adapter:
result_adapter.model_name = model_name
result_adapter.model_path = model_path
result_adapter.conv_factory = conv_factory
return result_adapter
else:
raise ValueError(f"Can not find adapter for model {model_name}")
@cache
def _auto_get_conv_template(
model_name: str, model_path: str
) -> Optional[ConversationAdapter]:
"""Auto get the conversation template.
Args:
model_name (str): The name of the model.
model_path (str): The path of the model.
Returns:
Optional[ConversationAdapter]: The conversation template.
"""
try:
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
return adapter.get_default_conv_template(model_name, model_path)
except Exception as e:
logger.debug(f"Failed to get conv template for {model_name} {model_path}: {e}")
return None
class DefaultConversationAdapterFactory(ConversationAdapterFactory):
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
"""Get a conversation adapter by model.
Args:
model_name (str): The name of the model.
model_path (str): The path of the model.
Returns:
ConversationAdapter: The conversation adapter.
"""
return _auto_get_conv_template(model_name, model_path)
def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
"""Dynamic model parser, parse the model parameters from the command line arguments.
Returns:
Optional[List[Type[BaseModelParameters]]]: The model parameters class list.
"""
from dbgpt.util.parameter_utils import _SimpleArgParser
from dbgpt.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]

View File

@@ -9,7 +9,7 @@ import os
import re import re
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple, TYPE_CHECKING, Optional
from functools import cache from functools import cache
from transformers import ( from transformers import (
AutoModel, AutoModel,
@@ -17,6 +17,9 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
LlamaTokenizer, LlamaTokenizer,
) )
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.base import ModelType from dbgpt.model.base import ModelType
from dbgpt.model.parameter import ( from dbgpt.model.parameter import (
@@ -24,9 +27,13 @@ from dbgpt.model.parameter import (
LlamaCppModelParameters, LlamaCppModelParameters,
ProxyModelParameters, ProxyModelParameters,
) )
from dbgpt.model.conversation import Conversation
from dbgpt.configs.model_config import get_device from dbgpt.configs.model_config import get_device
from dbgpt._private.config import Config from dbgpt._private.config import Config
if TYPE_CHECKING:
from dbgpt.app.chat_adapter import BaseChatAdpter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
@@ -92,17 +99,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
) )
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
try:
llm_adapter = get_llm_model_adapter(model_name, model_path)
return llm_adapter.model_param_class()
except Exception as e:
logger.warn(
f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`"
)
return ModelParameters
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? # TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
@@ -426,6 +422,87 @@ class InternLMAdapter(BaseLLMAdaper):
return model, tokenizer return model, tokenizer
class OldLLMModelAdapterWrapper(LLMModelAdapter):
"""Wrapping old adapter, which may be removed later"""
def __init__(self, adapter: BaseLLMAdaper, chat_adapter: "BaseChatAdpter") -> None:
self._adapter = adapter
self._chat_adapter = chat_adapter
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
return OldLLMModelAdapterWrapper(self._adapter, self._chat_adapter)
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer()
def model_type(self) -> str:
return self._adapter.model_type()
def model_param_class(self, model_type: str = None) -> ModelParameters:
return self._adapter.model_param_class(model_type)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> Optional[ConversationAdapter]:
conv_template = self._chat_adapter.get_conv_template(model_path)
return OldConversationAdapter(conv_template) if conv_template else None
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.loader(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model, model_path: str):
return self._chat_adapter.get_generate_stream_func(model_path)
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
class OldConversationAdapter(ConversationAdapter):
"""Wrapping old Conversation, which may be removed later"""
def __init__(self, conv: Conversation) -> None:
self._conv = conv
@property
def prompt_type(self) -> PromptType:
return PromptType.DBGPT
@property
def roles(self) -> Tuple[str]:
return self._conv.roles
@property
def sep(self) -> Optional[str]:
return self._conv.sep
@property
def stop_str(self) -> str:
return self._conv.stop_str
@property
def stop_token_ids(self) -> Optional[List[int]]:
return self._conv.stop_token_ids
def get_prompt(self) -> str:
return self._conv.get_prompt()
def set_system_message(self, system_message: str) -> None:
self._conv.update_system_message(system_message)
def append_message(self, role: str, message: str) -> None:
self._conv.append_message(role, message)
def update_last_message(self, message: str) -> None:
self._conv.update_last_message(message)
def copy(self) -> "ConversationAdapter":
return OldConversationAdapter(self._conv.copy())
register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)

View File

@@ -0,0 +1,130 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Optional, Tuple, Union, List
if TYPE_CHECKING:
from fastchat.conversation import Conversation
class PromptType(str, Enum):
"""Prompt type."""
FSCHAT: str = "fschat"
DBGPT: str = "dbgpt"
class ConversationAdapter(ABC):
"""The conversation adapter."""
@property
def prompt_type(self) -> PromptType:
return PromptType.FSCHAT
@property
@abstractmethod
def roles(self) -> Tuple[str]:
"""Get the roles of the conversation.
Returns:
Tuple[str]: The roles of the conversation.
"""
@property
def sep(self) -> Optional[str]:
"""Get the separator between messages."""
return "\n"
@property
def stop_str(self) -> Optional[Union[str, List[str]]]:
"""Get the stop criteria."""
return None
@property
def stop_token_ids(self) -> Optional[List[int]]:
"""Stops generation if meeting any token in this list"""
return None
@abstractmethod
def get_prompt(self) -> str:
"""Get the prompt string.
Returns:
str: The prompt string.
"""
@abstractmethod
def set_system_message(self, system_message: str) -> None:
"""Set the system message."""
@abstractmethod
def append_message(self, role: str, message: str) -> None:
"""Append a new message.
Args:
role (str): The role of the message.
message (str): The message content.
"""
@abstractmethod
def update_last_message(self, message: str) -> None:
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
Args:
message (str): The message content.
"""
@abstractmethod
def copy(self) -> "ConversationAdapter":
"""Copy the conversation."""
class ConversationAdapterFactory(ABC):
"""The conversation adapter factory."""
def get_by_name(
self,
template_name: str,
prompt_template_type: Optional[PromptType] = PromptType.FSCHAT,
) -> ConversationAdapter:
"""Get a conversation adapter by name.
Args:
template_name (str): The name of the template.
prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT.
Returns:
ConversationAdapter: The conversation adapter.
"""
raise NotImplementedError()
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
"""Get a conversation adapter by model.
Args:
model_name (str): The name of the model.
model_path (str): The path of the model.
Returns:
ConversationAdapter: The conversation adapter.
"""
raise NotImplementedError()
def get_conv_template(name: str) -> ConversationAdapter:
"""Get a conversation template.
Args:
name (str): The name of the template.
Just return the fastchat conversation template for now.
# TODO: More templates should be supported.
Returns:
Conversation: The conversation template.
"""
from fastchat.conversation import get_conv_template
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
conv_template = get_conv_template(name)
return FschatConversationAdapter(conv_template)

View File

@@ -0,0 +1,93 @@
import dataclasses
import logging
from dbgpt.model.base import ModelType
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
from dbgpt.model.parameter import BaseModelParameters
from dbgpt.util.parameter_utils import (
_extract_parameter_details,
_build_parameter_class,
_get_dataclass_print_str,
)
logger = logging.getLogger(__name__)
class VLLMModelAdapterWrapper(LLMModelAdapter):
"""Wrapping vllm engine"""
def __init__(self, conv_factory: ConversationAdapterFactory):
self.conv_factory = conv_factory
def new_adapter(self, **kwargs) -> "VLLMModelAdapterWrapper":
return VLLMModelAdapterWrapper(self.conv_factory)
def model_type(self) -> str:
return ModelType.VLLM
def model_param_class(self, model_type: str = None) -> BaseModelParameters:
import argparse
from vllm.engine.arg_utils import AsyncEngineArgs
parser = argparse.ArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument("--model_name", type=str, help="model name")
parser.add_argument(
"--model_path",
type=str,
help="local model path of the huggingface model to use",
)
parser.add_argument("--model_type", type=str, help="model type")
parser.add_argument("--device", type=str, default=None, help="device")
# TODO parse prompt templete from `model_name` and `model_path`
parser.add_argument(
"--prompt_template",
type=str,
default=None,
help="Prompt template. If None, the prompt template is automatically determined from model path",
)
descs = _extract_parameter_details(
parser,
"dbgpt.model.parameter.VLLMModelParameters",
skip_names=["model"],
overwrite_default_values={"trust_remote_code": True},
)
return _build_parameter_class(descs)
def load_from_params(self, params):
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
import torch
num_gpus = torch.cuda.device_count()
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
setattr(params, "tensor_parallel_size", num_gpus)
logger.info(
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
)
params = dataclasses.asdict(params)
params["model"] = params["model_path"]
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
# Set the attributes from the parsed arguments.
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, engine.engine.tokenizer
def support_async(self) -> bool:
return True
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.llm_out.vllm_llm import generate_stream
return generate_stream
def get_default_conv_template(
self, model_name: str, model_path: str
) -> ConversationAdapter:
return self.conv_factory.get_by_model(model_name, model_path)
def __str__(self) -> str:
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)

View File

@@ -405,7 +405,7 @@ def stop_model_controller(port: int):
def _model_dynamic_factory() -> Callable[[None], List[Type]]: def _model_dynamic_factory() -> Callable[[None], List[Type]]:
from dbgpt.model.model_adapter import _dynamic_model_parser from dbgpt.model.adapter.model_adapter import _dynamic_model_parser
param_class = _dynamic_model_parser() param_class = _dynamic_model_parser()
fix_class = [ModelWorkerParameters] fix_class = [ModelWorkerParameters]

View File

@@ -6,7 +6,8 @@ import time
import traceback import traceback
from dbgpt.configs.model_config import get_device from dbgpt.configs.model_config import get_device
from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.core import ModelOutput, ModelInferenceMetrics from dbgpt.core import ModelOutput, ModelInferenceMetrics
from dbgpt.model.loader import ModelLoader, _get_model_real_path from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters from dbgpt.model.parameter import ModelParameters
@@ -27,7 +28,7 @@ class DefaultModelWorker(ModelWorker):
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self._model_params = None self._model_params = None
self.llm_adapter: LLMModelAdaper = None self.llm_adapter: LLMModelAdapter = None
self._support_async = False self._support_async = False
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:

View File

@@ -37,7 +37,7 @@ def list_supported_models():
def _list_supported_models( def _list_supported_models(
worker_type: str, model_config: Dict[str, str] worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]: ) -> List[SupportedModel]:
from dbgpt.model.model_adapter import get_llm_model_adapter from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.loader import _get_model_real_path from dbgpt.model.loader import _get_model_real_path
ret = [] ret = []

View File

@@ -1,13 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Optional, Dict from typing import Optional, Dict, Any
from dataclasses import asdict
import logging import logging
from dbgpt.configs.model_config import get_device from dbgpt.configs.model_config import get_device
from dbgpt.model.base import ModelType from dbgpt.model.base import ModelType
from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.parameter import ( from dbgpt.model.parameter import (
ModelParameters, ModelParameters,
LlamaCppModelParameters, LlamaCppModelParameters,
@@ -117,7 +117,7 @@ class ModelLoader:
raise Exception(f"Unkown model type {model_type}") raise Exception(f"Unkown model type {model_type}")
def loader_with_params( def loader_with_params(
self, model_params: ModelParameters, llm_adapter: LLMModelAdaper self, model_params: ModelParameters, llm_adapter: LLMModelAdapter
): ):
model_type = llm_adapter.model_type() model_type = llm_adapter.model_type()
self.prompt_template = model_params.prompt_template self.prompt_template = model_params.prompt_template
@@ -133,7 +133,7 @@ class ModelLoader:
raise Exception(f"Unkown model type {model_type}") raise Exception(f"Unkown model type {model_type}")
def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters): def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
import torch import torch
from dbgpt.model.compression import compress_module from dbgpt.model.compression import compress_module
@@ -174,6 +174,12 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
else: else:
raise ValueError(f"Invalid device: {device}") raise ValueError(f"Invalid device: {device}")
model, tokenizer = _try_load_default_quantization_model(
llm_adapter, device, num_gpus, model_params, kwargs
)
if model:
return model, tokenizer
can_quantization = _check_quantization(model_params) can_quantization = _check_quantization(model_params)
if can_quantization and (num_gpus > 1 or model_params.load_4bit): if can_quantization and (num_gpus > 1 or model_params.load_4bit):
@@ -192,6 +198,46 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
# TODO merge current code into `load_huggingface_quantization_model` # TODO merge current code into `load_huggingface_quantization_model`
compress_module(model, model_params.device) compress_module(model, model_params.device)
return _handle_model_and_tokenizer(model, tokenizer, device, num_gpus, model_params)
def _try_load_default_quantization_model(
llm_adapter: LLMModelAdapter,
device: str,
num_gpus: int,
model_params: ModelParameters,
kwargs: Dict[str, Any],
):
"""Try load default quantization model(Support by huggingface default)"""
cloned_kwargs = {k: v for k, v in kwargs.items()}
try:
model, tokenizer = None, None
if device != "cuda":
return None, None
elif model_params.load_8bit and llm_adapter.support_8bit:
cloned_kwargs["load_in_8bit"] = True
model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
elif model_params.load_4bit and llm_adapter.support_4bit:
cloned_kwargs["load_in_4bit"] = True
model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
if model:
logger.info(
f"Load default quantization model {model_params.model_name} success"
)
return _handle_model_and_tokenizer(
model, tokenizer, device, num_gpus, model_params
)
return None, None
except Exception as e:
logger.warning(
f"Load default quantization model {model_params.model_name} failed, error: {str(e)}"
)
return None, None
def _handle_model_and_tokenizer(
model, tokenizer, device: str, num_gpus: int, model_params: ModelParameters
):
if ( if (
(device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading) (device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
or device == "mps" or device == "mps"
@@ -209,7 +255,7 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
def load_huggingface_quantization_model( def load_huggingface_quantization_model(
llm_adapter: LLMModelAdaper, llm_adapter: LLMModelAdapter,
model_params: ModelParameters, model_params: ModelParameters,
kwargs: Dict, kwargs: Dict,
max_memory: Dict[int, str], max_memory: Dict[int, str],
@@ -344,7 +390,9 @@ def load_huggingface_quantization_model(
return model, tokenizer return model, tokenizer
def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters): def llamacpp_loader(
llm_adapter: LLMModelAdapter, model_params: LlamaCppModelParameters
):
try: try:
from dbgpt.model.llm.llama_cpp.llama_cpp import LlamaCppModel from dbgpt.model.llm.llama_cpp.llama_cpp import LlamaCppModel
except ImportError as exc: except ImportError as exc:
@@ -358,7 +406,7 @@ def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelPara
return model, tokenizer return model, tokenizer
def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters): def proxyllm_loader(llm_adapter: LLMModelAdapter, model_params: ProxyModelParameters):
from dbgpt.model.proxy.llms.proxy_model import ProxyModel from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger.info("Load proxyllm") logger.info("Load proxyllm")

View File

@@ -1,660 +0,0 @@
from __future__ import annotations
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
import dataclasses
import logging
import threading
import os
from functools import cache
from dbgpt.model.base import ModelType
from dbgpt.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
ProxyModelParameters,
)
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util.parameter_utils import (
_extract_parameter_details,
_build_parameter_class,
_get_dataclass_print_str,
)
try:
from fastchat.conversation import (
Conversation,
register_conv_template,
SeparatorStyle,
)
except ImportError as exc:
raise ValueError(
"Could not import python package: fschat "
"Please install fastchat by command `pip install fschat` "
) from exc
if TYPE_CHECKING:
from fastchat.model.model_adapter import BaseModelAdapter
from dbgpt.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
from torch.nn import Module as TorchNNModule
logger = logging.getLogger(__name__)
thread_local = threading.local()
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
_OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
"codellama-13b-sql-sft",
"codellama-7b",
"codellama-7b-sql-sft",
"codellama-13b",
]
_NEW_HF_CHAT_MODELS = [
"yi-34b",
"yi-6b",
]
# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]
class LLMModelAdaper:
"""New Adapter for DB-GPT LLM models"""
def use_fast_tokenizer(self) -> bool:
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
for a given model.
"""
return False
def model_type(self) -> str:
return ModelType.HF
def model_param_class(self, model_type: str = None) -> ModelParameters:
"""Get the startup parameters instance of the model"""
model_type = model_type if model_type else self.model_type()
if model_type == ModelType.LLAMA_CPP:
return LlamaCppModelParameters
elif model_type == ModelType.PROXY:
return ProxyModelParameters
return ModelParameters
def load(self, model_path: str, from_pretrained_kwargs: dict):
"""Load model and tokenizer"""
raise NotImplementedError
def load_from_params(self, params):
"""Load the model and tokenizer according to the given parameters"""
raise NotImplementedError
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""
return False
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
raise NotImplementedError
def get_async_generate_stream_function(self, model, model_path: str):
"""Get the asynchronous generate stream function of the model"""
raise NotImplementedError
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
"""Get the default conv template"""
raise NotImplementedError
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
return None
def get_prompt_with_template(
self,
params: Dict,
messages: List[ModelMessage],
model_name: str,
model_path: str,
model_context: Dict,
prompt_template: str = None,
):
conv = self.get_default_conv_template(model_name, model_path)
if prompt_template:
logger.info(f"Use prompt template {prompt_template} from config")
conv = get_conv_template(prompt_template)
if not conv or not messages:
# Nothing to do
logger.info(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return None, None, None
conv = conv.copy()
system_messages = []
user_messages = []
ai_messages = []
for message in messages:
role, content = None, None
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
# Support for multiple system messages
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
# conv.append_message(conv.roles[0], content)
user_messages.append(content)
elif role == ModelMessageRoleType.AI:
# conv.append_message(conv.roles[1], content)
ai_messages.append(content)
else:
raise ValueError(f"Unknown role: {role}")
can_use_systems: [] = []
if system_messages:
if len(system_messages) > 1:
## Compatible with dbgpt complex scenarios, the last system will protect more complete information entered by the current user
user_messages[-1] = system_messages[-1]
can_use_systems = system_messages[:-1]
else:
can_use_systems = system_messages
for i in range(len(user_messages)):
conv.append_message(conv.roles[0], user_messages[i])
if i < len(ai_messages):
conv.append_message(conv.roles[1], ai_messages[i])
if isinstance(conv, Conversation):
conv.set_system_message("".join(can_use_systems))
else:
conv.update_system_message("".join(can_use_systems))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids
def model_adaptation(
self,
params: Dict,
model_name: str,
model_path: str,
tokenizer: Any,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
conv_stop_str, conv_stop_token_ids = None, None
if not new_prompt:
(
new_prompt,
conv_stop_str,
conv_stop_token_ids,
) = self.get_prompt_with_template(
params, messages, model_name, model_path, model_context, prompt_template
)
if not new_prompt:
return params, model_context
# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt
custom_stop = params.get("stop")
custom_stop_token_ids = params.get("stop_token_ids")
# Prefer the value passed in from the input parameter
params["stop"] = custom_stop or conv_stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
return params, model_context
class OldLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping old adapter, which may be removed later"""
def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
self._adapter = adapter
self._chat_adapter = chat_adapter
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer()
def model_type(self) -> str:
return self._adapter.model_type()
def model_param_class(self, model_type: str = None) -> ModelParameters:
return self._adapter.model_param_class(model_type)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._chat_adapter.get_conv_template(model_path)
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.loader(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model, model_path: str):
return self._chat_adapter.get_generate_stream_func(model_path)
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping fastchat adapter"""
def __init__(self, adapter: "BaseModelAdapter") -> None:
self._adapter = adapter
def use_fast_tokenizer(self) -> bool:
return self._adapter.use_fast_tokenizer
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.load_model(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
if _IS_BENCHMARK:
from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
generate_stream,
)
return generate_stream
else:
from fastchat.model.model_adapter import get_generate_stream_function
return get_generate_stream_function(model, model_path)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._adapter.get_default_conv_template(model_path)
def __str__(self) -> str:
return "{}({}.{})".format(
self.__class__.__name__,
self._adapter.__class__.__module__,
self._adapter.__class__.__name__,
)
class NewHFChatModelAdapter(LLMModelAdaper):
def load(self, model_path: str, from_pretrained_kwargs: dict):
try:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
"Please install it with `pip install transformers`."
) from exc
if not transformers.__version__ >= "4.34.0":
raise ValueError(
"Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
)
revision = from_pretrained_kwargs.get("revision", "main")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=self.use_fast_tokenizer,
revision=revision,
trust_remote_code=True,
)
except TypeError:
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, revision=revision, trust_remote_code=True
)
try:
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
except NameError:
model = AutoModel.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
# tokenizer.use_default_system_prompt = False
return model, tokenizer
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
return huggingface_chat_generate_stream
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
from transformers import AutoTokenizer
if not tokenizer:
raise ValueError("tokenizer is is None")
tokenizer: AutoTokenizer = tokenizer
messages = ModelMessage.to_openai_messages(messages)
str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return str_prompt
def get_conv_template(name: str) -> "Conversation":
"""Get a conversation template."""
from fastchat.conversation import get_conv_template
return get_conv_template(name)
@cache
def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
try:
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
return adapter.get_default_conv_template(model_name, model_path)
except Exception:
return None
@cache
def get_llm_model_adapter(
model_name: str,
model_path: str,
use_fastchat: bool = True,
use_fastchat_monkey_patch: bool = False,
model_type: str = None,
) -> LLMModelAdaper:
if model_type == ModelType.VLLM:
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
return VLLMModelAdaperWrapper()
use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
if use_new_hf_chat_models:
logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
return NewHFChatModelAdapter()
must_use_old = any(m in model_name for m in _OLD_MODELS)
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
adapter = _get_fastchat_model_adapter(
model_name,
model_path,
_fastchat_get_adapter_monkey_patch,
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
)
return FastChatLLMModelAdaperWrapper(adapter)
else:
from dbgpt.model.adapter import (
get_llm_model_adapter as _old_get_llm_model_adapter,
)
from dbgpt.app.chat_adapter import get_llm_chat_adapter
logger.info("Use DB-GPT old adapter")
return OldLLMModelAdaperWrapper(
_old_get_llm_model_adapter(model_name, model_path),
get_llm_chat_adapter(model_name, model_path),
)
def _get_fastchat_model_adapter(
model_name: str,
model_path: str,
caller: Callable[[str], None] = None,
use_fastchat_monkey_patch: bool = False,
):
from fastchat.model import model_adapter
_bak_get_model_adapter = model_adapter.get_model_adapter
try:
if use_fastchat_monkey_patch:
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
thread_local.model_name = model_name
_remove_black_list_model_of_fastchat()
if caller:
return caller(model_path)
finally:
del thread_local.model_name
model_adapter.get_model_adapter = _bak_get_model_adapter
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
if not model_name:
if not hasattr(thread_local, "model_name"):
raise RuntimeError("fastchat get adapter monkey path need model_name")
model_name = thread_local.model_name
from fastchat.model.model_adapter import model_adapters
for adapter in model_adapters:
if adapter.match(model_name):
logger.info(
f"Found llm model adapter with model name: {model_name}, {adapter}"
)
return adapter
model_path_basename = (
None if not model_path else os.path.basename(os.path.normpath(model_path))
)
for adapter in model_adapters:
if model_path_basename and adapter.match(model_path_basename):
logger.info(
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
)
return adapter
for adapter in model_adapters:
if model_path and adapter.match(model_path):
logger.info(
f"Found llm model adapter with model path: {model_path}, {adapter}"
)
return adapter
raise ValueError(
f"Invalid model adapter for model name {model_name} and model path {model_path}"
)
@cache
def _remove_black_list_model_of_fastchat():
from fastchat.model.model_adapter import model_adapters
black_list_models = []
for adapter in model_adapters:
try:
if (
adapter.get_default_conv_template("/data/not_exist_model_path").name
in _BLACK_LIST_MODLE_PROMPT
):
black_list_models.append(adapter)
except Exception:
pass
for adapter in black_list_models:
model_adapters.remove(adapter)
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from dbgpt.util.parameter_utils import _SimpleArgParser
from dbgpt.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]
class VLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping vllm engine"""
def model_type(self) -> str:
return ModelType.VLLM
def model_param_class(self, model_type: str = None) -> ModelParameters:
import argparse
from vllm.engine.arg_utils import AsyncEngineArgs
parser = argparse.ArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument("--model_name", type=str, help="model name")
parser.add_argument(
"--model_path",
type=str,
help="local model path of the huggingface model to use",
)
parser.add_argument("--model_type", type=str, help="model type")
parser.add_argument("--device", type=str, default=None, help="device")
# TODO parse prompt templete from `model_name` and `model_path`
parser.add_argument(
"--prompt_template",
type=str,
default=None,
help="Prompt template. If None, the prompt template is automatically determined from model path",
)
descs = _extract_parameter_details(
parser,
"dbgpt.model.parameter.VLLMModelParameters",
skip_names=["model"],
overwrite_default_values={"trust_remote_code": True},
)
return _build_parameter_class(descs)
def load_from_params(self, params):
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
import torch
num_gpus = torch.cuda.device_count()
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
setattr(params, "tensor_parallel_size", num_gpus)
logger.info(
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
)
params = dataclasses.asdict(params)
params["model"] = params["model_path"]
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
# Set the attributes from the parsed arguments.
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, engine.engine.tokenizer
def support_async(self) -> bool:
return True
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.llm_out.vllm_llm import generate_stream
return generate_stream
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return _auto_get_conv_template(model_name, model_path)
def __str__(self) -> str:
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
register_conv_template(
Conversation(
name="aquila-legacy",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("### Human: ", "### Assistant: ", "System"),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="\n",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
register_conv_template(
Conversation(
name="aquila",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant", "System"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_TWO,
sep="###",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
register_conv_template(
Conversation(
name="aquila-v1",
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
),
override=True,
)