feat(model): Support Yi-34B-Chat (#837)

This commit is contained in:
FangYin Cheng
2023-11-24 20:05:09 +08:00
committed by GitHub
parent 507566825f
commit a92f34081c
8 changed files with 253 additions and 22 deletions

View File

@@ -135,6 +135,7 @@ At present, we have introduced several key features to showcase our current capa
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
- [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
- Support API Proxy LLMs
- [x] [ChatGPT](https://api.openai.com/)

View File

@@ -133,6 +133,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
- [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
- 支持在线代理模型
- [x] [ChatGPT](https://api.openai.com/)

View File

@@ -126,6 +126,9 @@ LLM_MODEL_CONFIG = {
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
# https://huggingface.co/01-ai/Yi-34B-Chat
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
}
EMBEDDING_MODEL_CONFIG = {

View File

@@ -253,6 +253,7 @@ class DefaultModelWorker(ModelWorker):
params,
self.model_name,
self.model_path,
self.tokenizer,
prompt_template=self.ml.prompt_template,
)
stream_type = ""
@@ -269,7 +270,9 @@ class DefaultModelWorker(ModelWorker):
self.model, self.model_path
)
str_prompt = params.get("prompt")
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
print(
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
)
generate_stream_func_str_name = "{}.{}".format(
generate_stream_func.__module__, generate_stream_func.__name__

View File

@@ -0,0 +1,54 @@
import logging
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
logger = logging.getLogger(__name__)
@torch.inference_mode()
def huggingface_chat_generate_stream(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
params,
device,
context_len=4096,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 0.7))
top_p = float(params.get("top_p", 1.0))
echo = params.get("echo", False)
max_new_tokens = int(params.get("max_new_tokens", 2048))
input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
if model.config.is_encoder_decoder:
max_src_len = context_len
else: # truncate
max_src_len = context_len - max_new_tokens - 1
input_ids = input_ids[-max_src_len:]
input_echo_len = len(input_ids)
input_ids = torch.as_tensor([input_ids], device=device)
# messages = params["messages"]
# messages = ModelMessage.to_openai_messages(messages)
# input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
# input_ids = input_ids.to(device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"max_length": context_len,
"temperature": temperature,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
out = ""
for new_text in streamer:
out += new_text
yield out

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
import dataclasses
import logging
import threading
@@ -41,6 +41,7 @@ logger = logging.getLogger(__name__)
thread_local = threading.local()
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
_OLD_MODELS = [
"llama-cpp",
"proxyllm",
@@ -51,6 +52,14 @@ _OLD_MODELS = [
"codellama-13b",
]
_NEW_HF_CHAT_MODELS = [
"yi-34b",
"yi-6b",
]
# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]
class LLMModelAdaper:
"""New Adapter for DB-GPT LLM models"""
@@ -99,26 +108,25 @@ class LLMModelAdaper:
"""Get the default conv template"""
raise NotImplementedError
def model_adaptation(
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
return None
def get_prompt_with_template(
self,
params: Dict,
messages: List[ModelMessage],
model_name: str,
model_path: str,
model_context: Dict,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
):
conv = self.get_default_conv_template(model_name, model_path)
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages
if prompt_template:
logger.info(f"Use prompt template {prompt_template} from config")
@@ -128,7 +136,7 @@ class LLMModelAdaper:
logger.info(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return params, model_context
return None, None, None
conv = conv.copy()
system_messages = []
@@ -180,6 +188,41 @@ class LLMModelAdaper:
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids
def model_adaptation(
self,
params: Dict,
model_name: str,
model_path: str,
tokenizer: Any,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
conv_stop_str, conv_stop_token_ids = None, None
if not new_prompt:
(
new_prompt,
conv_stop_str,
conv_stop_token_ids,
) = self.get_prompt_with_template(
params, messages, model_name, model_path, model_context, prompt_template
)
if not new_prompt:
return params, model_context
# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
@@ -192,8 +235,8 @@ class LLMModelAdaper:
custom_stop_token_ids = params.get("stop_token_ids")
# Prefer the value passed in from the input parameter
params["stop"] = custom_stop or conv.stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
params["stop"] = custom_stop or conv_stop_str
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
return params, model_context
@@ -270,6 +313,69 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
)
class NewHFChatModelAdapter(LLMModelAdaper):
def load(self, model_path: str, from_pretrained_kwargs: dict):
try:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
"Please install it with `pip install transformers`."
) from exc
if not transformers.__version__ >= "4.34.0":
raise ValueError(
"Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
)
revision = from_pretrained_kwargs.get("revision", "main")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=self.use_fast_tokenizer,
revision=revision,
trust_remote_code=True,
)
except TypeError:
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, revision=revision, trust_remote_code=True
)
try:
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
except NameError:
model = AutoModel.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
# tokenizer.use_default_system_prompt = False
return model, tokenizer
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
from pilot.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
return huggingface_chat_generate_stream
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
) -> Optional[str]:
from transformers import AutoTokenizer
if not tokenizer:
raise ValueError("tokenizer is is None")
tokenizer: AutoTokenizer = tokenizer
messages = ModelMessage.to_openai_messages(messages)
str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return str_prompt
def get_conv_template(name: str) -> "Conversation":
"""Get a conversation template."""
from fastchat.conversation import get_conv_template
@@ -298,6 +404,11 @@ def get_llm_model_adapter(
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
return VLLMModelAdaperWrapper()
use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
if use_new_hf_chat_models:
logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
return NewHFChatModelAdapter()
must_use_old = any(m in model_name for m in _OLD_MODELS)
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
@@ -334,6 +445,7 @@ def _get_fastchat_model_adapter(
if use_fastchat_monkey_patch:
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
thread_local.model_name = model_name
_remove_black_list_model_of_fastchat()
if caller:
return caller(model_path)
finally:
@@ -377,6 +489,24 @@ def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
)
@cache
def _remove_black_list_model_of_fastchat():
from fastchat.model.model_adapter import model_adapters
black_list_models = []
for adapter in model_adapters:
try:
if (
adapter.get_default_conv_template("/data/not_exist_model_path").name
in _BLACK_LIST_MODLE_PROMPT
):
black_list_models.append(adapter)
except Exception:
pass
for adapter in black_list_models:
model_adapters.remove(adapter)
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import (

View File

@@ -113,6 +113,34 @@ class ModelMessage(BaseModel):
raise ValueError(f"Unknown role: {msg_role}")
return result
@staticmethod
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
"""Convert to OpenAI message format and
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
"""
history = []
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
# Move the last user's information to the end
temp_his = history[::-1]
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
return history
@staticmethod
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages))

View File

@@ -18,6 +18,10 @@ BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = (
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
)
BUILD_FROM_SOURCE = os.getenv("BUILD_FROM_SOURCE", "false").lower() == "true"
BUILD_FROM_SOURCE_URL_FAST_CHAT = os.getenv(
"BUILD_FROM_SOURCE_URL_FAST_CHAT", "git+https://github.com/lm-sys/FastChat.git"
)
def parse_requirements(file_name: str) -> List[str]:
@@ -298,7 +302,6 @@ def core_requires():
]
setup_spec.extras["framework"] = [
"fschat",
"coloredlogs",
"httpx",
"sqlparse==0.4.4",
@@ -315,7 +318,8 @@ def core_requires():
"duckdb-engine",
"jsonschema",
# TODO move transformers to default
"transformers>=4.31.0",
# "transformers>=4.31.0",
"transformers>=4.34.0",
"alembic==1.12.0",
# for excel
"openpyxl==3.1.2",
@@ -324,6 +328,12 @@ def core_requires():
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
"pympler",
]
if BUILD_FROM_SOURCE:
setup_spec.extras["framework"].append(
f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}"
)
else:
setup_spec.extras["framework"].append("fschat")
def knowledge_requires():
@@ -426,7 +436,8 @@ def default_requires():
pip install "db-gpt[default]"
"""
setup_spec.extras["default"] = [
"tokenizers==0.13.3",
# "tokenizers==0.13.3",
"tokenizers>=0.14",
"accelerate>=0.20.3",
"sentence-transformers",
"protobuf==3.20.3",