mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 17:16:51 +00:00
feat(model): Support Yi-34B-Chat (#837)
This commit is contained in:
@@ -135,6 +135,7 @@ At present, we have introduced several key features to showcase our current capa
|
||||
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
||||
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
||||
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
- [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||
|
||||
- Support API Proxy LLMs
|
||||
- [x] [ChatGPT](https://api.openai.com/)
|
||||
|
@@ -133,6 +133,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
|
||||
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
||||
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
||||
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
- [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||
|
||||
- 支持在线代理模型
|
||||
- [x] [ChatGPT](https://api.openai.com/)
|
||||
|
@@ -126,6 +126,9 @@ LLM_MODEL_CONFIG = {
|
||||
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
|
||||
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
|
||||
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
|
||||
# https://huggingface.co/01-ai/Yi-34B-Chat
|
||||
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
|
||||
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_CONFIG = {
|
||||
|
@@ -253,6 +253,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
params,
|
||||
self.model_name,
|
||||
self.model_path,
|
||||
self.tokenizer,
|
||||
prompt_template=self.ml.prompt_template,
|
||||
)
|
||||
stream_type = ""
|
||||
@@ -269,7 +270,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.model, self.model_path
|
||||
)
|
||||
str_prompt = params.get("prompt")
|
||||
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
|
||||
print(
|
||||
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
|
||||
)
|
||||
|
||||
generate_stream_func_str_name = "{}.{}".format(
|
||||
generate_stream_func.__module__, generate_stream_func.__name__
|
||||
|
54
pilot/model/llm_out/hf_chat_llm.py
Normal file
54
pilot/model/llm_out/hf_chat_llm.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def huggingface_chat_generate_stream(
|
||||
model: AutoModelForCausalLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
params,
|
||||
device,
|
||||
context_len=4096,
|
||||
):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 0.7))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
echo = params.get("echo", False)
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
# input_ids = input_ids.to(device)
|
||||
if model.config.is_encoder_decoder:
|
||||
max_src_len = context_len
|
||||
else: # truncate
|
||||
max_src_len = context_len - max_new_tokens - 1
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
input_echo_len = len(input_ids)
|
||||
input_ids = torch.as_tensor([input_ids], device=device)
|
||||
|
||||
# messages = params["messages"]
|
||||
# messages = ModelMessage.to_openai_messages(messages)
|
||||
# input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
|
||||
# input_ids = input_ids.to(device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, skip_prompt=not echo, skip_special_tokens=True
|
||||
)
|
||||
generate_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"max_length": context_len,
|
||||
"temperature": temperature,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
thread.start()
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
|
||||
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
@@ -41,6 +41,7 @@ logger = logging.getLogger(__name__)
|
||||
thread_local = threading.local()
|
||||
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||
|
||||
|
||||
_OLD_MODELS = [
|
||||
"llama-cpp",
|
||||
"proxyllm",
|
||||
@@ -51,6 +52,14 @@ _OLD_MODELS = [
|
||||
"codellama-13b",
|
||||
]
|
||||
|
||||
_NEW_HF_CHAT_MODELS = [
|
||||
"yi-34b",
|
||||
"yi-6b",
|
||||
]
|
||||
|
||||
# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
|
||||
_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]
|
||||
|
||||
|
||||
class LLMModelAdaper:
|
||||
"""New Adapter for DB-GPT LLM models"""
|
||||
@@ -99,26 +108,25 @@ class LLMModelAdaper:
|
||||
"""Get the default conv template"""
|
||||
raise NotImplementedError
|
||||
|
||||
def model_adaptation(
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def get_prompt_with_template(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
model_context: Dict,
|
||||
prompt_template: str = None,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
):
|
||||
conv = self.get_default_conv_template(model_name, model_path)
|
||||
messages = params.get("messages")
|
||||
# Some model scontext to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
|
||||
if messages:
|
||||
# Dict message to ModelMessage
|
||||
messages = [
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in messages
|
||||
]
|
||||
params["messages"] = messages
|
||||
|
||||
if prompt_template:
|
||||
logger.info(f"Use prompt template {prompt_template} from config")
|
||||
@@ -128,7 +136,7 @@ class LLMModelAdaper:
|
||||
logger.info(
|
||||
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||
)
|
||||
return params, model_context
|
||||
return None, None, None
|
||||
|
||||
conv = conv.copy()
|
||||
system_messages = []
|
||||
@@ -180,6 +188,41 @@ class LLMModelAdaper:
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
new_prompt = conv.get_prompt()
|
||||
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||
|
||||
def model_adaptation(
|
||||
self,
|
||||
params: Dict,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
messages = params.get("messages")
|
||||
# Some model scontext to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
if messages:
|
||||
# Dict message to ModelMessage
|
||||
messages = [
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in messages
|
||||
]
|
||||
params["messages"] = messages
|
||||
|
||||
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
||||
conv_stop_str, conv_stop_token_ids = None, None
|
||||
if not new_prompt:
|
||||
(
|
||||
new_prompt,
|
||||
conv_stop_str,
|
||||
conv_stop_token_ids,
|
||||
) = self.get_prompt_with_template(
|
||||
params, messages, model_name, model_path, model_context, prompt_template
|
||||
)
|
||||
if not new_prompt:
|
||||
return params, model_context
|
||||
|
||||
# Overwrite the original prompt
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
@@ -192,8 +235,8 @@ class LLMModelAdaper:
|
||||
custom_stop_token_ids = params.get("stop_token_ids")
|
||||
|
||||
# Prefer the value passed in from the input parameter
|
||||
params["stop"] = custom_stop or conv.stop_str
|
||||
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
|
||||
params["stop"] = custom_stop or conv_stop_str
|
||||
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
|
||||
|
||||
return params, model_context
|
||||
|
||||
@@ -270,6 +313,69 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
)
|
||||
|
||||
|
||||
class NewHFChatModelAdapter(LLMModelAdaper):
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
try:
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import depend python package "
|
||||
"Please install it with `pip install transformers`."
|
||||
) from exc
|
||||
if not transformers.__version__ >= "4.34.0":
|
||||
raise ValueError(
|
||||
"Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
|
||||
)
|
||||
revision = from_pretrained_kwargs.get("revision", "main")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=self.use_fast_tokenizer,
|
||||
revision=revision,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except TypeError:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, use_fast=False, revision=revision, trust_remote_code=True
|
||||
)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
except NameError:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
# tokenizer.use_default_system_prompt = False
|
||||
return model, tokenizer
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
from pilot.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
|
||||
|
||||
return huggingface_chat_generate_stream
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
) -> Optional[str]:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if not tokenizer:
|
||||
raise ValueError("tokenizer is is None")
|
||||
tokenizer: AutoTokenizer = tokenizer
|
||||
|
||||
messages = ModelMessage.to_openai_messages(messages)
|
||||
str_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return str_prompt
|
||||
|
||||
|
||||
def get_conv_template(name: str) -> "Conversation":
|
||||
"""Get a conversation template."""
|
||||
from fastchat.conversation import get_conv_template
|
||||
@@ -298,6 +404,11 @@ def get_llm_model_adapter(
|
||||
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
|
||||
return VLLMModelAdaperWrapper()
|
||||
|
||||
use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
|
||||
if use_new_hf_chat_models:
|
||||
logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
|
||||
return NewHFChatModelAdapter()
|
||||
|
||||
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
||||
if use_fastchat and not must_use_old:
|
||||
logger.info("Use fastcat adapter")
|
||||
@@ -334,6 +445,7 @@ def _get_fastchat_model_adapter(
|
||||
if use_fastchat_monkey_patch:
|
||||
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
||||
thread_local.model_name = model_name
|
||||
_remove_black_list_model_of_fastchat()
|
||||
if caller:
|
||||
return caller(model_path)
|
||||
finally:
|
||||
@@ -377,6 +489,24 @@ def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _remove_black_list_model_of_fastchat():
|
||||
from fastchat.model.model_adapter import model_adapters
|
||||
|
||||
black_list_models = []
|
||||
for adapter in model_adapters:
|
||||
try:
|
||||
if (
|
||||
adapter.get_default_conv_template("/data/not_exist_model_path").name
|
||||
in _BLACK_LIST_MODLE_PROMPT
|
||||
):
|
||||
black_list_models.append(adapter)
|
||||
except Exception:
|
||||
pass
|
||||
for adapter in black_list_models:
|
||||
model_adapters.remove(adapter)
|
||||
|
||||
|
||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
from pilot.model.parameter import (
|
||||
|
@@ -113,6 +113,34 @@ class ModelMessage(BaseModel):
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
"""Convert to OpenAI message format and
|
||||
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
"""
|
||||
history = []
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
# Move the last user's information to the end
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
return history
|
||||
|
||||
@staticmethod
|
||||
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
|
17
setup.py
17
setup.py
@@ -18,6 +18,10 @@ BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
||||
LLAMA_CPP_GPU_ACCELERATION = (
|
||||
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
|
||||
)
|
||||
BUILD_FROM_SOURCE = os.getenv("BUILD_FROM_SOURCE", "false").lower() == "true"
|
||||
BUILD_FROM_SOURCE_URL_FAST_CHAT = os.getenv(
|
||||
"BUILD_FROM_SOURCE_URL_FAST_CHAT", "git+https://github.com/lm-sys/FastChat.git"
|
||||
)
|
||||
|
||||
|
||||
def parse_requirements(file_name: str) -> List[str]:
|
||||
@@ -298,7 +302,6 @@ def core_requires():
|
||||
]
|
||||
|
||||
setup_spec.extras["framework"] = [
|
||||
"fschat",
|
||||
"coloredlogs",
|
||||
"httpx",
|
||||
"sqlparse==0.4.4",
|
||||
@@ -315,7 +318,8 @@ def core_requires():
|
||||
"duckdb-engine",
|
||||
"jsonschema",
|
||||
# TODO move transformers to default
|
||||
"transformers>=4.31.0",
|
||||
# "transformers>=4.31.0",
|
||||
"transformers>=4.34.0",
|
||||
"alembic==1.12.0",
|
||||
# for excel
|
||||
"openpyxl==3.1.2",
|
||||
@@ -324,6 +328,12 @@ def core_requires():
|
||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||
"pympler",
|
||||
]
|
||||
if BUILD_FROM_SOURCE:
|
||||
setup_spec.extras["framework"].append(
|
||||
f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}"
|
||||
)
|
||||
else:
|
||||
setup_spec.extras["framework"].append("fschat")
|
||||
|
||||
|
||||
def knowledge_requires():
|
||||
@@ -426,7 +436,8 @@ def default_requires():
|
||||
pip install "db-gpt[default]"
|
||||
"""
|
||||
setup_spec.extras["default"] = [
|
||||
"tokenizers==0.13.3",
|
||||
# "tokenizers==0.13.3",
|
||||
"tokenizers>=0.14",
|
||||
"accelerate>=0.20.3",
|
||||
"sentence-transformers",
|
||||
"protobuf==3.20.3",
|
||||
|
Reference in New Issue
Block a user