feat: Support baichuan-13B model

This commit is contained in:
FangYin Cheng 2023-07-25 00:53:28 +08:00
parent 7cc86a8b54
commit 01074660bc
8 changed files with 69 additions and 2 deletions

View File

@ -50,6 +50,7 @@ LLM_MODEL_CONFIG = {
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
}
# Load model config

View File

@ -12,6 +12,8 @@ from transformers import (
LlamaTokenizer,
BitsAndBytesConfig,
)
from transformers.generation.utils import GenerationConfig
from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config
@ -276,6 +278,24 @@ class Llama2Adapter(BaseLLMAdaper):
return model, tokenizer
class BaichuanAdapter(BaseLLMAdaper):
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
def match(self, model_path: str):
return "baichuan" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict):
# revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
@ -283,6 +303,7 @@ register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
register_llm_model_adapters(BaichuanAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later

View File

@ -2,6 +2,8 @@
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates.
"""
import dataclasses
@ -305,4 +307,21 @@ register_conv_template(
)
)
# Baichuan-13B-Chat template
register_conv_template(
# source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
Conversation(
name="baichuan-chat",
system="",
roles=(" <reserved_102> ", " <reserved_103> "),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_token_ids=[2, 195],
)
)
# TODO Support other model conversation template

View File

@ -54,17 +54,19 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
model_context = data.get("model_context")
has_echo = True
if model_context and "prompt_echo_len_char" in model_context:
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
has_echo = bool(model_context.get("echo", True))
if prompt_echo_len_char != -1:
skip_echo_len = prompt_echo_len_char
if data.get("error_code", 0) == 0:
if "vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL:
if has_echo and ("vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL):
# TODO Judging from model_context
# output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip()
elif "guanaco" in CFG.LLM_MODEL:
elif has_echo and "guanaco" in CFG.LLM_MODEL:
# NO stream output
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()

View File

@ -69,6 +69,7 @@ class BaseChat(ABC):
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
self.llm_echo = False
### can configurable storage methods
self.memory = DuckdbHistoryMemory(chat_session_id)
@ -128,6 +129,7 @@ class BaseChat(ABC):
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"stop": self.prompt_template.sep,
"echo": self.llm_echo,
}
return payload

View File

@ -62,7 +62,12 @@ class BaseChatAdpter:
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
params["prompt"] = new_prompt
# Overwrite model params:
params["stop"] = conv.stop_str
return params, model_context
@ -195,6 +200,19 @@ class Llama2ChatAdapter(BaseChatAdpter):
return generate_stream
class BaichuanChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "baichuan" in model_path.lower()
def get_conv_template(self) -> Conversation:
return get_conv_template("baichuan-chat")
def get_generate_stream_func(self):
from pilot.model.inference import generate_stream
return generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
@ -202,6 +220,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -136,6 +136,7 @@ class PromptRequest(BaseModel):
max_new_tokens: int
model: str
stop: str = None
echo: bool = True
class StreamRequest(BaseModel):
@ -178,6 +179,7 @@ def generate(prompt_request: PromptRequest) -> str:
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop,
"echo": prompt_request.echo,
}
rsp_str = ""

View File

@ -28,6 +28,7 @@ pyyaml==6.0
tokenizers==0.13.2
tqdm==4.64.1
transformers==4.30.0
transformers_stream_generator
timm==0.6.13
spacy==3.5.3
webdataset==0.2.48