mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 07:34:07 +00:00
feat: Support baichuan-13B model
This commit is contained in:
parent
7cc86a8b54
commit
01074660bc
@ -50,6 +50,7 @@ LLM_MODEL_CONFIG = {
|
||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||
}
|
||||
|
||||
# Load model config
|
||||
|
@ -12,6 +12,8 @@ from transformers import (
|
||||
LlamaTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
from transformers.generation.utils import GenerationConfig
|
||||
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.config import Config
|
||||
|
||||
@ -276,6 +278,24 @@ class Llama2Adapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class BaichuanAdapter(BaseLLMAdaper):
|
||||
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "baichuan" in model_path.lower()
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
# revision = from_pretrained_kwargs.get("revision", "main")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
**from_pretrained_kwargs,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
register_llm_model_adapters(GuanacoAdapter)
|
||||
@ -283,6 +303,7 @@ register_llm_model_adapters(FalconAdapater)
|
||||
register_llm_model_adapters(GorillaAdapter)
|
||||
register_llm_model_adapters(GPT4AllAdapter)
|
||||
register_llm_model_adapters(Llama2Adapter)
|
||||
register_llm_model_adapters(BaichuanAdapter)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
# just for test_py, remove this later
|
||||
|
@ -2,6 +2,8 @@
|
||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
|
||||
Conversation prompt templates.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
@ -305,4 +307,21 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Baichuan-13B-Chat template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
|
||||
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
|
||||
Conversation(
|
||||
name="baichuan-chat",
|
||||
system="",
|
||||
roles=(" <reserved_102> ", " <reserved_103> "),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||
sep="",
|
||||
sep2="</s>",
|
||||
stop_token_ids=[2, 195],
|
||||
)
|
||||
)
|
||||
|
||||
# TODO Support other model conversation template
|
||||
|
@ -54,17 +54,19 @@ class BaseOutputParser(ABC):
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
model_context = data.get("model_context")
|
||||
has_echo = True
|
||||
if model_context and "prompt_echo_len_char" in model_context:
|
||||
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
|
||||
has_echo = bool(model_context.get("echo", True))
|
||||
if prompt_echo_len_char != -1:
|
||||
skip_echo_len = prompt_echo_len_char
|
||||
|
||||
if data.get("error_code", 0) == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL:
|
||||
if has_echo and ("vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL):
|
||||
# TODO Judging from model_context
|
||||
# output = data["text"][skip_echo_len + 11:].strip()
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
elif "guanaco" in CFG.LLM_MODEL:
|
||||
elif has_echo and "guanaco" in CFG.LLM_MODEL:
|
||||
# NO stream output
|
||||
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
||||
|
||||
|
@ -69,6 +69,7 @@ class BaseChat(ABC):
|
||||
self.chat_mode = chat_mode
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
self.llm_echo = False
|
||||
### can configurable storage methods
|
||||
self.memory = DuckdbHistoryMemory(chat_session_id)
|
||||
|
||||
@ -128,6 +129,7 @@ class BaseChat(ABC):
|
||||
"temperature": float(self.prompt_template.temperature),
|
||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||
"stop": self.prompt_template.sep,
|
||||
"echo": self.llm_echo,
|
||||
}
|
||||
return payload
|
||||
|
||||
|
@ -62,7 +62,12 @@ class BaseChatAdpter:
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||
model_context["echo"] = params.get("echo", True)
|
||||
params["prompt"] = new_prompt
|
||||
|
||||
# Overwrite model params:
|
||||
params["stop"] = conv.stop_str
|
||||
|
||||
return params, model_context
|
||||
|
||||
|
||||
@ -195,6 +200,19 @@ class Llama2ChatAdapter(BaseChatAdpter):
|
||||
return generate_stream
|
||||
|
||||
|
||||
class BaichuanChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "baichuan" in model_path.lower()
|
||||
|
||||
def get_conv_template(self) -> Conversation:
|
||||
return get_conv_template("baichuan-chat")
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||
@ -202,6 +220,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
|
||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
||||
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
@ -136,6 +136,7 @@ class PromptRequest(BaseModel):
|
||||
max_new_tokens: int
|
||||
model: str
|
||||
stop: str = None
|
||||
echo: bool = True
|
||||
|
||||
|
||||
class StreamRequest(BaseModel):
|
||||
@ -178,6 +179,7 @@ def generate(prompt_request: PromptRequest) -> str:
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
"stop": prompt_request.stop,
|
||||
"echo": prompt_request.echo,
|
||||
}
|
||||
|
||||
rsp_str = ""
|
||||
|
@ -28,6 +28,7 @@ pyyaml==6.0
|
||||
tokenizers==0.13.2
|
||||
tqdm==4.64.1
|
||||
transformers==4.30.0
|
||||
transformers_stream_generator
|
||||
timm==0.6.13
|
||||
spacy==3.5.3
|
||||
webdataset==0.2.48
|
||||
|
Loading…
Reference in New Issue
Block a user