mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
feat(model): Support Mixtral-8x7B (#959)
This commit is contained in:
@@ -103,7 +103,8 @@ At present, we have introduced several key features to showcase our current capa
|
|||||||
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
|
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
|
||||||
|
|
||||||
- News
|
- News
|
||||||
- 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
- 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||||
|
- 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
||||||
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||||
- [More Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf)
|
- [More Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf)
|
||||||
|
|
||||||
|
@@ -111,7 +111,8 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
|
|||||||
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
|
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
|
||||||
|
|
||||||
- 新增支持模型
|
- 新增支持模型
|
||||||
- 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
- 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||||
|
- 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
|
||||||
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||||
- [更多开源模型](https://www.yuque.com/eosphoros/dbgpt-docs/iqaaqwriwhp6zslc#qQktR)
|
- [更多开源模型](https://www.yuque.com/eosphoros/dbgpt-docs/iqaaqwriwhp6zslc#qQktR)
|
||||||
|
|
||||||
|
@@ -245,7 +245,7 @@ class WizardLMChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
class LlamaCppChatAdapter(BaseChatAdpter):
|
class LlamaCppChatAdapter(BaseChatAdpter):
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
from dbgpt.model.adapter import LlamaCppAdapater
|
from dbgpt.model.adapter.old_adapter import LlamaCppAdapater
|
||||||
|
|
||||||
if "llama-cpp" == model_path:
|
if "llama-cpp" == model_path:
|
||||||
return True
|
return True
|
||||||
|
@@ -113,7 +113,9 @@ LLM_MODEL_CONFIG = {
|
|||||||
# https://huggingface.co/microsoft/Orca-2-13b
|
# https://huggingface.co/microsoft/Orca-2-13b
|
||||||
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
|
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
|
||||||
# https://huggingface.co/openchat/openchat_3.5
|
# https://huggingface.co/openchat/openchat_3.5
|
||||||
"openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
|
"openchat-3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
|
||||||
|
# https://huggingface.co/openchat/openchat-3.5-1210
|
||||||
|
"openchat-3.5-1210": os.path.join(MODEL_PATH, "openchat-3.5-1210"),
|
||||||
# https://huggingface.co/hfl/chinese-alpaca-2-7b
|
# https://huggingface.co/hfl/chinese-alpaca-2-7b
|
||||||
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
|
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
|
||||||
# https://huggingface.co/hfl/chinese-alpaca-2-13b
|
# https://huggingface.co/hfl/chinese-alpaca-2-13b
|
||||||
@@ -124,6 +126,10 @@ LLM_MODEL_CONFIG = {
|
|||||||
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
|
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
|
||||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||||
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
|
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
|
||||||
|
# https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
"mixtral-8x7b-instruct-v0.1": os.path.join(
|
||||||
|
MODEL_PATH, "Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
),
|
||||||
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
|
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
|
||||||
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
|
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
|
||||||
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1
|
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1
|
||||||
|
0
dbgpt/model/adapter/__init__.py
Normal file
0
dbgpt/model/adapter/__init__.py
Normal file
437
dbgpt/model/adapter/base.py
Normal file
437
dbgpt/model/adapter/base.py
Normal file
@@ -0,0 +1,437 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional, Any, Tuple, Type, Callable
|
||||||
|
import logging
|
||||||
|
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||||
|
from dbgpt.model.base import ModelType
|
||||||
|
from dbgpt.model.parameter import (
|
||||||
|
BaseModelParameters,
|
||||||
|
ModelParameters,
|
||||||
|
LlamaCppModelParameters,
|
||||||
|
ProxyModelParameters,
|
||||||
|
)
|
||||||
|
from dbgpt.model.adapter.template import (
|
||||||
|
get_conv_template,
|
||||||
|
ConversationAdapter,
|
||||||
|
ConversationAdapterFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelAdapter(ABC):
|
||||||
|
"""New Adapter for DB-GPT LLM models"""
|
||||||
|
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
model_path: Optional[str] = None
|
||||||
|
conv_factory: Optional[ConversationAdapterFactory] = None
|
||||||
|
# TODO: more flexible quantization config
|
||||||
|
support_4bit: bool = False
|
||||||
|
support_8bit: bool = False
|
||||||
|
support_system_message: bool = True
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__} model_name={self.model_name} model_path={self.model_path}>"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
|
||||||
|
"""Create a new adapter instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: The parameters of the new adapter instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMModelAdapter: The new adapter instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def use_fast_tokenizer(self) -> bool:
|
||||||
|
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
|
||||||
|
for a given model.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return ModelType.HF
|
||||||
|
|
||||||
|
def model_param_class(self, model_type: str = None) -> Type[BaseModelParameters]:
|
||||||
|
"""Get the startup parameters instance of the model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type (str, optional): The type of model. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[BaseModelParameters]: The startup parameters instance of the model
|
||||||
|
"""
|
||||||
|
# """Get the startup parameters instance of the model"""
|
||||||
|
model_type = model_type if model_type else self.model_type()
|
||||||
|
if model_type == ModelType.LLAMA_CPP:
|
||||||
|
return LlamaCppModelParameters
|
||||||
|
elif model_type == ModelType.PROXY:
|
||||||
|
return ProxyModelParameters
|
||||||
|
return ModelParameters
|
||||||
|
|
||||||
|
def match(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Whether the model adapter can load the given model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type (str): The type of model
|
||||||
|
model_name (Optional[str], optional): The name of model. Defaults to None.
|
||||||
|
model_path (Optional[str], optional): The path of model. Defaults to None.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def support_quantization_4bit(self) -> bool:
|
||||||
|
"""Whether the model adapter can load 4bit model
|
||||||
|
|
||||||
|
If it is True, we will load the 4bit model with :meth:`~LLMModelAdapter.load`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the model adapter can load 4bit model, default is False
|
||||||
|
"""
|
||||||
|
return self.support_4bit
|
||||||
|
|
||||||
|
def support_quantization_8bit(self) -> bool:
|
||||||
|
"""Whether the model adapter can load 8bit model
|
||||||
|
|
||||||
|
If it is True, we will load the 8bit model with :meth:`~LLMModelAdapter.load`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the model adapter can load 8bit model, default is False
|
||||||
|
"""
|
||||||
|
return self.support_8bit
|
||||||
|
|
||||||
|
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
"""Load model and tokenizer"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_from_params(self, params):
|
||||||
|
"""Load the model and tokenizer according to the given parameters"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def support_async(self) -> bool:
|
||||||
|
"""Whether the loaded model supports asynchronous calls"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_generate_stream_function(self, model, model_path: str):
|
||||||
|
"""Get the generate stream function of the model"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_async_generate_stream_function(self, model, model_path: str):
|
||||||
|
"""Get the asynchronous generate stream function of the model"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_default_conv_template(
|
||||||
|
self, model_name: str, model_path: str
|
||||||
|
) -> Optional[ConversationAdapter]:
|
||||||
|
"""Get the default conversation template
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
model_path (str): The path of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[ConversationAdapter]: The conversation template.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_default_message_separator(self) -> str:
|
||||||
|
"""Get the default message separator"""
|
||||||
|
try:
|
||||||
|
conv_template = self.get_default_conv_template(
|
||||||
|
self.model_name, self.model_path
|
||||||
|
)
|
||||||
|
return conv_template.sep
|
||||||
|
except Exception:
|
||||||
|
return "\n"
|
||||||
|
|
||||||
|
def transform_model_messages(
|
||||||
|
self, messages: List[ModelMessage]
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""Transform the model messages
|
||||||
|
|
||||||
|
Default is the OpenAI format, example:
|
||||||
|
.. code-block:: python
|
||||||
|
return_messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi"}
|
||||||
|
]
|
||||||
|
|
||||||
|
But some model may need to transform the messages to other format(e.g. There is no system message), such as:
|
||||||
|
.. code-block:: python
|
||||||
|
return_messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi"}
|
||||||
|
]
|
||||||
|
Args:
|
||||||
|
messages (List[ModelMessage]): The model messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, str]]: The transformed model messages
|
||||||
|
"""
|
||||||
|
logger.info(f"support_system_message: {self.support_system_message}")
|
||||||
|
if not self.support_system_message:
|
||||||
|
return self._transform_to_no_system_messages(messages)
|
||||||
|
else:
|
||||||
|
return ModelMessage.to_openai_messages(messages)
|
||||||
|
|
||||||
|
def _transform_to_no_system_messages(
|
||||||
|
self, messages: List[ModelMessage]
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""Transform the model messages to no system messages
|
||||||
|
|
||||||
|
Some opensource chat model no system messages, so wo should transform the messages to no system messages.
|
||||||
|
|
||||||
|
Merge the system messages to the last user message, example:
|
||||||
|
.. code-block:: python
|
||||||
|
return_messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi"}
|
||||||
|
]
|
||||||
|
=>
|
||||||
|
return_messages = [
|
||||||
|
{"role": "user", "content": "You are a helpful assistant\nHello"},
|
||||||
|
{"role": "assistant", "content": "Hi"}
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[ModelMessage]): The model messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, str]]: The transformed model messages
|
||||||
|
"""
|
||||||
|
openai_messages = ModelMessage.to_openai_messages(messages)
|
||||||
|
system_messages = []
|
||||||
|
return_messages = []
|
||||||
|
for message in openai_messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
system_messages.append(message["content"])
|
||||||
|
else:
|
||||||
|
return_messages.append(message)
|
||||||
|
if len(system_messages) > 1:
|
||||||
|
# Too much system messages should be a warning
|
||||||
|
logger.warning("Your system messages have more than one message")
|
||||||
|
if system_messages:
|
||||||
|
sep = self.get_default_message_separator()
|
||||||
|
str_system_messages = ",".join(system_messages)
|
||||||
|
# Update last user message
|
||||||
|
return_messages[-1]["content"] = (
|
||||||
|
str_system_messages + sep + return_messages[-1]["content"]
|
||||||
|
)
|
||||||
|
return return_messages
|
||||||
|
|
||||||
|
def get_str_prompt(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Get the string prompt from the given parameters and messages
|
||||||
|
|
||||||
|
If the value of return is not None, we will skip :meth:`~LLMModelAdapter.get_prompt_with_template` and use the value of return.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (Dict): The parameters
|
||||||
|
messages (List[ModelMessage]): The model messages
|
||||||
|
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
|
||||||
|
prompt_template (str, optional): The prompt template. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The string prompt
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_prompt_with_template(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
model_context: Dict,
|
||||||
|
prompt_template: str = None,
|
||||||
|
):
|
||||||
|
conv: ConversationAdapter = self.get_default_conv_template(
|
||||||
|
model_name, model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
logger.info(f"Use prompt template {prompt_template} from config")
|
||||||
|
conv = get_conv_template(prompt_template)
|
||||||
|
if not conv or not messages:
|
||||||
|
# Nothing to do
|
||||||
|
logger.info(
|
||||||
|
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||||
|
)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
conv = conv.copy()
|
||||||
|
system_messages = []
|
||||||
|
user_messages = []
|
||||||
|
ai_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, ModelMessage):
|
||||||
|
role = message.role
|
||||||
|
content = message.content
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
role = message["role"]
|
||||||
|
content = message["content"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid message type: {message}")
|
||||||
|
|
||||||
|
if role == ModelMessageRoleType.SYSTEM:
|
||||||
|
# Support for multiple system messages
|
||||||
|
system_messages.append(content)
|
||||||
|
elif role == ModelMessageRoleType.HUMAN:
|
||||||
|
# conv.append_message(conv.roles[0], content)
|
||||||
|
user_messages.append(content)
|
||||||
|
elif role == ModelMessageRoleType.AI:
|
||||||
|
# conv.append_message(conv.roles[1], content)
|
||||||
|
ai_messages.append(content)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role: {role}")
|
||||||
|
|
||||||
|
can_use_systems: [] = []
|
||||||
|
if system_messages:
|
||||||
|
if len(system_messages) > 1:
|
||||||
|
# Compatible with dbgpt complex scenarios, the last system will protect more complete information
|
||||||
|
# entered by the current user
|
||||||
|
user_messages[-1] = system_messages[-1]
|
||||||
|
can_use_systems = system_messages[:-1]
|
||||||
|
else:
|
||||||
|
can_use_systems = system_messages
|
||||||
|
|
||||||
|
for i in range(len(user_messages)):
|
||||||
|
conv.append_message(conv.roles[0], user_messages[i])
|
||||||
|
if i < len(ai_messages):
|
||||||
|
conv.append_message(conv.roles[1], ai_messages[i])
|
||||||
|
|
||||||
|
# TODO join all system messages may not be a good idea
|
||||||
|
conv.set_system_message("".join(can_use_systems))
|
||||||
|
# Add a blank message for the assistant.
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
new_prompt = conv.get_prompt()
|
||||||
|
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||||
|
|
||||||
|
def model_adaptation(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
"""Params adaptation"""
|
||||||
|
messages = params.get("messages")
|
||||||
|
# Some model context to dbgpt server
|
||||||
|
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||||
|
if messages:
|
||||||
|
# Dict message to ModelMessage
|
||||||
|
messages = [
|
||||||
|
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
params["messages"] = messages
|
||||||
|
|
||||||
|
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
||||||
|
conv_stop_str, conv_stop_token_ids = None, None
|
||||||
|
if not new_prompt:
|
||||||
|
(
|
||||||
|
new_prompt,
|
||||||
|
conv_stop_str,
|
||||||
|
conv_stop_token_ids,
|
||||||
|
) = self.get_prompt_with_template(
|
||||||
|
params, messages, model_name, model_path, model_context, prompt_template
|
||||||
|
)
|
||||||
|
if not new_prompt:
|
||||||
|
return params, model_context
|
||||||
|
|
||||||
|
# Overwrite the original prompt
|
||||||
|
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||||
|
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||||
|
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||||
|
model_context["echo"] = params.get("echo", True)
|
||||||
|
model_context["has_format_prompt"] = True
|
||||||
|
params["prompt"] = new_prompt
|
||||||
|
|
||||||
|
custom_stop = params.get("stop")
|
||||||
|
custom_stop_token_ids = params.get("stop_token_ids")
|
||||||
|
|
||||||
|
# Prefer the value passed in from the input parameter
|
||||||
|
params["stop"] = custom_stop or conv_stop_str
|
||||||
|
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
|
||||||
|
|
||||||
|
return params, model_context
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterEntry:
|
||||||
|
"""The entry of model adapter"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_adapter: LLMModelAdapter,
|
||||||
|
match_funcs: List[Callable[[str, str, str], bool]] = None,
|
||||||
|
):
|
||||||
|
self.model_adapter = model_adapter
|
||||||
|
self.match_funcs = match_funcs or []
|
||||||
|
|
||||||
|
|
||||||
|
model_adapters: List[AdapterEntry] = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_model_adapter(
|
||||||
|
model_adapter_cls: Type[LLMModelAdapter],
|
||||||
|
match_funcs: List[Callable[[str, str, str], bool]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register a model adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_adapter_cls (Type[LLMModelAdapter]): The model adapter class.
|
||||||
|
match_funcs (List[Callable[[str, str, str], bool]], optional): The match functions. Defaults to None.
|
||||||
|
"""
|
||||||
|
model_adapters.append(AdapterEntry(model_adapter_cls(), match_funcs))
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_adapter(
|
||||||
|
model_type: str,
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
conv_factory: Optional[ConversationAdapterFactory] = None,
|
||||||
|
) -> Optional[LLMModelAdapter]:
|
||||||
|
"""Get a model adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type (str): The type of the model.
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
model_path (str): The path of the model.
|
||||||
|
conv_factory (Optional[ConversationAdapterFactory], optional): The conversation factory. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
Optional[LLMModelAdapter]: The model adapter.
|
||||||
|
"""
|
||||||
|
adapter = None
|
||||||
|
# First find adapter by model_name
|
||||||
|
for adapter_entry in model_adapters:
|
||||||
|
if adapter_entry.model_adapter.match(model_type, model_name, None):
|
||||||
|
adapter = adapter_entry.model_adapter
|
||||||
|
break
|
||||||
|
for adapter_entry in model_adapters:
|
||||||
|
if adapter_entry.model_adapter.match(model_type, None, model_path):
|
||||||
|
adapter = adapter_entry.model_adapter
|
||||||
|
break
|
||||||
|
if adapter:
|
||||||
|
new_adapter = adapter.new_adapter()
|
||||||
|
new_adapter.model_name = model_name
|
||||||
|
new_adapter.model_path = model_path
|
||||||
|
if conv_factory:
|
||||||
|
new_adapter.conv_factory = conv_factory
|
||||||
|
return new_adapter
|
||||||
|
return None
|
262
dbgpt/model/adapter/fschat_adapter.py
Normal file
262
dbgpt/model/adapter/fschat_adapter.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""Adapter for fastchat
|
||||||
|
|
||||||
|
You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import logging
|
||||||
|
from functools import cache
|
||||||
|
from typing import TYPE_CHECKING, Callable, Tuple, List, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastchat.conversation import (
|
||||||
|
Conversation,
|
||||||
|
register_conv_template,
|
||||||
|
SeparatorStyle,
|
||||||
|
)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import python package: fschat "
|
||||||
|
"Please install fastchat by command `pip install fschat` "
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||||
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastchat.model.model_adapter import BaseModelAdapter
|
||||||
|
from torch.nn import Module as TorchNNModule
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
thread_local = threading.local()
|
||||||
|
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||||
|
|
||||||
|
# If some model is not in the blacklist, but it still affects the loading of DB-GPT, you can add it to the blacklist.
|
||||||
|
__BLACK_LIST_MODEL_PROMPT = []
|
||||||
|
|
||||||
|
|
||||||
|
class FschatConversationAdapter(ConversationAdapter):
|
||||||
|
"""The conversation adapter for fschat."""
|
||||||
|
|
||||||
|
def __init__(self, conv: Conversation):
|
||||||
|
self._conv = conv
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_type(self) -> PromptType:
|
||||||
|
return PromptType.FSCHAT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def roles(self) -> Tuple[str]:
|
||||||
|
return self._conv.roles
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sep(self) -> Optional[str]:
|
||||||
|
return self._conv.sep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_str(self) -> str:
|
||||||
|
return self._conv.stop_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_token_ids(self) -> Optional[List[int]]:
|
||||||
|
return self._conv.stop_token_ids
|
||||||
|
|
||||||
|
def get_prompt(self) -> str:
|
||||||
|
"""Get the prompt string."""
|
||||||
|
return self._conv.get_prompt()
|
||||||
|
|
||||||
|
def set_system_message(self, system_message: str) -> None:
|
||||||
|
"""Set the system message."""
|
||||||
|
self._conv.set_system_message(system_message)
|
||||||
|
|
||||||
|
def append_message(self, role: str, message: str) -> None:
|
||||||
|
"""Append a new message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role (str): The role of the message.
|
||||||
|
message (str): The message content.
|
||||||
|
"""
|
||||||
|
self._conv.append_message(role, message)
|
||||||
|
|
||||||
|
def update_last_message(self, message: str) -> None:
|
||||||
|
"""Update the last output.
|
||||||
|
|
||||||
|
The last message is typically set to be None when constructing the prompt,
|
||||||
|
so we need to update it in-place after getting the response from a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message content.
|
||||||
|
"""
|
||||||
|
self._conv.update_last_message(message)
|
||||||
|
|
||||||
|
def copy(self) -> "ConversationAdapter":
|
||||||
|
"""Copy the conversation."""
|
||||||
|
return FschatConversationAdapter(self._conv.copy())
|
||||||
|
|
||||||
|
|
||||||
|
class FastChatLLMModelAdapterWrapper(LLMModelAdapter):
|
||||||
|
"""Wrapping fastchat adapter"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: "BaseModelAdapter") -> None:
|
||||||
|
self._adapter = adapter
|
||||||
|
|
||||||
|
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
|
||||||
|
return FastChatLLMModelAdapterWrapper(self._adapter)
|
||||||
|
|
||||||
|
def use_fast_tokenizer(self) -> bool:
|
||||||
|
return self._adapter.use_fast_tokenizer
|
||||||
|
|
||||||
|
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
return self._adapter.load_model(model_path, from_pretrained_kwargs)
|
||||||
|
|
||||||
|
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
|
||||||
|
if _IS_BENCHMARK:
|
||||||
|
from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
|
||||||
|
generate_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
return generate_stream
|
||||||
|
else:
|
||||||
|
from fastchat.model.model_adapter import get_generate_stream_function
|
||||||
|
|
||||||
|
return get_generate_stream_function(model, model_path)
|
||||||
|
|
||||||
|
def get_default_conv_template(
|
||||||
|
self, model_name: str, model_path: str
|
||||||
|
) -> Optional[ConversationAdapter]:
|
||||||
|
conv_template = self._adapter.get_default_conv_template(model_path)
|
||||||
|
return FschatConversationAdapter(conv_template) if conv_template else None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "{}({}.{})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self._adapter.__class__.__module__,
|
||||||
|
self._adapter.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fastchat_model_adapter(
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
caller: Callable[[str], None] = None,
|
||||||
|
use_fastchat_monkey_patch: bool = False,
|
||||||
|
):
|
||||||
|
from fastchat.model import model_adapter
|
||||||
|
|
||||||
|
_bak_get_model_adapter = model_adapter.get_model_adapter
|
||||||
|
try:
|
||||||
|
if use_fastchat_monkey_patch:
|
||||||
|
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
||||||
|
thread_local.model_name = model_name
|
||||||
|
_remove_black_list_model_of_fastchat()
|
||||||
|
if caller:
|
||||||
|
return caller(model_path)
|
||||||
|
finally:
|
||||||
|
del thread_local.model_name
|
||||||
|
model_adapter.get_model_adapter = _bak_get_model_adapter
|
||||||
|
|
||||||
|
|
||||||
|
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
|
||||||
|
if not model_name:
|
||||||
|
if not hasattr(thread_local, "model_name"):
|
||||||
|
raise RuntimeError("fastchat get adapter monkey path need model_name")
|
||||||
|
model_name = thread_local.model_name
|
||||||
|
from fastchat.model.model_adapter import model_adapters
|
||||||
|
|
||||||
|
for adapter in model_adapters:
|
||||||
|
if adapter.match(model_name):
|
||||||
|
logger.info(
|
||||||
|
f"Found llm model adapter with model name: {model_name}, {adapter}"
|
||||||
|
)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
model_path_basename = (
|
||||||
|
None if not model_path else os.path.basename(os.path.normpath(model_path))
|
||||||
|
)
|
||||||
|
for adapter in model_adapters:
|
||||||
|
if model_path_basename and adapter.match(model_path_basename):
|
||||||
|
logger.info(
|
||||||
|
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
|
||||||
|
)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
for adapter in model_adapters:
|
||||||
|
if model_path and adapter.match(model_path):
|
||||||
|
logger.info(
|
||||||
|
f"Found llm model adapter with model path: {model_path}, {adapter}"
|
||||||
|
)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid model adapter for model name {model_name} and model path {model_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _remove_black_list_model_of_fastchat():
|
||||||
|
from fastchat.model.model_adapter import model_adapters
|
||||||
|
|
||||||
|
black_list_models = []
|
||||||
|
for adapter in model_adapters:
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
adapter.get_default_conv_template("/data/not_exist_model_path").name
|
||||||
|
in __BLACK_LIST_MODEL_PROMPT
|
||||||
|
):
|
||||||
|
black_list_models.append(adapter)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
for adapter in black_list_models:
|
||||||
|
model_adapters.remove(adapter)
|
||||||
|
|
||||||
|
|
||||||
|
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
|
||||||
|
# We also recommend that you modify it directly in the fastchat repository.
|
||||||
|
|
||||||
|
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="aquila-legacy",
|
||||||
|
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||||
|
roles=("### Human: ", "### Assistant: ", "System"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||||
|
sep="\n",
|
||||||
|
sep2="</s>",
|
||||||
|
stop_str=["</s>", "[UNK]"],
|
||||||
|
),
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="aquila",
|
||||||
|
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||||
|
roles=("Human", "Assistant", "System"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
||||||
|
sep="###",
|
||||||
|
sep2="</s>",
|
||||||
|
stop_str=["</s>", "[UNK]"],
|
||||||
|
),
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="aquila-v1",
|
||||||
|
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||||
|
sep="",
|
||||||
|
sep2="</s>",
|
||||||
|
stop_str=["</s>", "<|endoftext|>"],
|
||||||
|
),
|
||||||
|
override=True,
|
||||||
|
)
|
136
dbgpt/model/adapter/hf_adapter.py
Normal file
136
dbgpt/model/adapter/hf_adapter.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Optional, List, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from dbgpt.core import ModelMessage
|
||||||
|
from dbgpt.model.base import ModelType
|
||||||
|
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NewHFChatModelAdapter(LLMModelAdapter, ABC):
|
||||||
|
"""Model adapter for new huggingface chat models
|
||||||
|
|
||||||
|
See https://huggingface.co/docs/transformers/main/en/chat_templating
|
||||||
|
|
||||||
|
We can transform the inference chat messages to chat model instead of create a
|
||||||
|
prompt template for this model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def new_adapter(self, **kwargs) -> "NewHFChatModelAdapter":
|
||||||
|
return self.__class__()
|
||||||
|
|
||||||
|
def match(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
if model_type != ModelType.HF:
|
||||||
|
return False
|
||||||
|
if model_name is None and model_path is None:
|
||||||
|
return False
|
||||||
|
model_name = model_name.lower() if model_name else None
|
||||||
|
model_path = model_path.lower() if model_path else None
|
||||||
|
return self.do_match(model_name) or self.do_match(model_path)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import depend python package "
|
||||||
|
"Please install it with `pip install transformers`."
|
||||||
|
) from exc
|
||||||
|
if not transformers.__version__ >= "4.34.0":
|
||||||
|
raise ValueError(
|
||||||
|
"Current model (Load by NewHFChatModelAdapter) require transformers.__version__>=4.34.0"
|
||||||
|
)
|
||||||
|
revision = from_pretrained_kwargs.get("revision", "main")
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
use_fast=self.use_fast_tokenizer(),
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path, use_fast=False, revision=revision, trust_remote_code=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||||
|
)
|
||||||
|
except NameError:
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||||
|
)
|
||||||
|
# tokenizer.use_default_system_prompt = False
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def get_generate_stream_function(self, model, model_path: str):
|
||||||
|
"""Get the generate stream function of the model"""
|
||||||
|
from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
|
||||||
|
|
||||||
|
return huggingface_chat_generate_stream
|
||||||
|
|
||||||
|
def get_str_prompt(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
if not tokenizer:
|
||||||
|
raise ValueError("tokenizer is is None")
|
||||||
|
tokenizer: AutoTokenizer = tokenizer
|
||||||
|
|
||||||
|
messages = self.transform_model_messages(messages)
|
||||||
|
logger.debug(f"The messages after transform: \n{messages}")
|
||||||
|
str_prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return str_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class YiAdapter(NewHFChatModelAdapter):
|
||||||
|
support_4bit: bool = True
|
||||||
|
support_8bit: bool = True
|
||||||
|
support_system_message: bool = True
|
||||||
|
|
||||||
|
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||||
|
return (
|
||||||
|
lower_model_name_or_path
|
||||||
|
and "yi-" in lower_model_name_or_path
|
||||||
|
and "chat" in lower_model_name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Mixtral8x7BAdapter(NewHFChatModelAdapter):
|
||||||
|
"""
|
||||||
|
https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
support_4bit: bool = True
|
||||||
|
support_8bit: bool = True
|
||||||
|
support_system_message: bool = False
|
||||||
|
|
||||||
|
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||||
|
return (
|
||||||
|
lower_model_name_or_path
|
||||||
|
and "mixtral" in lower_model_name_or_path
|
||||||
|
and "8x7b" in lower_model_name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_adapter(YiAdapter)
|
||||||
|
register_model_adapter(Mixtral8x7BAdapter)
|
166
dbgpt/model/adapter/model_adapter.py
Normal file
166
dbgpt/model/adapter/model_adapter.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Type,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import os
|
||||||
|
from functools import cache
|
||||||
|
from dbgpt.model.base import ModelType
|
||||||
|
from dbgpt.model.parameter import BaseModelParameters
|
||||||
|
from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
|
||||||
|
from dbgpt.model.adapter.template import (
|
||||||
|
ConversationAdapter,
|
||||||
|
ConversationAdapterFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
thread_local = threading.local()
|
||||||
|
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
_OLD_MODELS = [
|
||||||
|
"llama-cpp",
|
||||||
|
"proxyllm",
|
||||||
|
"gptj-6b",
|
||||||
|
"codellama-13b-sql-sft",
|
||||||
|
"codellama-7b",
|
||||||
|
"codellama-7b-sql-sft",
|
||||||
|
"codellama-13b",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_llm_model_adapter(
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
use_fastchat: bool = True,
|
||||||
|
use_fastchat_monkey_patch: bool = False,
|
||||||
|
model_type: str = None,
|
||||||
|
) -> LLMModelAdapter:
|
||||||
|
conv_factory = DefaultConversationAdapterFactory()
|
||||||
|
if model_type == ModelType.VLLM:
|
||||||
|
logger.info("Current model type is vllm, return VLLMModelAdapterWrapper")
|
||||||
|
from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
|
||||||
|
|
||||||
|
return VLLMModelAdapterWrapper(conv_factory)
|
||||||
|
|
||||||
|
# Import NewHFChatModelAdapter for it can be registered
|
||||||
|
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
|
||||||
|
|
||||||
|
new_model_adapter = get_model_adapter(
|
||||||
|
model_type, model_name, model_path, conv_factory
|
||||||
|
)
|
||||||
|
if new_model_adapter:
|
||||||
|
logger.info(f"Current model {model_name} use new adapter {new_model_adapter}")
|
||||||
|
return new_model_adapter
|
||||||
|
|
||||||
|
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
||||||
|
result_adapter: Optional[LLMModelAdapter] = None
|
||||||
|
if use_fastchat and not must_use_old:
|
||||||
|
logger.info("Use fastcat adapter")
|
||||||
|
from dbgpt.model.adapter.fschat_adapter import (
|
||||||
|
_get_fastchat_model_adapter,
|
||||||
|
_fastchat_get_adapter_monkey_patch,
|
||||||
|
FastChatLLMModelAdapterWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = _get_fastchat_model_adapter(
|
||||||
|
model_name,
|
||||||
|
model_path,
|
||||||
|
_fastchat_get_adapter_monkey_patch,
|
||||||
|
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
|
||||||
|
)
|
||||||
|
if adapter:
|
||||||
|
result_adapter = FastChatLLMModelAdapterWrapper(adapter)
|
||||||
|
|
||||||
|
else:
|
||||||
|
from dbgpt.model.adapter.old_adapter import (
|
||||||
|
get_llm_model_adapter as _old_get_llm_model_adapter,
|
||||||
|
OldLLMModelAdapterWrapper,
|
||||||
|
)
|
||||||
|
from dbgpt.app.chat_adapter import get_llm_chat_adapter
|
||||||
|
|
||||||
|
logger.info("Use DB-GPT old adapter")
|
||||||
|
result_adapter = OldLLMModelAdapterWrapper(
|
||||||
|
_old_get_llm_model_adapter(model_name, model_path),
|
||||||
|
get_llm_chat_adapter(model_name, model_path),
|
||||||
|
)
|
||||||
|
if result_adapter:
|
||||||
|
result_adapter.model_name = model_name
|
||||||
|
result_adapter.model_path = model_path
|
||||||
|
result_adapter.conv_factory = conv_factory
|
||||||
|
return result_adapter
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Can not find adapter for model {model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _auto_get_conv_template(
|
||||||
|
model_name: str, model_path: str
|
||||||
|
) -> Optional[ConversationAdapter]:
|
||||||
|
"""Auto get the conversation template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
model_path (str): The path of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[ConversationAdapter]: The conversation template.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
|
||||||
|
return adapter.get_default_conv_template(model_name, model_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get conv template for {model_name} {model_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultConversationAdapterFactory(ConversationAdapterFactory):
|
||||||
|
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
|
||||||
|
"""Get a conversation adapter by model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
model_path (str): The path of the model.
|
||||||
|
Returns:
|
||||||
|
ConversationAdapter: The conversation adapter.
|
||||||
|
"""
|
||||||
|
return _auto_get_conv_template(model_name, model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
|
||||||
|
"""Dynamic model parser, parse the model parameters from the command line arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[Type[BaseModelParameters]]]: The model parameters class list.
|
||||||
|
"""
|
||||||
|
from dbgpt.util.parameter_utils import _SimpleArgParser
|
||||||
|
from dbgpt.model.parameter import (
|
||||||
|
EmbeddingModelParameters,
|
||||||
|
WorkerType,
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
|
||||||
|
pre_args.parse()
|
||||||
|
model_name = pre_args.get("model_name")
|
||||||
|
model_path = pre_args.get("model_path")
|
||||||
|
worker_type = pre_args.get("worker_type")
|
||||||
|
model_type = pre_args.get("model_type")
|
||||||
|
if model_name is None and model_type != ModelType.VLLM:
|
||||||
|
return None
|
||||||
|
if worker_type == WorkerType.TEXT2VEC:
|
||||||
|
return [
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||||
|
model_name, EmbeddingModelParameters
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
|
||||||
|
param_class = llm_adapter.model_param_class()
|
||||||
|
return [param_class]
|
@@ -9,7 +9,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, TYPE_CHECKING, Optional
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
@@ -17,6 +17,9 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
|
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||||
from dbgpt.model.base import ModelType
|
from dbgpt.model.base import ModelType
|
||||||
|
|
||||||
from dbgpt.model.parameter import (
|
from dbgpt.model.parameter import (
|
||||||
@@ -24,9 +27,13 @@ from dbgpt.model.parameter import (
|
|||||||
LlamaCppModelParameters,
|
LlamaCppModelParameters,
|
||||||
ProxyModelParameters,
|
ProxyModelParameters,
|
||||||
)
|
)
|
||||||
|
from dbgpt.model.conversation import Conversation
|
||||||
from dbgpt.configs.model_config import get_device
|
from dbgpt.configs.model_config import get_device
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from dbgpt.app.chat_adapter import BaseChatAdpter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@@ -92,17 +99,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
|
|
||||||
try:
|
|
||||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
|
||||||
return llm_adapter.model_param_class()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn(
|
|
||||||
f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`"
|
|
||||||
)
|
|
||||||
return ModelParameters
|
|
||||||
|
|
||||||
|
|
||||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||||
|
|
||||||
|
|
||||||
@@ -426,6 +422,87 @@ class InternLMAdapter(BaseLLMAdaper):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class OldLLMModelAdapterWrapper(LLMModelAdapter):
|
||||||
|
"""Wrapping old adapter, which may be removed later"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: BaseLLMAdaper, chat_adapter: "BaseChatAdpter") -> None:
|
||||||
|
self._adapter = adapter
|
||||||
|
self._chat_adapter = chat_adapter
|
||||||
|
|
||||||
|
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
|
||||||
|
return OldLLMModelAdapterWrapper(self._adapter, self._chat_adapter)
|
||||||
|
|
||||||
|
def use_fast_tokenizer(self) -> bool:
|
||||||
|
return self._adapter.use_fast_tokenizer()
|
||||||
|
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return self._adapter.model_type()
|
||||||
|
|
||||||
|
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||||
|
return self._adapter.model_param_class(model_type)
|
||||||
|
|
||||||
|
def get_default_conv_template(
|
||||||
|
self, model_name: str, model_path: str
|
||||||
|
) -> Optional[ConversationAdapter]:
|
||||||
|
conv_template = self._chat_adapter.get_conv_template(model_path)
|
||||||
|
return OldConversationAdapter(conv_template) if conv_template else None
|
||||||
|
|
||||||
|
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
return self._adapter.loader(model_path, from_pretrained_kwargs)
|
||||||
|
|
||||||
|
def get_generate_stream_function(self, model, model_path: str):
|
||||||
|
return self._chat_adapter.get_generate_stream_func(model_path)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "{}({}.{})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self._adapter.__class__.__module__,
|
||||||
|
self._adapter.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OldConversationAdapter(ConversationAdapter):
|
||||||
|
"""Wrapping old Conversation, which may be removed later"""
|
||||||
|
|
||||||
|
def __init__(self, conv: Conversation) -> None:
|
||||||
|
self._conv = conv
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_type(self) -> PromptType:
|
||||||
|
return PromptType.DBGPT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def roles(self) -> Tuple[str]:
|
||||||
|
return self._conv.roles
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sep(self) -> Optional[str]:
|
||||||
|
return self._conv.sep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_str(self) -> str:
|
||||||
|
return self._conv.stop_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_token_ids(self) -> Optional[List[int]]:
|
||||||
|
return self._conv.stop_token_ids
|
||||||
|
|
||||||
|
def get_prompt(self) -> str:
|
||||||
|
return self._conv.get_prompt()
|
||||||
|
|
||||||
|
def set_system_message(self, system_message: str) -> None:
|
||||||
|
self._conv.update_system_message(system_message)
|
||||||
|
|
||||||
|
def append_message(self, role: str, message: str) -> None:
|
||||||
|
self._conv.append_message(role, message)
|
||||||
|
|
||||||
|
def update_last_message(self, message: str) -> None:
|
||||||
|
self._conv.update_last_message(message)
|
||||||
|
|
||||||
|
def copy(self) -> "ConversationAdapter":
|
||||||
|
return OldConversationAdapter(self._conv.copy())
|
||||||
|
|
||||||
|
|
||||||
register_llm_model_adapters(VicunaLLMAdapater)
|
register_llm_model_adapters(VicunaLLMAdapater)
|
||||||
register_llm_model_adapters(ChatGLMAdapater)
|
register_llm_model_adapters(ChatGLMAdapater)
|
||||||
register_llm_model_adapters(GuanacoAdapter)
|
register_llm_model_adapters(GuanacoAdapter)
|
130
dbgpt/model/adapter/template.py
Normal file
130
dbgpt/model/adapter/template.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Optional, Tuple, Union, List
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastchat.conversation import Conversation
|
||||||
|
|
||||||
|
|
||||||
|
class PromptType(str, Enum):
|
||||||
|
"""Prompt type."""
|
||||||
|
|
||||||
|
FSCHAT: str = "fschat"
|
||||||
|
DBGPT: str = "dbgpt"
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationAdapter(ABC):
|
||||||
|
"""The conversation adapter."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_type(self) -> PromptType:
|
||||||
|
return PromptType.FSCHAT
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def roles(self) -> Tuple[str]:
|
||||||
|
"""Get the roles of the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str]: The roles of the conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sep(self) -> Optional[str]:
|
||||||
|
"""Get the separator between messages."""
|
||||||
|
return "\n"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_str(self) -> Optional[Union[str, List[str]]]:
|
||||||
|
"""Get the stop criteria."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_token_ids(self) -> Optional[List[int]]:
|
||||||
|
"""Stops generation if meeting any token in this list"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_prompt(self) -> str:
|
||||||
|
"""Get the prompt string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The prompt string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_system_message(self, system_message: str) -> None:
|
||||||
|
"""Set the system message."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def append_message(self, role: str, message: str) -> None:
|
||||||
|
"""Append a new message.
|
||||||
|
Args:
|
||||||
|
role (str): The role of the message.
|
||||||
|
message (str): The message content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_last_message(self, message: str) -> None:
|
||||||
|
"""Update the last output.
|
||||||
|
|
||||||
|
The last message is typically set to be None when constructing the prompt,
|
||||||
|
so we need to update it in-place after getting the response from a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def copy(self) -> "ConversationAdapter":
|
||||||
|
"""Copy the conversation."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationAdapterFactory(ABC):
|
||||||
|
"""The conversation adapter factory."""
|
||||||
|
|
||||||
|
def get_by_name(
|
||||||
|
self,
|
||||||
|
template_name: str,
|
||||||
|
prompt_template_type: Optional[PromptType] = PromptType.FSCHAT,
|
||||||
|
) -> ConversationAdapter:
|
||||||
|
"""Get a conversation adapter by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_name (str): The name of the template.
|
||||||
|
prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConversationAdapter: The conversation adapter.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
|
||||||
|
"""Get a conversation adapter by model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model.
|
||||||
|
model_path (str): The path of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConversationAdapter: The conversation adapter.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def get_conv_template(name: str) -> ConversationAdapter:
|
||||||
|
"""Get a conversation template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the template.
|
||||||
|
|
||||||
|
Just return the fastchat conversation template for now.
|
||||||
|
# TODO: More templates should be supported.
|
||||||
|
Returns:
|
||||||
|
Conversation: The conversation template.
|
||||||
|
"""
|
||||||
|
from fastchat.conversation import get_conv_template
|
||||||
|
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
|
||||||
|
|
||||||
|
conv_template = get_conv_template(name)
|
||||||
|
return FschatConversationAdapter(conv_template)
|
93
dbgpt/model/adapter/vllm_adapter.py
Normal file
93
dbgpt/model/adapter/vllm_adapter.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
from dbgpt.model.base import ModelType
|
||||||
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
|
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||||
|
from dbgpt.model.parameter import BaseModelParameters
|
||||||
|
from dbgpt.util.parameter_utils import (
|
||||||
|
_extract_parameter_details,
|
||||||
|
_build_parameter_class,
|
||||||
|
_get_dataclass_print_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMModelAdapterWrapper(LLMModelAdapter):
|
||||||
|
"""Wrapping vllm engine"""
|
||||||
|
|
||||||
|
def __init__(self, conv_factory: ConversationAdapterFactory):
|
||||||
|
self.conv_factory = conv_factory
|
||||||
|
|
||||||
|
def new_adapter(self, **kwargs) -> "VLLMModelAdapterWrapper":
|
||||||
|
return VLLMModelAdapterWrapper(self.conv_factory)
|
||||||
|
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return ModelType.VLLM
|
||||||
|
|
||||||
|
def model_param_class(self, model_type: str = None) -> BaseModelParameters:
|
||||||
|
import argparse
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
parser.add_argument("--model_name", type=str, help="model name")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_path",
|
||||||
|
type=str,
|
||||||
|
help="local model path of the huggingface model to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("--model_type", type=str, help="model type")
|
||||||
|
parser.add_argument("--device", type=str, default=None, help="device")
|
||||||
|
# TODO parse prompt templete from `model_name` and `model_path`
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt_template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Prompt template. If None, the prompt template is automatically determined from model path",
|
||||||
|
)
|
||||||
|
|
||||||
|
descs = _extract_parameter_details(
|
||||||
|
parser,
|
||||||
|
"dbgpt.model.parameter.VLLMModelParameters",
|
||||||
|
skip_names=["model"],
|
||||||
|
overwrite_default_values={"trust_remote_code": True},
|
||||||
|
)
|
||||||
|
return _build_parameter_class(descs)
|
||||||
|
|
||||||
|
def load_from_params(self, params):
|
||||||
|
from vllm import AsyncLLMEngine
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
import torch
|
||||||
|
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
|
||||||
|
setattr(params, "tensor_parallel_size", num_gpus)
|
||||||
|
logger.info(
|
||||||
|
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = dataclasses.asdict(params)
|
||||||
|
params["model"] = params["model_path"]
|
||||||
|
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
|
||||||
|
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
|
||||||
|
# Set the attributes from the parsed arguments.
|
||||||
|
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
|
||||||
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
return engine, engine.engine.tokenizer
|
||||||
|
|
||||||
|
def support_async(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_async_generate_stream_function(self, model, model_path: str):
|
||||||
|
from dbgpt.model.llm_out.vllm_llm import generate_stream
|
||||||
|
|
||||||
|
return generate_stream
|
||||||
|
|
||||||
|
def get_default_conv_template(
|
||||||
|
self, model_name: str, model_path: str
|
||||||
|
) -> ConversationAdapter:
|
||||||
|
return self.conv_factory.get_by_model(model_name, model_path)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
|
@@ -405,7 +405,7 @@ def stop_model_controller(port: int):
|
|||||||
|
|
||||||
|
|
||||||
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||||
from dbgpt.model.model_adapter import _dynamic_model_parser
|
from dbgpt.model.adapter.model_adapter import _dynamic_model_parser
|
||||||
|
|
||||||
param_class = _dynamic_model_parser()
|
param_class = _dynamic_model_parser()
|
||||||
fix_class = [ModelWorkerParameters]
|
fix_class = [ModelWorkerParameters]
|
||||||
|
@@ -6,7 +6,8 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from dbgpt.configs.model_config import get_device
|
from dbgpt.configs.model_config import get_device
|
||||||
from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
|
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||||
from dbgpt.core import ModelOutput, ModelInferenceMetrics
|
from dbgpt.core import ModelOutput, ModelInferenceMetrics
|
||||||
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
||||||
from dbgpt.model.parameter import ModelParameters
|
from dbgpt.model.parameter import ModelParameters
|
||||||
@@ -27,7 +28,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self._model_params = None
|
self._model_params = None
|
||||||
self.llm_adapter: LLMModelAdaper = None
|
self.llm_adapter: LLMModelAdapter = None
|
||||||
self._support_async = False
|
self._support_async = False
|
||||||
|
|
||||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||||
|
@@ -37,7 +37,7 @@ def list_supported_models():
|
|||||||
def _list_supported_models(
|
def _list_supported_models(
|
||||||
worker_type: str, model_config: Dict[str, str]
|
worker_type: str, model_config: Dict[str, str]
|
||||||
) -> List[SupportedModel]:
|
) -> List[SupportedModel]:
|
||||||
from dbgpt.model.model_adapter import get_llm_model_adapter
|
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||||
from dbgpt.model.loader import _get_model_real_path
|
from dbgpt.model.loader import _get_model_real_path
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from dataclasses import asdict
|
|
||||||
import logging
|
import logging
|
||||||
from dbgpt.configs.model_config import get_device
|
from dbgpt.configs.model_config import get_device
|
||||||
from dbgpt.model.base import ModelType
|
from dbgpt.model.base import ModelType
|
||||||
from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
|
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||||
from dbgpt.model.parameter import (
|
from dbgpt.model.parameter import (
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
LlamaCppModelParameters,
|
LlamaCppModelParameters,
|
||||||
@@ -117,7 +117,7 @@ class ModelLoader:
|
|||||||
raise Exception(f"Unkown model type {model_type}")
|
raise Exception(f"Unkown model type {model_type}")
|
||||||
|
|
||||||
def loader_with_params(
|
def loader_with_params(
|
||||||
self, model_params: ModelParameters, llm_adapter: LLMModelAdaper
|
self, model_params: ModelParameters, llm_adapter: LLMModelAdapter
|
||||||
):
|
):
|
||||||
model_type = llm_adapter.model_type()
|
model_type = llm_adapter.model_type()
|
||||||
self.prompt_template = model_params.prompt_template
|
self.prompt_template = model_params.prompt_template
|
||||||
@@ -133,7 +133,7 @@ class ModelLoader:
|
|||||||
raise Exception(f"Unkown model type {model_type}")
|
raise Exception(f"Unkown model type {model_type}")
|
||||||
|
|
||||||
|
|
||||||
def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters):
|
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
|
||||||
import torch
|
import torch
|
||||||
from dbgpt.model.compression import compress_module
|
from dbgpt.model.compression import compress_module
|
||||||
|
|
||||||
@@ -174,6 +174,12 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid device: {device}")
|
raise ValueError(f"Invalid device: {device}")
|
||||||
|
|
||||||
|
model, tokenizer = _try_load_default_quantization_model(
|
||||||
|
llm_adapter, device, num_gpus, model_params, kwargs
|
||||||
|
)
|
||||||
|
if model:
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
can_quantization = _check_quantization(model_params)
|
can_quantization = _check_quantization(model_params)
|
||||||
|
|
||||||
if can_quantization and (num_gpus > 1 or model_params.load_4bit):
|
if can_quantization and (num_gpus > 1 or model_params.load_4bit):
|
||||||
@@ -192,6 +198,46 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
|
|||||||
# TODO merge current code into `load_huggingface_quantization_model`
|
# TODO merge current code into `load_huggingface_quantization_model`
|
||||||
compress_module(model, model_params.device)
|
compress_module(model, model_params.device)
|
||||||
|
|
||||||
|
return _handle_model_and_tokenizer(model, tokenizer, device, num_gpus, model_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_load_default_quantization_model(
|
||||||
|
llm_adapter: LLMModelAdapter,
|
||||||
|
device: str,
|
||||||
|
num_gpus: int,
|
||||||
|
model_params: ModelParameters,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""Try load default quantization model(Support by huggingface default)"""
|
||||||
|
cloned_kwargs = {k: v for k, v in kwargs.items()}
|
||||||
|
try:
|
||||||
|
model, tokenizer = None, None
|
||||||
|
if device != "cuda":
|
||||||
|
return None, None
|
||||||
|
elif model_params.load_8bit and llm_adapter.support_8bit:
|
||||||
|
cloned_kwargs["load_in_8bit"] = True
|
||||||
|
model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
|
||||||
|
elif model_params.load_4bit and llm_adapter.support_4bit:
|
||||||
|
cloned_kwargs["load_in_4bit"] = True
|
||||||
|
model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
|
||||||
|
if model:
|
||||||
|
logger.info(
|
||||||
|
f"Load default quantization model {model_params.model_name} success"
|
||||||
|
)
|
||||||
|
return _handle_model_and_tokenizer(
|
||||||
|
model, tokenizer, device, num_gpus, model_params
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Load default quantization model {model_params.model_name} failed, error: {str(e)}"
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_model_and_tokenizer(
|
||||||
|
model, tokenizer, device: str, num_gpus: int, model_params: ModelParameters
|
||||||
|
):
|
||||||
if (
|
if (
|
||||||
(device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
|
(device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
|
||||||
or device == "mps"
|
or device == "mps"
|
||||||
@@ -209,7 +255,7 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
|
|||||||
|
|
||||||
|
|
||||||
def load_huggingface_quantization_model(
|
def load_huggingface_quantization_model(
|
||||||
llm_adapter: LLMModelAdaper,
|
llm_adapter: LLMModelAdapter,
|
||||||
model_params: ModelParameters,
|
model_params: ModelParameters,
|
||||||
kwargs: Dict,
|
kwargs: Dict,
|
||||||
max_memory: Dict[int, str],
|
max_memory: Dict[int, str],
|
||||||
@@ -344,7 +390,9 @@ def load_huggingface_quantization_model(
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters):
|
def llamacpp_loader(
|
||||||
|
llm_adapter: LLMModelAdapter, model_params: LlamaCppModelParameters
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
from dbgpt.model.llm.llama_cpp.llama_cpp import LlamaCppModel
|
from dbgpt.model.llm.llama_cpp.llama_cpp import LlamaCppModel
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
@@ -358,7 +406,7 @@ def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelPara
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters):
|
def proxyllm_loader(llm_adapter: LLMModelAdapter, model_params: ProxyModelParameters):
|
||||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
logger.info("Load proxyllm")
|
logger.info("Load proxyllm")
|
||||||
|
@@ -1,660 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import os
|
|
||||||
from functools import cache
|
|
||||||
from dbgpt.model.base import ModelType
|
|
||||||
from dbgpt.model.parameter import (
|
|
||||||
ModelParameters,
|
|
||||||
LlamaCppModelParameters,
|
|
||||||
ProxyModelParameters,
|
|
||||||
)
|
|
||||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
|
||||||
from dbgpt.util.parameter_utils import (
|
|
||||||
_extract_parameter_details,
|
|
||||||
_build_parameter_class,
|
|
||||||
_get_dataclass_print_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from fastchat.conversation import (
|
|
||||||
Conversation,
|
|
||||||
register_conv_template,
|
|
||||||
SeparatorStyle,
|
|
||||||
)
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import python package: fschat "
|
|
||||||
"Please install fastchat by command `pip install fschat` "
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fastchat.model.model_adapter import BaseModelAdapter
|
|
||||||
from dbgpt.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
|
|
||||||
from torch.nn import Module as TorchNNModule
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
thread_local = threading.local()
|
|
||||||
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
|
||||||
|
|
||||||
|
|
||||||
_OLD_MODELS = [
|
|
||||||
"llama-cpp",
|
|
||||||
"proxyllm",
|
|
||||||
"gptj-6b",
|
|
||||||
"codellama-13b-sql-sft",
|
|
||||||
"codellama-7b",
|
|
||||||
"codellama-7b-sql-sft",
|
|
||||||
"codellama-13b",
|
|
||||||
]
|
|
||||||
|
|
||||||
_NEW_HF_CHAT_MODELS = [
|
|
||||||
"yi-34b",
|
|
||||||
"yi-6b",
|
|
||||||
]
|
|
||||||
|
|
||||||
# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
|
|
||||||
_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]
|
|
||||||
|
|
||||||
|
|
||||||
class LLMModelAdaper:
|
|
||||||
"""New Adapter for DB-GPT LLM models"""
|
|
||||||
|
|
||||||
def use_fast_tokenizer(self) -> bool:
|
|
||||||
"""Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
|
|
||||||
for a given model.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def model_type(self) -> str:
|
|
||||||
return ModelType.HF
|
|
||||||
|
|
||||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
|
||||||
"""Get the startup parameters instance of the model"""
|
|
||||||
model_type = model_type if model_type else self.model_type()
|
|
||||||
if model_type == ModelType.LLAMA_CPP:
|
|
||||||
return LlamaCppModelParameters
|
|
||||||
elif model_type == ModelType.PROXY:
|
|
||||||
return ProxyModelParameters
|
|
||||||
return ModelParameters
|
|
||||||
|
|
||||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
|
||||||
"""Load model and tokenizer"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def load_from_params(self, params):
|
|
||||||
"""Load the model and tokenizer according to the given parameters"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def support_async(self) -> bool:
|
|
||||||
"""Whether the loaded model supports asynchronous calls"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_generate_stream_function(self, model, model_path: str):
|
|
||||||
"""Get the generate stream function of the model"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_async_generate_stream_function(self, model, model_path: str):
|
|
||||||
"""Get the asynchronous generate stream function of the model"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_default_conv_template(
|
|
||||||
self, model_name: str, model_path: str
|
|
||||||
) -> "Conversation":
|
|
||||||
"""Get the default conv template"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_str_prompt(
|
|
||||||
self,
|
|
||||||
params: Dict,
|
|
||||||
messages: List[ModelMessage],
|
|
||||||
tokenizer: Any,
|
|
||||||
prompt_template: str = None,
|
|
||||||
) -> Optional[str]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_prompt_with_template(
|
|
||||||
self,
|
|
||||||
params: Dict,
|
|
||||||
messages: List[ModelMessage],
|
|
||||||
model_name: str,
|
|
||||||
model_path: str,
|
|
||||||
model_context: Dict,
|
|
||||||
prompt_template: str = None,
|
|
||||||
):
|
|
||||||
conv = self.get_default_conv_template(model_name, model_path)
|
|
||||||
|
|
||||||
if prompt_template:
|
|
||||||
logger.info(f"Use prompt template {prompt_template} from config")
|
|
||||||
conv = get_conv_template(prompt_template)
|
|
||||||
if not conv or not messages:
|
|
||||||
# Nothing to do
|
|
||||||
logger.info(
|
|
||||||
f"No conv from model_path {model_path} or no messages in params, {self}"
|
|
||||||
)
|
|
||||||
return None, None, None
|
|
||||||
|
|
||||||
conv = conv.copy()
|
|
||||||
system_messages = []
|
|
||||||
user_messages = []
|
|
||||||
ai_messages = []
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
role, content = None, None
|
|
||||||
if isinstance(message, ModelMessage):
|
|
||||||
role = message.role
|
|
||||||
content = message.content
|
|
||||||
elif isinstance(message, dict):
|
|
||||||
role = message["role"]
|
|
||||||
content = message["content"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid message type: {message}")
|
|
||||||
|
|
||||||
if role == ModelMessageRoleType.SYSTEM:
|
|
||||||
# Support for multiple system messages
|
|
||||||
system_messages.append(content)
|
|
||||||
elif role == ModelMessageRoleType.HUMAN:
|
|
||||||
# conv.append_message(conv.roles[0], content)
|
|
||||||
user_messages.append(content)
|
|
||||||
elif role == ModelMessageRoleType.AI:
|
|
||||||
# conv.append_message(conv.roles[1], content)
|
|
||||||
ai_messages.append(content)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown role: {role}")
|
|
||||||
|
|
||||||
can_use_systems: [] = []
|
|
||||||
if system_messages:
|
|
||||||
if len(system_messages) > 1:
|
|
||||||
## Compatible with dbgpt complex scenarios, the last system will protect more complete information entered by the current user
|
|
||||||
user_messages[-1] = system_messages[-1]
|
|
||||||
can_use_systems = system_messages[:-1]
|
|
||||||
else:
|
|
||||||
can_use_systems = system_messages
|
|
||||||
|
|
||||||
for i in range(len(user_messages)):
|
|
||||||
conv.append_message(conv.roles[0], user_messages[i])
|
|
||||||
if i < len(ai_messages):
|
|
||||||
conv.append_message(conv.roles[1], ai_messages[i])
|
|
||||||
|
|
||||||
if isinstance(conv, Conversation):
|
|
||||||
conv.set_system_message("".join(can_use_systems))
|
|
||||||
else:
|
|
||||||
conv.update_system_message("".join(can_use_systems))
|
|
||||||
|
|
||||||
# Add a blank message for the assistant.
|
|
||||||
conv.append_message(conv.roles[1], None)
|
|
||||||
new_prompt = conv.get_prompt()
|
|
||||||
return new_prompt, conv.stop_str, conv.stop_token_ids
|
|
||||||
|
|
||||||
def model_adaptation(
|
|
||||||
self,
|
|
||||||
params: Dict,
|
|
||||||
model_name: str,
|
|
||||||
model_path: str,
|
|
||||||
tokenizer: Any,
|
|
||||||
prompt_template: str = None,
|
|
||||||
) -> Tuple[Dict, Dict]:
|
|
||||||
"""Params adaptation"""
|
|
||||||
messages = params.get("messages")
|
|
||||||
# Some model scontext to dbgpt server
|
|
||||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
|
||||||
if messages:
|
|
||||||
# Dict message to ModelMessage
|
|
||||||
messages = [
|
|
||||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
|
||||||
for m in messages
|
|
||||||
]
|
|
||||||
params["messages"] = messages
|
|
||||||
|
|
||||||
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
|
||||||
conv_stop_str, conv_stop_token_ids = None, None
|
|
||||||
if not new_prompt:
|
|
||||||
(
|
|
||||||
new_prompt,
|
|
||||||
conv_stop_str,
|
|
||||||
conv_stop_token_ids,
|
|
||||||
) = self.get_prompt_with_template(
|
|
||||||
params, messages, model_name, model_path, model_context, prompt_template
|
|
||||||
)
|
|
||||||
if not new_prompt:
|
|
||||||
return params, model_context
|
|
||||||
|
|
||||||
# Overwrite the original prompt
|
|
||||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
|
||||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
|
||||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
|
||||||
model_context["echo"] = params.get("echo", True)
|
|
||||||
model_context["has_format_prompt"] = True
|
|
||||||
params["prompt"] = new_prompt
|
|
||||||
|
|
||||||
custom_stop = params.get("stop")
|
|
||||||
custom_stop_token_ids = params.get("stop_token_ids")
|
|
||||||
|
|
||||||
# Prefer the value passed in from the input parameter
|
|
||||||
params["stop"] = custom_stop or conv_stop_str
|
|
||||||
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
|
|
||||||
|
|
||||||
return params, model_context
|
|
||||||
|
|
||||||
|
|
||||||
class OldLLMModelAdaperWrapper(LLMModelAdaper):
|
|
||||||
"""Wrapping old adapter, which may be removed later"""
|
|
||||||
|
|
||||||
def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
|
|
||||||
self._adapter = adapter
|
|
||||||
self._chat_adapter = chat_adapter
|
|
||||||
|
|
||||||
def use_fast_tokenizer(self) -> bool:
|
|
||||||
return self._adapter.use_fast_tokenizer()
|
|
||||||
|
|
||||||
def model_type(self) -> str:
|
|
||||||
return self._adapter.model_type()
|
|
||||||
|
|
||||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
|
||||||
return self._adapter.model_param_class(model_type)
|
|
||||||
|
|
||||||
def get_default_conv_template(
|
|
||||||
self, model_name: str, model_path: str
|
|
||||||
) -> "Conversation":
|
|
||||||
return self._chat_adapter.get_conv_template(model_path)
|
|
||||||
|
|
||||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
|
||||||
return self._adapter.loader(model_path, from_pretrained_kwargs)
|
|
||||||
|
|
||||||
def get_generate_stream_function(self, model, model_path: str):
|
|
||||||
return self._chat_adapter.get_generate_stream_func(model_path)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "{}({}.{})".format(
|
|
||||||
self.__class__.__name__,
|
|
||||||
self._adapter.__class__.__module__,
|
|
||||||
self._adapter.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
|
|
||||||
"""Wrapping fastchat adapter"""
|
|
||||||
|
|
||||||
def __init__(self, adapter: "BaseModelAdapter") -> None:
|
|
||||||
self._adapter = adapter
|
|
||||||
|
|
||||||
def use_fast_tokenizer(self) -> bool:
|
|
||||||
return self._adapter.use_fast_tokenizer
|
|
||||||
|
|
||||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
|
||||||
return self._adapter.load_model(model_path, from_pretrained_kwargs)
|
|
||||||
|
|
||||||
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
|
|
||||||
if _IS_BENCHMARK:
|
|
||||||
from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
|
|
||||||
generate_stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
return generate_stream
|
|
||||||
else:
|
|
||||||
from fastchat.model.model_adapter import get_generate_stream_function
|
|
||||||
|
|
||||||
return get_generate_stream_function(model, model_path)
|
|
||||||
|
|
||||||
def get_default_conv_template(
|
|
||||||
self, model_name: str, model_path: str
|
|
||||||
) -> "Conversation":
|
|
||||||
return self._adapter.get_default_conv_template(model_path)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "{}({}.{})".format(
|
|
||||||
self.__class__.__name__,
|
|
||||||
self._adapter.__class__.__module__,
|
|
||||||
self._adapter.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NewHFChatModelAdapter(LLMModelAdaper):
|
|
||||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
|
||||||
try:
|
|
||||||
import transformers
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import depend python package "
|
|
||||||
"Please install it with `pip install transformers`."
|
|
||||||
) from exc
|
|
||||||
if not transformers.__version__ >= "4.34.0":
|
|
||||||
raise ValueError(
|
|
||||||
"Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
|
|
||||||
)
|
|
||||||
revision = from_pretrained_kwargs.get("revision", "main")
|
|
||||||
try:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
use_fast=self.use_fast_tokenizer,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_path, use_fast=False, revision=revision, trust_remote_code=True
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
|
||||||
)
|
|
||||||
except NameError:
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
|
||||||
)
|
|
||||||
# tokenizer.use_default_system_prompt = False
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
def get_generate_stream_function(self, model, model_path: str):
|
|
||||||
"""Get the generate stream function of the model"""
|
|
||||||
from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
|
|
||||||
|
|
||||||
return huggingface_chat_generate_stream
|
|
||||||
|
|
||||||
def get_str_prompt(
|
|
||||||
self,
|
|
||||||
params: Dict,
|
|
||||||
messages: List[ModelMessage],
|
|
||||||
tokenizer: Any,
|
|
||||||
prompt_template: str = None,
|
|
||||||
) -> Optional[str]:
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
if not tokenizer:
|
|
||||||
raise ValueError("tokenizer is is None")
|
|
||||||
tokenizer: AutoTokenizer = tokenizer
|
|
||||||
|
|
||||||
messages = ModelMessage.to_openai_messages(messages)
|
|
||||||
str_prompt = tokenizer.apply_chat_template(
|
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
return str_prompt
|
|
||||||
|
|
||||||
|
|
||||||
def get_conv_template(name: str) -> "Conversation":
|
|
||||||
"""Get a conversation template."""
|
|
||||||
from fastchat.conversation import get_conv_template
|
|
||||||
|
|
||||||
return get_conv_template(name)
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
|
|
||||||
try:
|
|
||||||
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
|
|
||||||
return adapter.get_default_conv_template(model_name, model_path)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def get_llm_model_adapter(
|
|
||||||
model_name: str,
|
|
||||||
model_path: str,
|
|
||||||
use_fastchat: bool = True,
|
|
||||||
use_fastchat_monkey_patch: bool = False,
|
|
||||||
model_type: str = None,
|
|
||||||
) -> LLMModelAdaper:
|
|
||||||
if model_type == ModelType.VLLM:
|
|
||||||
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
|
|
||||||
return VLLMModelAdaperWrapper()
|
|
||||||
|
|
||||||
use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
|
|
||||||
if use_new_hf_chat_models:
|
|
||||||
logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
|
|
||||||
return NewHFChatModelAdapter()
|
|
||||||
|
|
||||||
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
|
||||||
if use_fastchat and not must_use_old:
|
|
||||||
logger.info("Use fastcat adapter")
|
|
||||||
adapter = _get_fastchat_model_adapter(
|
|
||||||
model_name,
|
|
||||||
model_path,
|
|
||||||
_fastchat_get_adapter_monkey_patch,
|
|
||||||
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
|
|
||||||
)
|
|
||||||
return FastChatLLMModelAdaperWrapper(adapter)
|
|
||||||
else:
|
|
||||||
from dbgpt.model.adapter import (
|
|
||||||
get_llm_model_adapter as _old_get_llm_model_adapter,
|
|
||||||
)
|
|
||||||
from dbgpt.app.chat_adapter import get_llm_chat_adapter
|
|
||||||
|
|
||||||
logger.info("Use DB-GPT old adapter")
|
|
||||||
return OldLLMModelAdaperWrapper(
|
|
||||||
_old_get_llm_model_adapter(model_name, model_path),
|
|
||||||
get_llm_chat_adapter(model_name, model_path),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_fastchat_model_adapter(
|
|
||||||
model_name: str,
|
|
||||||
model_path: str,
|
|
||||||
caller: Callable[[str], None] = None,
|
|
||||||
use_fastchat_monkey_patch: bool = False,
|
|
||||||
):
|
|
||||||
from fastchat.model import model_adapter
|
|
||||||
|
|
||||||
_bak_get_model_adapter = model_adapter.get_model_adapter
|
|
||||||
try:
|
|
||||||
if use_fastchat_monkey_patch:
|
|
||||||
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
|
||||||
thread_local.model_name = model_name
|
|
||||||
_remove_black_list_model_of_fastchat()
|
|
||||||
if caller:
|
|
||||||
return caller(model_path)
|
|
||||||
finally:
|
|
||||||
del thread_local.model_name
|
|
||||||
model_adapter.get_model_adapter = _bak_get_model_adapter
|
|
||||||
|
|
||||||
|
|
||||||
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
|
|
||||||
if not model_name:
|
|
||||||
if not hasattr(thread_local, "model_name"):
|
|
||||||
raise RuntimeError("fastchat get adapter monkey path need model_name")
|
|
||||||
model_name = thread_local.model_name
|
|
||||||
from fastchat.model.model_adapter import model_adapters
|
|
||||||
|
|
||||||
for adapter in model_adapters:
|
|
||||||
if adapter.match(model_name):
|
|
||||||
logger.info(
|
|
||||||
f"Found llm model adapter with model name: {model_name}, {adapter}"
|
|
||||||
)
|
|
||||||
return adapter
|
|
||||||
|
|
||||||
model_path_basename = (
|
|
||||||
None if not model_path else os.path.basename(os.path.normpath(model_path))
|
|
||||||
)
|
|
||||||
for adapter in model_adapters:
|
|
||||||
if model_path_basename and adapter.match(model_path_basename):
|
|
||||||
logger.info(
|
|
||||||
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
|
|
||||||
)
|
|
||||||
return adapter
|
|
||||||
|
|
||||||
for adapter in model_adapters:
|
|
||||||
if model_path and adapter.match(model_path):
|
|
||||||
logger.info(
|
|
||||||
f"Found llm model adapter with model path: {model_path}, {adapter}"
|
|
||||||
)
|
|
||||||
return adapter
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid model adapter for model name {model_name} and model path {model_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def _remove_black_list_model_of_fastchat():
|
|
||||||
from fastchat.model.model_adapter import model_adapters
|
|
||||||
|
|
||||||
black_list_models = []
|
|
||||||
for adapter in model_adapters:
|
|
||||||
try:
|
|
||||||
if (
|
|
||||||
adapter.get_default_conv_template("/data/not_exist_model_path").name
|
|
||||||
in _BLACK_LIST_MODLE_PROMPT
|
|
||||||
):
|
|
||||||
black_list_models.append(adapter)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
for adapter in black_list_models:
|
|
||||||
model_adapters.remove(adapter)
|
|
||||||
|
|
||||||
|
|
||||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
|
||||||
from dbgpt.util.parameter_utils import _SimpleArgParser
|
|
||||||
from dbgpt.model.parameter import (
|
|
||||||
EmbeddingModelParameters,
|
|
||||||
WorkerType,
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
|
||||||
)
|
|
||||||
|
|
||||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
|
|
||||||
pre_args.parse()
|
|
||||||
model_name = pre_args.get("model_name")
|
|
||||||
model_path = pre_args.get("model_path")
|
|
||||||
worker_type = pre_args.get("worker_type")
|
|
||||||
model_type = pre_args.get("model_type")
|
|
||||||
if model_name is None and model_type != ModelType.VLLM:
|
|
||||||
return None
|
|
||||||
if worker_type == WorkerType.TEXT2VEC:
|
|
||||||
return [
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
|
||||||
model_name, EmbeddingModelParameters
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
|
|
||||||
param_class = llm_adapter.model_param_class()
|
|
||||||
return [param_class]
|
|
||||||
|
|
||||||
|
|
||||||
class VLLMModelAdaperWrapper(LLMModelAdaper):
|
|
||||||
"""Wrapping vllm engine"""
|
|
||||||
|
|
||||||
def model_type(self) -> str:
|
|
||||||
return ModelType.VLLM
|
|
||||||
|
|
||||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
|
||||||
import argparse
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
|
||||||
parser.add_argument("--model_name", type=str, help="model name")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
help="local model path of the huggingface model to use",
|
|
||||||
)
|
|
||||||
parser.add_argument("--model_type", type=str, help="model type")
|
|
||||||
parser.add_argument("--device", type=str, default=None, help="device")
|
|
||||||
# TODO parse prompt templete from `model_name` and `model_path`
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_template",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Prompt template. If None, the prompt template is automatically determined from model path",
|
|
||||||
)
|
|
||||||
|
|
||||||
descs = _extract_parameter_details(
|
|
||||||
parser,
|
|
||||||
"dbgpt.model.parameter.VLLMModelParameters",
|
|
||||||
skip_names=["model"],
|
|
||||||
overwrite_default_values={"trust_remote_code": True},
|
|
||||||
)
|
|
||||||
return _build_parameter_class(descs)
|
|
||||||
|
|
||||||
def load_from_params(self, params):
|
|
||||||
from vllm import AsyncLLMEngine
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
||||||
import torch
|
|
||||||
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
|
|
||||||
setattr(params, "tensor_parallel_size", num_gpus)
|
|
||||||
logger.info(
|
|
||||||
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = dataclasses.asdict(params)
|
|
||||||
params["model"] = params["model_path"]
|
|
||||||
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
|
|
||||||
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
|
|
||||||
# Set the attributes from the parsed arguments.
|
|
||||||
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
|
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
||||||
return engine, engine.engine.tokenizer
|
|
||||||
|
|
||||||
def support_async(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_async_generate_stream_function(self, model, model_path: str):
|
|
||||||
from dbgpt.model.llm_out.vllm_llm import generate_stream
|
|
||||||
|
|
||||||
return generate_stream
|
|
||||||
|
|
||||||
def get_default_conv_template(
|
|
||||||
self, model_name: str, model_path: str
|
|
||||||
) -> "Conversation":
|
|
||||||
return _auto_get_conv_template(model_name, model_path)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
|
|
||||||
# We also recommend that you modify it directly in the fastchat repository.
|
|
||||||
|
|
||||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
|
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name="aquila-legacy",
|
|
||||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
|
||||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
|
||||||
roles=("### Human: ", "### Assistant: ", "System"),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
|
||||||
sep="\n",
|
|
||||||
sep2="</s>",
|
|
||||||
stop_str=["</s>", "[UNK]"],
|
|
||||||
),
|
|
||||||
override=True,
|
|
||||||
)
|
|
||||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name="aquila",
|
|
||||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
|
||||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
|
||||||
roles=("Human", "Assistant", "System"),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
|
||||||
sep="###",
|
|
||||||
sep2="</s>",
|
|
||||||
stop_str=["</s>", "[UNK]"],
|
|
||||||
),
|
|
||||||
override=True,
|
|
||||||
)
|
|
||||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name="aquila-v1",
|
|
||||||
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
|
||||||
sep="",
|
|
||||||
sep2="</s>",
|
|
||||||
stop_str=["</s>", "<|endoftext|>"],
|
|
||||||
),
|
|
||||||
override=True,
|
|
||||||
)
|
|
Reference in New Issue
Block a user