mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
feat(model): Support Mixtral-8x7B (#959)
This commit is contained in:
@@ -103,7 +103,8 @@ At present, we have introduced several key features to showcase our current capa
|
||||
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
|
||||
|
||||
- News
|
||||
- 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
||||
- 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
||||
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||
- [More Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf)
|
||||
|
||||
|
@@ -111,7 +111,8 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
|
||||
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
|
||||
|
||||
- 新增支持模型
|
||||
- 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
||||
- 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
||||
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||
- [更多开源模型](https://www.yuque.com/eosphoros/dbgpt-docs/iqaaqwriwhp6zslc#qQktR)
|
||||
|
||||
|
@@ -245,7 +245,7 @@ class WizardLMChatAdapter(BaseChatAdpter):
|
||||
|
||||
class LlamaCppChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
from dbgpt.model.adapter import LlamaCppAdapater
|
||||
from dbgpt.model.adapter.old_adapter import LlamaCppAdapater
|
||||
|
||||
if "llama-cpp" == model_path:
|
||||
return True
|
||||
|
@@ -113,7 +113,9 @@ LLM_MODEL_CONFIG = {
|
||||
# https://huggingface.co/microsoft/Orca-2-13b
|
||||
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
|
||||
# https://huggingface.co/openchat/openchat_3.5
|
||||
"openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
|
||||
"openchat-3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
|
||||
# https://huggingface.co/openchat/openchat-3.5-1210
|
||||
"openchat-3.5-1210": os.path.join(MODEL_PATH, "openchat-3.5-1210"),
|
||||
# https://huggingface.co/hfl/chinese-alpaca-2-7b
|
||||
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
|
||||
# https://huggingface.co/hfl/chinese-alpaca-2-13b
|
||||
@@ -124,6 +126,10 @@ LLM_MODEL_CONFIG = {
|
||||
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
|
||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
|
||||
# https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
"mixtral-8x7b-instruct-v0.1": os.path.join(
|
||||
MODEL_PATH, "Mixtral-8x7B-Instruct-v0.1"
|
||||
),
|
||||
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
|
||||
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1
|
||||
|
0
dbgpt/model/adapter/__init__.py
Normal file
0
dbgpt/model/adapter/__init__.py
Normal file
437
dbgpt/model/adapter/base.py
Normal file
437
dbgpt/model/adapter/base.py
Normal file
@@ -0,0 +1,437 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any, Tuple, Type, Callable
|
||||
import logging
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import (
|
||||
BaseModelParameters,
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from dbgpt.model.adapter.template import (
|
||||
get_conv_template,
|
||||
ConversationAdapter,
|
||||
ConversationAdapterFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMModelAdapter(ABC):
|
||||
"""New Adapter for DB-GPT LLM models"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
model_path: Optional[str] = None
|
||||
conv_factory: Optional[ConversationAdapterFactory] = None
|
||||
# TODO: more flexible quantization config
|
||||
support_4bit: bool = False
|
||||
support_8bit: bool = False
|
||||
support_system_message: bool = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} model_name={self.model_name} model_path={self.model_path}>"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
@abstractmethod
|
||||
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
|
||||
"""Create a new adapter instance
|
||||
|
||||
Args:
|
||||
**kwargs: The parameters of the new adapter instance
|
||||
|
||||
Returns:
|
||||
LLMModelAdapter: The new adapter instance
|
||||
"""
|
||||
|
||||
def use_fast_tokenizer(self) -> bool:
|
||||
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
|
||||
for a given model.
|
||||
"""
|
||||
return False
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.HF
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> Type[BaseModelParameters]:
|
||||
"""Get the startup parameters instance of the model
|
||||
|
||||
Args:
|
||||
model_type (str, optional): The type of model. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Type[BaseModelParameters]: The startup parameters instance of the model
|
||||
"""
|
||||
# """Get the startup parameters instance of the model"""
|
||||
model_type = model_type if model_type else self.model_type()
|
||||
if model_type == ModelType.LLAMA_CPP:
|
||||
return LlamaCppModelParameters
|
||||
elif model_type == ModelType.PROXY:
|
||||
return ProxyModelParameters
|
||||
return ModelParameters
|
||||
|
||||
def match(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Whether the model adapter can load the given model
|
||||
|
||||
Args:
|
||||
model_type (str): The type of model
|
||||
model_name (Optional[str], optional): The name of model. Defaults to None.
|
||||
model_path (Optional[str], optional): The path of model. Defaults to None.
|
||||
"""
|
||||
return False
|
||||
|
||||
def support_quantization_4bit(self) -> bool:
|
||||
"""Whether the model adapter can load 4bit model
|
||||
|
||||
If it is True, we will load the 4bit model with :meth:`~LLMModelAdapter.load`
|
||||
|
||||
Returns:
|
||||
bool: Whether the model adapter can load 4bit model, default is False
|
||||
"""
|
||||
return self.support_4bit
|
||||
|
||||
def support_quantization_8bit(self) -> bool:
|
||||
"""Whether the model adapter can load 8bit model
|
||||
|
||||
If it is True, we will load the 8bit model with :meth:`~LLMModelAdapter.load`
|
||||
|
||||
Returns:
|
||||
bool: Whether the model adapter can load 8bit model, default is False
|
||||
"""
|
||||
return self.support_8bit
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
"""Load model and tokenizer"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_from_params(self, params):
|
||||
"""Load the model and tokenizer according to the given parameters"""
|
||||
raise NotImplementedError
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
return False
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the asynchronous generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
"""Get the default conversation template
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model.
|
||||
model_path (str): The path of the model.
|
||||
|
||||
Returns:
|
||||
Optional[ConversationAdapter]: The conversation template.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_default_message_separator(self) -> str:
|
||||
"""Get the default message separator"""
|
||||
try:
|
||||
conv_template = self.get_default_conv_template(
|
||||
self.model_name, self.model_path
|
||||
)
|
||||
return conv_template.sep
|
||||
except Exception:
|
||||
return "\n"
|
||||
|
||||
def transform_model_messages(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Transform the model messages
|
||||
|
||||
Default is the OpenAI format, example:
|
||||
.. code-block:: python
|
||||
return_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"}
|
||||
]
|
||||
|
||||
But some model may need to transform the messages to other format(e.g. There is no system message), such as:
|
||||
.. code-block:: python
|
||||
return_messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"}
|
||||
]
|
||||
Args:
|
||||
messages (List[ModelMessage]): The model messages
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: The transformed model messages
|
||||
"""
|
||||
logger.info(f"support_system_message: {self.support_system_message}")
|
||||
if not self.support_system_message:
|
||||
return self._transform_to_no_system_messages(messages)
|
||||
else:
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
|
||||
def _transform_to_no_system_messages(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Transform the model messages to no system messages
|
||||
|
||||
Some opensource chat model no system messages, so wo should transform the messages to no system messages.
|
||||
|
||||
Merge the system messages to the last user message, example:
|
||||
.. code-block:: python
|
||||
return_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"}
|
||||
]
|
||||
=>
|
||||
return_messages = [
|
||||
{"role": "user", "content": "You are a helpful assistant\nHello"},
|
||||
{"role": "assistant", "content": "Hi"}
|
||||
]
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The model messages
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: The transformed model messages
|
||||
"""
|
||||
openai_messages = ModelMessage.to_openai_messages(messages)
|
||||
system_messages = []
|
||||
return_messages = []
|
||||
for message in openai_messages:
|
||||
if message["role"] == "system":
|
||||
system_messages.append(message["content"])
|
||||
else:
|
||||
return_messages.append(message)
|
||||
if len(system_messages) > 1:
|
||||
# Too much system messages should be a warning
|
||||
logger.warning("Your system messages have more than one message")
|
||||
if system_messages:
|
||||
sep = self.get_default_message_separator()
|
||||
str_system_messages = ",".join(system_messages)
|
||||
# Update last user message
|
||||
return_messages[-1]["content"] = (
|
||||
str_system_messages + sep + return_messages[-1]["content"]
|
||||
)
|
||||
return return_messages
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Optional[str]:
|
||||
"""Get the string prompt from the given parameters and messages
|
||||
|
||||
If the value of return is not None, we will skip :meth:`~LLMModelAdapter.get_prompt_with_template` and use the value of return.
|
||||
|
||||
Args:
|
||||
params (Dict): The parameters
|
||||
messages (List[ModelMessage]): The model messages
|
||||
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
|
||||
prompt_template (str, optional): The prompt template. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The string prompt
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_prompt_with_template(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
model_context: Dict,
|
||||
prompt_template: str = None,
|
||||
):
|
||||
conv: ConversationAdapter = self.get_default_conv_template(
|
||||
model_name, model_path
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
logger.info(f"Use prompt template {prompt_template} from config")
|
||||
conv = get_conv_template(prompt_template)
|
||||
if not conv or not messages:
|
||||
# Nothing to do
|
||||
logger.info(
|
||||
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
conv = conv.copy()
|
||||
system_messages = []
|
||||
user_messages = []
|
||||
ai_messages = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, ModelMessage):
|
||||
role = message.role
|
||||
content = message.content
|
||||
elif isinstance(message, dict):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
else:
|
||||
raise ValueError(f"Invalid message type: {message}")
|
||||
|
||||
if role == ModelMessageRoleType.SYSTEM:
|
||||
# Support for multiple system messages
|
||||
system_messages.append(content)
|
||||
elif role == ModelMessageRoleType.HUMAN:
|
||||
# conv.append_message(conv.roles[0], content)
|
||||
user_messages.append(content)
|
||||
elif role == ModelMessageRoleType.AI:
|
||||
# conv.append_message(conv.roles[1], content)
|
||||
ai_messages.append(content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
|
||||
can_use_systems: [] = []
|
||||
if system_messages:
|
||||
if len(system_messages) > 1:
|
||||
# Compatible with dbgpt complex scenarios, the last system will protect more complete information
|
||||
# entered by the current user
|
||||
user_messages[-1] = system_messages[-1]
|
||||
can_use_systems = system_messages[:-1]
|
||||
else:
|
||||
can_use_systems = system_messages
|
||||
|
||||
for i in range(len(user_messages)):
|
||||
conv.append_message(conv.roles[0], user_messages[i])
|
||||
if i < len(ai_messages):
|
||||
conv.append_message(conv.roles[1], ai_messages[i])
|
||||
|
||||
# TODO join all system messages may not be a good idea
|
||||
conv.set_system_message("".join(can_use_systems))
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
new_prompt = conv.get_prompt()
|
||||
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||
|
||||
def model_adaptation(
|
||||
self,
|
||||
params: Dict,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
messages = params.get("messages")
|
||||
# Some model context to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
if messages:
|
||||
# Dict message to ModelMessage
|
||||
messages = [
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in messages
|
||||
]
|
||||
params["messages"] = messages
|
||||
|
||||
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
||||
conv_stop_str, conv_stop_token_ids = None, None
|
||||
if not new_prompt:
|
||||
(
|
||||
new_prompt,
|
||||
conv_stop_str,
|
||||
conv_stop_token_ids,
|
||||
) = self.get_prompt_with_template(
|
||||
params, messages, model_name, model_path, model_context, prompt_template
|
||||
)
|
||||
if not new_prompt:
|
||||
return params, model_context
|
||||
|
||||
# Overwrite the original prompt
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||
model_context["echo"] = params.get("echo", True)
|
||||
model_context["has_format_prompt"] = True
|
||||
params["prompt"] = new_prompt
|
||||
|
||||
custom_stop = params.get("stop")
|
||||
custom_stop_token_ids = params.get("stop_token_ids")
|
||||
|
||||
# Prefer the value passed in from the input parameter
|
||||
params["stop"] = custom_stop or conv_stop_str
|
||||
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
|
||||
|
||||
return params, model_context
|
||||
|
||||
|
||||
class AdapterEntry:
|
||||
"""The entry of model adapter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_adapter: LLMModelAdapter,
|
||||
match_funcs: List[Callable[[str, str, str], bool]] = None,
|
||||
):
|
||||
self.model_adapter = model_adapter
|
||||
self.match_funcs = match_funcs or []
|
||||
|
||||
|
||||
model_adapters: List[AdapterEntry] = []
|
||||
|
||||
|
||||
def register_model_adapter(
|
||||
model_adapter_cls: Type[LLMModelAdapter],
|
||||
match_funcs: List[Callable[[str, str, str], bool]] = None,
|
||||
) -> None:
|
||||
"""Register a model adapter.
|
||||
|
||||
Args:
|
||||
model_adapter_cls (Type[LLMModelAdapter]): The model adapter class.
|
||||
match_funcs (List[Callable[[str, str, str], bool]], optional): The match functions. Defaults to None.
|
||||
"""
|
||||
model_adapters.append(AdapterEntry(model_adapter_cls(), match_funcs))
|
||||
|
||||
|
||||
def get_model_adapter(
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
conv_factory: Optional[ConversationAdapterFactory] = None,
|
||||
) -> Optional[LLMModelAdapter]:
|
||||
"""Get a model adapter.
|
||||
|
||||
Args:
|
||||
model_type (str): The type of the model.
|
||||
model_name (str): The name of the model.
|
||||
model_path (str): The path of the model.
|
||||
conv_factory (Optional[ConversationAdapterFactory], optional): The conversation factory. Defaults to None.
|
||||
Returns:
|
||||
Optional[LLMModelAdapter]: The model adapter.
|
||||
"""
|
||||
adapter = None
|
||||
# First find adapter by model_name
|
||||
for adapter_entry in model_adapters:
|
||||
if adapter_entry.model_adapter.match(model_type, model_name, None):
|
||||
adapter = adapter_entry.model_adapter
|
||||
break
|
||||
for adapter_entry in model_adapters:
|
||||
if adapter_entry.model_adapter.match(model_type, None, model_path):
|
||||
adapter = adapter_entry.model_adapter
|
||||
break
|
||||
if adapter:
|
||||
new_adapter = adapter.new_adapter()
|
||||
new_adapter.model_name = model_name
|
||||
new_adapter.model_path = model_path
|
||||
if conv_factory:
|
||||
new_adapter.conv_factory = conv_factory
|
||||
return new_adapter
|
||||
return None
|
262
dbgpt/model/adapter/fschat_adapter.py
Normal file
262
dbgpt/model/adapter/fschat_adapter.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Adapter for fastchat
|
||||
|
||||
You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it.
|
||||
"""
|
||||
import os
|
||||
import threading
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Callable, Tuple, List, Optional
|
||||
|
||||
try:
|
||||
from fastchat.conversation import (
|
||||
Conversation,
|
||||
register_conv_template,
|
||||
SeparatorStyle,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: fschat "
|
||||
"Please install fastchat by command `pip install fschat` "
|
||||
) from exc
|
||||
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
from torch.nn import Module as TorchNNModule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
thread_local = threading.local()
|
||||
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||
|
||||
# If some model is not in the blacklist, but it still affects the loading of DB-GPT, you can add it to the blacklist.
|
||||
__BLACK_LIST_MODEL_PROMPT = []
|
||||
|
||||
|
||||
class FschatConversationAdapter(ConversationAdapter):
|
||||
"""The conversation adapter for fschat."""
|
||||
|
||||
def __init__(self, conv: Conversation):
|
||||
self._conv = conv
|
||||
|
||||
@property
|
||||
def prompt_type(self) -> PromptType:
|
||||
return PromptType.FSCHAT
|
||||
|
||||
@property
|
||||
def roles(self) -> Tuple[str]:
|
||||
return self._conv.roles
|
||||
|
||||
@property
|
||||
def sep(self) -> Optional[str]:
|
||||
return self._conv.sep
|
||||
|
||||
@property
|
||||
def stop_str(self) -> str:
|
||||
return self._conv.stop_str
|
||||
|
||||
@property
|
||||
def stop_token_ids(self) -> Optional[List[int]]:
|
||||
return self._conv.stop_token_ids
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt string."""
|
||||
return self._conv.get_prompt()
|
||||
|
||||
def set_system_message(self, system_message: str) -> None:
|
||||
"""Set the system message."""
|
||||
self._conv.set_system_message(system_message)
|
||||
|
||||
def append_message(self, role: str, message: str) -> None:
|
||||
"""Append a new message.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
message (str): The message content.
|
||||
"""
|
||||
self._conv.append_message(role, message)
|
||||
|
||||
def update_last_message(self, message: str) -> None:
|
||||
"""Update the last output.
|
||||
|
||||
The last message is typically set to be None when constructing the prompt,
|
||||
so we need to update it in-place after getting the response from a model.
|
||||
|
||||
Args:
|
||||
message (str): The message content.
|
||||
"""
|
||||
self._conv.update_last_message(message)
|
||||
|
||||
def copy(self) -> "ConversationAdapter":
|
||||
"""Copy the conversation."""
|
||||
return FschatConversationAdapter(self._conv.copy())
|
||||
|
||||
|
||||
class FastChatLLMModelAdapterWrapper(LLMModelAdapter):
|
||||
"""Wrapping fastchat adapter"""
|
||||
|
||||
def __init__(self, adapter: "BaseModelAdapter") -> None:
|
||||
self._adapter = adapter
|
||||
|
||||
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
|
||||
return FastChatLLMModelAdapterWrapper(self._adapter)
|
||||
|
||||
def use_fast_tokenizer(self) -> bool:
|
||||
return self._adapter.use_fast_tokenizer
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
return self._adapter.load_model(model_path, from_pretrained_kwargs)
|
||||
|
||||
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
|
||||
if _IS_BENCHMARK:
|
||||
from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
|
||||
generate_stream,
|
||||
)
|
||||
|
||||
return generate_stream
|
||||
else:
|
||||
from fastchat.model.model_adapter import get_generate_stream_function
|
||||
|
||||
return get_generate_stream_function(model, model_path)
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
conv_template = self._adapter.get_default_conv_template(model_path)
|
||||
return FschatConversationAdapter(conv_template) if conv_template else None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}({}.{})".format(
|
||||
self.__class__.__name__,
|
||||
self._adapter.__class__.__module__,
|
||||
self._adapter.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
def _get_fastchat_model_adapter(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
caller: Callable[[str], None] = None,
|
||||
use_fastchat_monkey_patch: bool = False,
|
||||
):
|
||||
from fastchat.model import model_adapter
|
||||
|
||||
_bak_get_model_adapter = model_adapter.get_model_adapter
|
||||
try:
|
||||
if use_fastchat_monkey_patch:
|
||||
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
||||
thread_local.model_name = model_name
|
||||
_remove_black_list_model_of_fastchat()
|
||||
if caller:
|
||||
return caller(model_path)
|
||||
finally:
|
||||
del thread_local.model_name
|
||||
model_adapter.get_model_adapter = _bak_get_model_adapter
|
||||
|
||||
|
||||
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
|
||||
if not model_name:
|
||||
if not hasattr(thread_local, "model_name"):
|
||||
raise RuntimeError("fastchat get adapter monkey path need model_name")
|
||||
model_name = thread_local.model_name
|
||||
from fastchat.model.model_adapter import model_adapters
|
||||
|
||||
for adapter in model_adapters:
|
||||
if adapter.match(model_name):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model name: {model_name}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
model_path_basename = (
|
||||
None if not model_path else os.path.basename(os.path.normpath(model_path))
|
||||
)
|
||||
for adapter in model_adapters:
|
||||
if model_path_basename and adapter.match(model_path_basename):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
for adapter in model_adapters:
|
||||
if model_path and adapter.match(model_path):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model path: {model_path}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid model adapter for model name {model_name} and model path {model_path}"
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _remove_black_list_model_of_fastchat():
|
||||
from fastchat.model.model_adapter import model_adapters
|
||||
|
||||
black_list_models = []
|
||||
for adapter in model_adapters:
|
||||
try:
|
||||
if (
|
||||
adapter.get_default_conv_template("/data/not_exist_model_path").name
|
||||
in __BLACK_LIST_MODEL_PROMPT
|
||||
):
|
||||
black_list_models.append(adapter)
|
||||
except Exception:
|
||||
pass
|
||||
for adapter in black_list_models:
|
||||
model_adapters.remove(adapter)
|
||||
|
||||
|
||||
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
|
||||
# We also recommend that you modify it directly in the fastchat repository.
|
||||
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila-legacy",
|
||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("### Human: ", "### Assistant: ", "System"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||
sep="\n",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila",
|
||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant", "System"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
||||
sep="###",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila-v1",
|
||||
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||
sep="",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "<|endoftext|>"],
|
||||
),
|
||||
override=True,
|
||||
)
|
136
dbgpt/model/adapter/hf_adapter.py
Normal file
136
dbgpt/model/adapter/hf_adapter.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, List, Any
|
||||
import logging
|
||||
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NewHFChatModelAdapter(LLMModelAdapter, ABC):
|
||||
"""Model adapter for new huggingface chat models
|
||||
|
||||
See https://huggingface.co/docs/transformers/main/en/chat_templating
|
||||
|
||||
We can transform the inference chat messages to chat model instead of create a
|
||||
prompt template for this model
|
||||
"""
|
||||
|
||||
def new_adapter(self, **kwargs) -> "NewHFChatModelAdapter":
|
||||
return self.__class__()
|
||||
|
||||
def match(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
if model_type != ModelType.HF:
|
||||
return False
|
||||
if model_name is None and model_path is None:
|
||||
return False
|
||||
model_name = model_name.lower() if model_name else None
|
||||
model_path = model_path.lower() if model_path else None
|
||||
return self.do_match(model_name) or self.do_match(model_path)
|
||||
|
||||
@abstractmethod
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
try:
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import depend python package "
|
||||
"Please install it with `pip install transformers`."
|
||||
) from exc
|
||||
if not transformers.__version__ >= "4.34.0":
|
||||
raise ValueError(
|
||||
"Current model (Load by NewHFChatModelAdapter) require transformers.__version__>=4.34.0"
|
||||
)
|
||||
revision = from_pretrained_kwargs.get("revision", "main")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=self.use_fast_tokenizer(),
|
||||
revision=revision,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except TypeError:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, use_fast=False, revision=revision, trust_remote_code=True
|
||||
)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
except NameError:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
# tokenizer.use_default_system_prompt = False
|
||||
return model, tokenizer
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
|
||||
|
||||
return huggingface_chat_generate_stream
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Optional[str]:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if not tokenizer:
|
||||
raise ValueError("tokenizer is is None")
|
||||
tokenizer: AutoTokenizer = tokenizer
|
||||
|
||||
messages = self.transform_model_messages(messages)
|
||||
logger.debug(f"The messages after transform: \n{messages}")
|
||||
str_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return str_prompt
|
||||
|
||||
|
||||
class YiAdapter(NewHFChatModelAdapter):
|
||||
support_4bit: bool = True
|
||||
support_8bit: bool = True
|
||||
support_system_message: bool = True
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return (
|
||||
lower_model_name_or_path
|
||||
and "yi-" in lower_model_name_or_path
|
||||
and "chat" in lower_model_name_or_path
|
||||
)
|
||||
|
||||
|
||||
class Mixtral8x7BAdapter(NewHFChatModelAdapter):
|
||||
"""
|
||||
https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
"""
|
||||
|
||||
support_4bit: bool = True
|
||||
support_8bit: bool = True
|
||||
support_system_message: bool = False
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return (
|
||||
lower_model_name_or_path
|
||||
and "mixtral" in lower_model_name_or_path
|
||||
and "8x7b" in lower_model_name_or_path
|
||||
)
|
||||
|
||||
|
||||
register_model_adapter(YiAdapter)
|
||||
register_model_adapter(Mixtral8x7BAdapter)
|
166
dbgpt/model/adapter/model_adapter.py
Normal file
166
dbgpt/model/adapter/model_adapter.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
List,
|
||||
Type,
|
||||
Optional,
|
||||
)
|
||||
import logging
|
||||
import threading
|
||||
import os
|
||||
from functools import cache
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import BaseModelParameters
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
|
||||
from dbgpt.model.adapter.template import (
|
||||
ConversationAdapter,
|
||||
ConversationAdapterFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
thread_local = threading.local()
|
||||
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||
|
||||
|
||||
_OLD_MODELS = [
|
||||
"llama-cpp",
|
||||
"proxyllm",
|
||||
"gptj-6b",
|
||||
"codellama-13b-sql-sft",
|
||||
"codellama-7b",
|
||||
"codellama-7b-sql-sft",
|
||||
"codellama-13b",
|
||||
]
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_model_adapter(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
use_fastchat: bool = True,
|
||||
use_fastchat_monkey_patch: bool = False,
|
||||
model_type: str = None,
|
||||
) -> LLMModelAdapter:
|
||||
conv_factory = DefaultConversationAdapterFactory()
|
||||
if model_type == ModelType.VLLM:
|
||||
logger.info("Current model type is vllm, return VLLMModelAdapterWrapper")
|
||||
from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
|
||||
|
||||
return VLLMModelAdapterWrapper(conv_factory)
|
||||
|
||||
# Import NewHFChatModelAdapter for it can be registered
|
||||
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
|
||||
|
||||
new_model_adapter = get_model_adapter(
|
||||
model_type, model_name, model_path, conv_factory
|
||||
)
|
||||
if new_model_adapter:
|
||||
logger.info(f"Current model {model_name} use new adapter {new_model_adapter}")
|
||||
return new_model_adapter
|
||||
|
||||
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
||||
result_adapter: Optional[LLMModelAdapter] = None
|
||||
if use_fastchat and not must_use_old:
|
||||
logger.info("Use fastcat adapter")
|
||||
from dbgpt.model.adapter.fschat_adapter import (
|
||||
_get_fastchat_model_adapter,
|
||||
_fastchat_get_adapter_monkey_patch,
|
||||
FastChatLLMModelAdapterWrapper,
|
||||
)
|
||||
|
||||
adapter = _get_fastchat_model_adapter(
|
||||
model_name,
|
||||
model_path,
|
||||
_fastchat_get_adapter_monkey_patch,
|
||||
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
|
||||
)
|
||||
if adapter:
|
||||
result_adapter = FastChatLLMModelAdapterWrapper(adapter)
|
||||
|
||||
else:
|
||||
from dbgpt.model.adapter.old_adapter import (
|
||||
get_llm_model_adapter as _old_get_llm_model_adapter,
|
||||
OldLLMModelAdapterWrapper,
|
||||
)
|
||||
from dbgpt.app.chat_adapter import get_llm_chat_adapter
|
||||
|
||||
logger.info("Use DB-GPT old adapter")
|
||||
result_adapter = OldLLMModelAdapterWrapper(
|
||||
_old_get_llm_model_adapter(model_name, model_path),
|
||||
get_llm_chat_adapter(model_name, model_path),
|
||||
)
|
||||
if result_adapter:
|
||||
result_adapter.model_name = model_name
|
||||
result_adapter.model_path = model_path
|
||||
result_adapter.conv_factory = conv_factory
|
||||
return result_adapter
|
||||
else:
|
||||
raise ValueError(f"Can not find adapter for model {model_name}")
|
||||
|
||||
|
||||
@cache
|
||||
def _auto_get_conv_template(
|
||||
model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
"""Auto get the conversation template.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model.
|
||||
model_path (str): The path of the model.
|
||||
|
||||
Returns:
|
||||
Optional[ConversationAdapter]: The conversation template.
|
||||
"""
|
||||
try:
|
||||
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
|
||||
return adapter.get_default_conv_template(model_name, model_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get conv template for {model_name} {model_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class DefaultConversationAdapterFactory(ConversationAdapterFactory):
|
||||
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
|
||||
"""Get a conversation adapter by model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model.
|
||||
model_path (str): The path of the model.
|
||||
Returns:
|
||||
ConversationAdapter: The conversation adapter.
|
||||
"""
|
||||
return _auto_get_conv_template(model_name, model_path)
|
||||
|
||||
|
||||
def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
|
||||
"""Dynamic model parser, parse the model parameters from the command line arguments.
|
||||
|
||||
Returns:
|
||||
Optional[List[Type[BaseModelParameters]]]: The model parameters class list.
|
||||
"""
|
||||
from dbgpt.util.parameter_utils import _SimpleArgParser
|
||||
from dbgpt.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
WorkerType,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
model_path = pre_args.get("model_path")
|
||||
worker_type = pre_args.get("worker_type")
|
||||
model_type = pre_args.get("model_type")
|
||||
if model_name is None and model_type != ModelType.VLLM:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
model_name, EmbeddingModelParameters
|
||||
)
|
||||
]
|
||||
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
@@ -9,7 +9,7 @@ import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, TYPE_CHECKING, Optional
|
||||
from functools import cache
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
@@ -17,6 +17,9 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
from dbgpt.model.base import ModelType
|
||||
|
||||
from dbgpt.model.parameter import (
|
||||
@@ -24,9 +27,13 @@ from dbgpt.model.parameter import (
|
||||
LlamaCppModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from dbgpt.model.conversation import Conversation
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.app.chat_adapter import BaseChatAdpter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
@@ -92,17 +99,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
||||
)
|
||||
|
||||
|
||||
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
|
||||
try:
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
return llm_adapter.model_param_class()
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`"
|
||||
)
|
||||
return ModelParameters
|
||||
|
||||
|
||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||
|
||||
|
||||
@@ -426,6 +422,87 @@ class InternLMAdapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class OldLLMModelAdapterWrapper(LLMModelAdapter):
|
||||
"""Wrapping old adapter, which may be removed later"""
|
||||
|
||||
def __init__(self, adapter: BaseLLMAdaper, chat_adapter: "BaseChatAdpter") -> None:
|
||||
self._adapter = adapter
|
||||
self._chat_adapter = chat_adapter
|
||||
|
||||
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
|
||||
return OldLLMModelAdapterWrapper(self._adapter, self._chat_adapter)
|
||||
|
||||
def use_fast_tokenizer(self) -> bool:
|
||||
return self._adapter.use_fast_tokenizer()
|
||||
|
||||
def model_type(self) -> str:
|
||||
return self._adapter.model_type()
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||
return self._adapter.model_param_class(model_type)
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
conv_template = self._chat_adapter.get_conv_template(model_path)
|
||||
return OldConversationAdapter(conv_template) if conv_template else None
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
return self._adapter.loader(model_path, from_pretrained_kwargs)
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
return self._chat_adapter.get_generate_stream_func(model_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}({}.{})".format(
|
||||
self.__class__.__name__,
|
||||
self._adapter.__class__.__module__,
|
||||
self._adapter.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
class OldConversationAdapter(ConversationAdapter):
|
||||
"""Wrapping old Conversation, which may be removed later"""
|
||||
|
||||
def __init__(self, conv: Conversation) -> None:
|
||||
self._conv = conv
|
||||
|
||||
@property
|
||||
def prompt_type(self) -> PromptType:
|
||||
return PromptType.DBGPT
|
||||
|
||||
@property
|
||||
def roles(self) -> Tuple[str]:
|
||||
return self._conv.roles
|
||||
|
||||
@property
|
||||
def sep(self) -> Optional[str]:
|
||||
return self._conv.sep
|
||||
|
||||
@property
|
||||
def stop_str(self) -> str:
|
||||
return self._conv.stop_str
|
||||
|
||||
@property
|
||||
def stop_token_ids(self) -> Optional[List[int]]:
|
||||
return self._conv.stop_token_ids
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
return self._conv.get_prompt()
|
||||
|
||||
def set_system_message(self, system_message: str) -> None:
|
||||
self._conv.update_system_message(system_message)
|
||||
|
||||
def append_message(self, role: str, message: str) -> None:
|
||||
self._conv.append_message(role, message)
|
||||
|
||||
def update_last_message(self, message: str) -> None:
|
||||
self._conv.update_last_message(message)
|
||||
|
||||
def copy(self) -> "ConversationAdapter":
|
||||
return OldConversationAdapter(self._conv.copy())
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
register_llm_model_adapters(GuanacoAdapter)
|
130
dbgpt/model/adapter/template.py
Normal file
130
dbgpt/model/adapter/template.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.conversation import Conversation
|
||||
|
||||
|
||||
class PromptType(str, Enum):
|
||||
"""Prompt type."""
|
||||
|
||||
FSCHAT: str = "fschat"
|
||||
DBGPT: str = "dbgpt"
|
||||
|
||||
|
||||
class ConversationAdapter(ABC):
|
||||
"""The conversation adapter."""
|
||||
|
||||
@property
|
||||
def prompt_type(self) -> PromptType:
|
||||
return PromptType.FSCHAT
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def roles(self) -> Tuple[str]:
|
||||
"""Get the roles of the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple[str]: The roles of the conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def sep(self) -> Optional[str]:
|
||||
"""Get the separator between messages."""
|
||||
return "\n"
|
||||
|
||||
@property
|
||||
def stop_str(self) -> Optional[Union[str, List[str]]]:
|
||||
"""Get the stop criteria."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def stop_token_ids(self) -> Optional[List[int]]:
|
||||
"""Stops generation if meeting any token in this list"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt string.
|
||||
|
||||
Returns:
|
||||
str: The prompt string.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_system_message(self, system_message: str) -> None:
|
||||
"""Set the system message."""
|
||||
|
||||
@abstractmethod
|
||||
def append_message(self, role: str, message: str) -> None:
|
||||
"""Append a new message.
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
message (str): The message content.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_last_message(self, message: str) -> None:
|
||||
"""Update the last output.
|
||||
|
||||
The last message is typically set to be None when constructing the prompt,
|
||||
so we need to update it in-place after getting the response from a model.
|
||||
|
||||
Args:
|
||||
message (str): The message content.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> "ConversationAdapter":
|
||||
"""Copy the conversation."""
|
||||
|
||||
|
||||
class ConversationAdapterFactory(ABC):
|
||||
"""The conversation adapter factory."""
|
||||
|
||||
def get_by_name(
|
||||
self,
|
||||
template_name: str,
|
||||
prompt_template_type: Optional[PromptType] = PromptType.FSCHAT,
|
||||
) -> ConversationAdapter:
|
||||
"""Get a conversation adapter by name.
|
||||
|
||||
Args:
|
||||
template_name (str): The name of the template.
|
||||
prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT.
|
||||
|
||||
Returns:
|
||||
ConversationAdapter: The conversation adapter.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
|
||||
"""Get a conversation adapter by model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model.
|
||||
model_path (str): The path of the model.
|
||||
|
||||
Returns:
|
||||
ConversationAdapter: The conversation adapter.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_conv_template(name: str) -> ConversationAdapter:
|
||||
"""Get a conversation template.
|
||||
|
||||
Args:
|
||||
name (str): The name of the template.
|
||||
|
||||
Just return the fastchat conversation template for now.
|
||||
# TODO: More templates should be supported.
|
||||
Returns:
|
||||
Conversation: The conversation template.
|
||||
"""
|
||||
from fastchat.conversation import get_conv_template
|
||||
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
|
||||
|
||||
conv_template = get_conv_template(name)
|
||||
return FschatConversationAdapter(conv_template)
|
93
dbgpt/model/adapter/vllm_adapter.py
Normal file
93
dbgpt/model/adapter/vllm_adapter.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||
from dbgpt.model.parameter import BaseModelParameters
|
||||
from dbgpt.util.parameter_utils import (
|
||||
_extract_parameter_details,
|
||||
_build_parameter_class,
|
||||
_get_dataclass_print_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VLLMModelAdapterWrapper(LLMModelAdapter):
|
||||
"""Wrapping vllm engine"""
|
||||
|
||||
def __init__(self, conv_factory: ConversationAdapterFactory):
|
||||
self.conv_factory = conv_factory
|
||||
|
||||
def new_adapter(self, **kwargs) -> "VLLMModelAdapterWrapper":
|
||||
return VLLMModelAdapterWrapper(self.conv_factory)
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.VLLM
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> BaseModelParameters:
|
||||
import argparse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--model_name", type=str, help="model name")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
help="local model path of the huggingface model to use",
|
||||
)
|
||||
parser.add_argument("--model_type", type=str, help="model type")
|
||||
parser.add_argument("--device", type=str, default=None, help="device")
|
||||
# TODO parse prompt templete from `model_name` and `model_path`
|
||||
parser.add_argument(
|
||||
"--prompt_template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template. If None, the prompt template is automatically determined from model path",
|
||||
)
|
||||
|
||||
descs = _extract_parameter_details(
|
||||
parser,
|
||||
"dbgpt.model.parameter.VLLMModelParameters",
|
||||
skip_names=["model"],
|
||||
overwrite_default_values={"trust_remote_code": True},
|
||||
)
|
||||
return _build_parameter_class(descs)
|
||||
|
||||
def load_from_params(self, params):
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
import torch
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
|
||||
setattr(params, "tensor_parallel_size", num_gpus)
|
||||
logger.info(
|
||||
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
|
||||
)
|
||||
|
||||
params = dataclasses.asdict(params)
|
||||
params["model"] = params["model_path"]
|
||||
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
|
||||
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
|
||||
# Set the attributes from the parsed arguments.
|
||||
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
return engine, engine.engine.tokenizer
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.llm_out.vllm_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> ConversationAdapter:
|
||||
return self.conv_factory.get_by_model(model_name, model_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
|
@@ -405,7 +405,7 @@ def stop_model_controller(port: int):
|
||||
|
||||
|
||||
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
from dbgpt.model.model_adapter import _dynamic_model_parser
|
||||
from dbgpt.model.adapter.model_adapter import _dynamic_model_parser
|
||||
|
||||
param_class = _dynamic_model_parser()
|
||||
fix_class = [ModelWorkerParameters]
|
||||
|
@@ -6,7 +6,8 @@ import time
|
||||
import traceback
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.core import ModelOutput, ModelInferenceMetrics
|
||||
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
@@ -27,7 +28,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self._model_params = None
|
||||
self.llm_adapter: LLMModelAdaper = None
|
||||
self.llm_adapter: LLMModelAdapter = None
|
||||
self._support_async = False
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
|
@@ -37,7 +37,7 @@ def list_supported_models():
|
||||
def _list_supported_models(
|
||||
worker_type: str, model_config: Dict[str, str]
|
||||
) -> List[SupportedModel]:
|
||||
from dbgpt.model.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.model.loader import _get_model_real_path
|
||||
|
||||
ret = []
|
||||
|
@@ -1,13 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from dataclasses import asdict
|
||||
import logging
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
@@ -117,7 +117,7 @@ class ModelLoader:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
def loader_with_params(
|
||||
self, model_params: ModelParameters, llm_adapter: LLMModelAdaper
|
||||
self, model_params: ModelParameters, llm_adapter: LLMModelAdapter
|
||||
):
|
||||
model_type = llm_adapter.model_type()
|
||||
self.prompt_template = model_params.prompt_template
|
||||
@@ -133,7 +133,7 @@ class ModelLoader:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
|
||||
def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters):
|
||||
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
|
||||
import torch
|
||||
from dbgpt.model.compression import compress_module
|
||||
|
||||
@@ -174,6 +174,12 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {device}")
|
||||
|
||||
model, tokenizer = _try_load_default_quantization_model(
|
||||
llm_adapter, device, num_gpus, model_params, kwargs
|
||||
)
|
||||
if model:
|
||||
return model, tokenizer
|
||||
|
||||
can_quantization = _check_quantization(model_params)
|
||||
|
||||
if can_quantization and (num_gpus > 1 or model_params.load_4bit):
|
||||
@@ -192,6 +198,46 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
|
||||
# TODO merge current code into `load_huggingface_quantization_model`
|
||||
compress_module(model, model_params.device)
|
||||
|
||||
return _handle_model_and_tokenizer(model, tokenizer, device, num_gpus, model_params)
|
||||
|
||||
|
||||
def _try_load_default_quantization_model(
|
||||
llm_adapter: LLMModelAdapter,
|
||||
device: str,
|
||||
num_gpus: int,
|
||||
model_params: ModelParameters,
|
||||
kwargs: Dict[str, Any],
|
||||
):
|
||||
"""Try load default quantization model(Support by huggingface default)"""
|
||||
cloned_kwargs = {k: v for k, v in kwargs.items()}
|
||||
try:
|
||||
model, tokenizer = None, None
|
||||
if device != "cuda":
|
||||
return None, None
|
||||
elif model_params.load_8bit and llm_adapter.support_8bit:
|
||||
cloned_kwargs["load_in_8bit"] = True
|
||||
model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
|
||||
elif model_params.load_4bit and llm_adapter.support_4bit:
|
||||
cloned_kwargs["load_in_4bit"] = True
|
||||
model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
|
||||
if model:
|
||||
logger.info(
|
||||
f"Load default quantization model {model_params.model_name} success"
|
||||
)
|
||||
return _handle_model_and_tokenizer(
|
||||
model, tokenizer, device, num_gpus, model_params
|
||||
)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Load default quantization model {model_params.model_name} failed, error: {str(e)}"
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
def _handle_model_and_tokenizer(
|
||||
model, tokenizer, device: str, num_gpus: int, model_params: ModelParameters
|
||||
):
|
||||
if (
|
||||
(device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
|
||||
or device == "mps"
|
||||
@@ -209,7 +255,7 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
|
||||
|
||||
|
||||
def load_huggingface_quantization_model(
|
||||
llm_adapter: LLMModelAdaper,
|
||||
llm_adapter: LLMModelAdapter,
|
||||
model_params: ModelParameters,
|
||||
kwargs: Dict,
|
||||
max_memory: Dict[int, str],
|
||||
@@ -344,7 +390,9 @@ def load_huggingface_quantization_model(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters):
|
||||
def llamacpp_loader(
|
||||
llm_adapter: LLMModelAdapter, model_params: LlamaCppModelParameters
|
||||
):
|
||||
try:
|
||||
from dbgpt.model.llm.llama_cpp.llama_cpp import LlamaCppModel
|
||||
except ImportError as exc:
|
||||
@@ -358,7 +406,7 @@ def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelPara
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters):
|
||||
def proxyllm_loader(llm_adapter: LLMModelAdapter, model_params: ProxyModelParameters):
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
logger.info("Load proxyllm")
|
||||
|
@@ -1,660 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import os
|
||||
from functools import cache
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util.parameter_utils import (
|
||||
_extract_parameter_details,
|
||||
_build_parameter_class,
|
||||
_get_dataclass_print_str,
|
||||
)
|
||||
|
||||
try:
|
||||
from fastchat.conversation import (
|
||||
Conversation,
|
||||
register_conv_template,
|
||||
SeparatorStyle,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: fschat "
|
||||
"Please install fastchat by command `pip install fschat` "
|
||||
) from exc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
from dbgpt.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
|
||||
from torch.nn import Module as TorchNNModule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
thread_local = threading.local()
|
||||
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||
|
||||
|
||||
_OLD_MODELS = [
|
||||
"llama-cpp",
|
||||
"proxyllm",
|
||||
"gptj-6b",
|
||||
"codellama-13b-sql-sft",
|
||||
"codellama-7b",
|
||||
"codellama-7b-sql-sft",
|
||||
"codellama-13b",
|
||||
]
|
||||
|
||||
_NEW_HF_CHAT_MODELS = [
|
||||
"yi-34b",
|
||||
"yi-6b",
|
||||
]
|
||||
|
||||
# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
|
||||
_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]
|
||||
|
||||
|
||||
class LLMModelAdaper:
|
||||
"""New Adapter for DB-GPT LLM models"""
|
||||
|
||||
def use_fast_tokenizer(self) -> bool:
|
||||
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
|
||||
for a given model.
|
||||
"""
|
||||
return False
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.HF
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||
"""Get the startup parameters instance of the model"""
|
||||
model_type = model_type if model_type else self.model_type()
|
||||
if model_type == ModelType.LLAMA_CPP:
|
||||
return LlamaCppModelParameters
|
||||
elif model_type == ModelType.PROXY:
|
||||
return ProxyModelParameters
|
||||
return ModelParameters
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
"""Load model and tokenizer"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_from_params(self, params):
|
||||
"""Load the model and tokenizer according to the given parameters"""
|
||||
raise NotImplementedError
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
return False
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the asynchronous generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
"""Get the default conv template"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def get_prompt_with_template(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
model_context: Dict,
|
||||
prompt_template: str = None,
|
||||
):
|
||||
conv = self.get_default_conv_template(model_name, model_path)
|
||||
|
||||
if prompt_template:
|
||||
logger.info(f"Use prompt template {prompt_template} from config")
|
||||
conv = get_conv_template(prompt_template)
|
||||
if not conv or not messages:
|
||||
# Nothing to do
|
||||
logger.info(
|
||||
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
conv = conv.copy()
|
||||
system_messages = []
|
||||
user_messages = []
|
||||
ai_messages = []
|
||||
|
||||
for message in messages:
|
||||
role, content = None, None
|
||||
if isinstance(message, ModelMessage):
|
||||
role = message.role
|
||||
content = message.content
|
||||
elif isinstance(message, dict):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
else:
|
||||
raise ValueError(f"Invalid message type: {message}")
|
||||
|
||||
if role == ModelMessageRoleType.SYSTEM:
|
||||
# Support for multiple system messages
|
||||
system_messages.append(content)
|
||||
elif role == ModelMessageRoleType.HUMAN:
|
||||
# conv.append_message(conv.roles[0], content)
|
||||
user_messages.append(content)
|
||||
elif role == ModelMessageRoleType.AI:
|
||||
# conv.append_message(conv.roles[1], content)
|
||||
ai_messages.append(content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
|
||||
can_use_systems: [] = []
|
||||
if system_messages:
|
||||
if len(system_messages) > 1:
|
||||
## Compatible with dbgpt complex scenarios, the last system will protect more complete information entered by the current user
|
||||
user_messages[-1] = system_messages[-1]
|
||||
can_use_systems = system_messages[:-1]
|
||||
else:
|
||||
can_use_systems = system_messages
|
||||
|
||||
for i in range(len(user_messages)):
|
||||
conv.append_message(conv.roles[0], user_messages[i])
|
||||
if i < len(ai_messages):
|
||||
conv.append_message(conv.roles[1], ai_messages[i])
|
||||
|
||||
if isinstance(conv, Conversation):
|
||||
conv.set_system_message("".join(can_use_systems))
|
||||
else:
|
||||
conv.update_system_message("".join(can_use_systems))
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
new_prompt = conv.get_prompt()
|
||||
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||
|
||||
def model_adaptation(
|
||||
self,
|
||||
params: Dict,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
messages = params.get("messages")
|
||||
# Some model scontext to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
if messages:
|
||||
# Dict message to ModelMessage
|
||||
messages = [
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in messages
|
||||
]
|
||||
params["messages"] = messages
|
||||
|
||||
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
||||
conv_stop_str, conv_stop_token_ids = None, None
|
||||
if not new_prompt:
|
||||
(
|
||||
new_prompt,
|
||||
conv_stop_str,
|
||||
conv_stop_token_ids,
|
||||
) = self.get_prompt_with_template(
|
||||
params, messages, model_name, model_path, model_context, prompt_template
|
||||
)
|
||||
if not new_prompt:
|
||||
return params, model_context
|
||||
|
||||
# Overwrite the original prompt
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||
model_context["echo"] = params.get("echo", True)
|
||||
model_context["has_format_prompt"] = True
|
||||
params["prompt"] = new_prompt
|
||||
|
||||
custom_stop = params.get("stop")
|
||||
custom_stop_token_ids = params.get("stop_token_ids")
|
||||
|
||||
# Prefer the value passed in from the input parameter
|
||||
params["stop"] = custom_stop or conv_stop_str
|
||||
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
|
||||
|
||||
return params, model_context
|
||||
|
||||
|
||||
class OldLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
"""Wrapping old adapter, which may be removed later"""
|
||||
|
||||
def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
|
||||
self._adapter = adapter
|
||||
self._chat_adapter = chat_adapter
|
||||
|
||||
def use_fast_tokenizer(self) -> bool:
|
||||
return self._adapter.use_fast_tokenizer()
|
||||
|
||||
def model_type(self) -> str:
|
||||
return self._adapter.model_type()
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||
return self._adapter.model_param_class(model_type)
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
return self._chat_adapter.get_conv_template(model_path)
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
return self._adapter.loader(model_path, from_pretrained_kwargs)
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
return self._chat_adapter.get_generate_stream_func(model_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}({}.{})".format(
|
||||
self.__class__.__name__,
|
||||
self._adapter.__class__.__module__,
|
||||
self._adapter.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
"""Wrapping fastchat adapter"""
|
||||
|
||||
def __init__(self, adapter: "BaseModelAdapter") -> None:
|
||||
self._adapter = adapter
|
||||
|
||||
def use_fast_tokenizer(self) -> bool:
|
||||
return self._adapter.use_fast_tokenizer
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
return self._adapter.load_model(model_path, from_pretrained_kwargs)
|
||||
|
||||
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
|
||||
if _IS_BENCHMARK:
|
||||
from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
|
||||
generate_stream,
|
||||
)
|
||||
|
||||
return generate_stream
|
||||
else:
|
||||
from fastchat.model.model_adapter import get_generate_stream_function
|
||||
|
||||
return get_generate_stream_function(model, model_path)
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
return self._adapter.get_default_conv_template(model_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}({}.{})".format(
|
||||
self.__class__.__name__,
|
||||
self._adapter.__class__.__module__,
|
||||
self._adapter.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
class NewHFChatModelAdapter(LLMModelAdaper):
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
try:
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import depend python package "
|
||||
"Please install it with `pip install transformers`."
|
||||
) from exc
|
||||
if not transformers.__version__ >= "4.34.0":
|
||||
raise ValueError(
|
||||
"Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
|
||||
)
|
||||
revision = from_pretrained_kwargs.get("revision", "main")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=self.use_fast_tokenizer,
|
||||
revision=revision,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except TypeError:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, use_fast=False, revision=revision, trust_remote_code=True
|
||||
)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
except NameError:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
# tokenizer.use_default_system_prompt = False
|
||||
return model, tokenizer
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
|
||||
|
||||
return huggingface_chat_generate_stream
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Optional[str]:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if not tokenizer:
|
||||
raise ValueError("tokenizer is is None")
|
||||
tokenizer: AutoTokenizer = tokenizer
|
||||
|
||||
messages = ModelMessage.to_openai_messages(messages)
|
||||
str_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return str_prompt
|
||||
|
||||
|
||||
def get_conv_template(name: str) -> "Conversation":
|
||||
"""Get a conversation template."""
|
||||
from fastchat.conversation import get_conv_template
|
||||
|
||||
return get_conv_template(name)
|
||||
|
||||
|
||||
@cache
|
||||
def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
|
||||
try:
|
||||
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
|
||||
return adapter.get_default_conv_template(model_name, model_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_model_adapter(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
use_fastchat: bool = True,
|
||||
use_fastchat_monkey_patch: bool = False,
|
||||
model_type: str = None,
|
||||
) -> LLMModelAdaper:
|
||||
if model_type == ModelType.VLLM:
|
||||
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
|
||||
return VLLMModelAdaperWrapper()
|
||||
|
||||
use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
|
||||
if use_new_hf_chat_models:
|
||||
logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
|
||||
return NewHFChatModelAdapter()
|
||||
|
||||
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
||||
if use_fastchat and not must_use_old:
|
||||
logger.info("Use fastcat adapter")
|
||||
adapter = _get_fastchat_model_adapter(
|
||||
model_name,
|
||||
model_path,
|
||||
_fastchat_get_adapter_monkey_patch,
|
||||
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
|
||||
)
|
||||
return FastChatLLMModelAdaperWrapper(adapter)
|
||||
else:
|
||||
from dbgpt.model.adapter import (
|
||||
get_llm_model_adapter as _old_get_llm_model_adapter,
|
||||
)
|
||||
from dbgpt.app.chat_adapter import get_llm_chat_adapter
|
||||
|
||||
logger.info("Use DB-GPT old adapter")
|
||||
return OldLLMModelAdaperWrapper(
|
||||
_old_get_llm_model_adapter(model_name, model_path),
|
||||
get_llm_chat_adapter(model_name, model_path),
|
||||
)
|
||||
|
||||
|
||||
def _get_fastchat_model_adapter(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
caller: Callable[[str], None] = None,
|
||||
use_fastchat_monkey_patch: bool = False,
|
||||
):
|
||||
from fastchat.model import model_adapter
|
||||
|
||||
_bak_get_model_adapter = model_adapter.get_model_adapter
|
||||
try:
|
||||
if use_fastchat_monkey_patch:
|
||||
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
||||
thread_local.model_name = model_name
|
||||
_remove_black_list_model_of_fastchat()
|
||||
if caller:
|
||||
return caller(model_path)
|
||||
finally:
|
||||
del thread_local.model_name
|
||||
model_adapter.get_model_adapter = _bak_get_model_adapter
|
||||
|
||||
|
||||
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
|
||||
if not model_name:
|
||||
if not hasattr(thread_local, "model_name"):
|
||||
raise RuntimeError("fastchat get adapter monkey path need model_name")
|
||||
model_name = thread_local.model_name
|
||||
from fastchat.model.model_adapter import model_adapters
|
||||
|
||||
for adapter in model_adapters:
|
||||
if adapter.match(model_name):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model name: {model_name}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
model_path_basename = (
|
||||
None if not model_path else os.path.basename(os.path.normpath(model_path))
|
||||
)
|
||||
for adapter in model_adapters:
|
||||
if model_path_basename and adapter.match(model_path_basename):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
for adapter in model_adapters:
|
||||
if model_path and adapter.match(model_path):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model path: {model_path}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid model adapter for model name {model_name} and model path {model_path}"
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _remove_black_list_model_of_fastchat():
|
||||
from fastchat.model.model_adapter import model_adapters
|
||||
|
||||
black_list_models = []
|
||||
for adapter in model_adapters:
|
||||
try:
|
||||
if (
|
||||
adapter.get_default_conv_template("/data/not_exist_model_path").name
|
||||
in _BLACK_LIST_MODLE_PROMPT
|
||||
):
|
||||
black_list_models.append(adapter)
|
||||
except Exception:
|
||||
pass
|
||||
for adapter in black_list_models:
|
||||
model_adapters.remove(adapter)
|
||||
|
||||
|
||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
from dbgpt.util.parameter_utils import _SimpleArgParser
|
||||
from dbgpt.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
WorkerType,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
model_path = pre_args.get("model_path")
|
||||
worker_type = pre_args.get("worker_type")
|
||||
model_type = pre_args.get("model_type")
|
||||
if model_name is None and model_type != ModelType.VLLM:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
model_name, EmbeddingModelParameters
|
||||
)
|
||||
]
|
||||
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
||||
|
||||
|
||||
class VLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
"""Wrapping vllm engine"""
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.VLLM
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||
import argparse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--model_name", type=str, help="model name")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
help="local model path of the huggingface model to use",
|
||||
)
|
||||
parser.add_argument("--model_type", type=str, help="model type")
|
||||
parser.add_argument("--device", type=str, default=None, help="device")
|
||||
# TODO parse prompt templete from `model_name` and `model_path`
|
||||
parser.add_argument(
|
||||
"--prompt_template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template. If None, the prompt template is automatically determined from model path",
|
||||
)
|
||||
|
||||
descs = _extract_parameter_details(
|
||||
parser,
|
||||
"dbgpt.model.parameter.VLLMModelParameters",
|
||||
skip_names=["model"],
|
||||
overwrite_default_values={"trust_remote_code": True},
|
||||
)
|
||||
return _build_parameter_class(descs)
|
||||
|
||||
def load_from_params(self, params):
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
import torch
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
|
||||
setattr(params, "tensor_parallel_size", num_gpus)
|
||||
logger.info(
|
||||
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
|
||||
)
|
||||
|
||||
params = dataclasses.asdict(params)
|
||||
params["model"] = params["model_path"]
|
||||
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
|
||||
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
|
||||
# Set the attributes from the parsed arguments.
|
||||
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
return engine, engine.engine.tokenizer
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.llm_out.vllm_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
return _auto_get_conv_template(model_name, model_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
|
||||
|
||||
|
||||
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
|
||||
# We also recommend that you modify it directly in the fastchat repository.
|
||||
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila-legacy",
|
||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("### Human: ", "### Assistant: ", "System"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||
sep="\n",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila",
|
||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant", "System"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
||||
sep="###",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila-v1",
|
||||
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||
sep="",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "<|endoftext|>"],
|
||||
),
|
||||
override=True,
|
||||
)
|
Reference in New Issue
Block a user