mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 01:50:08 +00:00
feat: Support baichuan-13B model
This commit is contained in:
parent
7cc86a8b54
commit
01074660bc
@ -50,6 +50,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||||
|
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
|
@ -12,6 +12,8 @@ from transformers import (
|
|||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
)
|
)
|
||||||
|
from transformers.generation.utils import GenerationConfig
|
||||||
|
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import DEVICE
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
@ -276,6 +278,24 @@ class Llama2Adapter(BaseLLMAdaper):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class BaichuanAdapter(BaseLLMAdaper):
|
||||||
|
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "baichuan" in model_path.lower()
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
# revision = from_pretrained_kwargs.get("revision", "main")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
**from_pretrained_kwargs,
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
register_llm_model_adapters(VicunaLLMAdapater)
|
register_llm_model_adapters(VicunaLLMAdapater)
|
||||||
register_llm_model_adapters(ChatGLMAdapater)
|
register_llm_model_adapters(ChatGLMAdapater)
|
||||||
register_llm_model_adapters(GuanacoAdapter)
|
register_llm_model_adapters(GuanacoAdapter)
|
||||||
@ -283,6 +303,7 @@ register_llm_model_adapters(FalconAdapater)
|
|||||||
register_llm_model_adapters(GorillaAdapter)
|
register_llm_model_adapters(GorillaAdapter)
|
||||||
register_llm_model_adapters(GPT4AllAdapter)
|
register_llm_model_adapters(GPT4AllAdapter)
|
||||||
register_llm_model_adapters(Llama2Adapter)
|
register_llm_model_adapters(Llama2Adapter)
|
||||||
|
register_llm_model_adapters(BaichuanAdapter)
|
||||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||||
|
|
||||||
# just for test_py, remove this later
|
# just for test_py, remove this later
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
|
|
||||||
Conversation prompt templates.
|
Conversation prompt templates.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
@ -305,4 +307,21 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Baichuan-13B-Chat template
|
||||||
|
register_conv_template(
|
||||||
|
# source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
|
||||||
|
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
|
||||||
|
Conversation(
|
||||||
|
name="baichuan-chat",
|
||||||
|
system="",
|
||||||
|
roles=(" <reserved_102> ", " <reserved_103> "),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||||
|
sep="",
|
||||||
|
sep2="</s>",
|
||||||
|
stop_token_ids=[2, 195],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# TODO Support other model conversation template
|
# TODO Support other model conversation template
|
||||||
|
@ -54,17 +54,19 @@ class BaseOutputParser(ABC):
|
|||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
"""
|
"""
|
||||||
model_context = data.get("model_context")
|
model_context = data.get("model_context")
|
||||||
|
has_echo = True
|
||||||
if model_context and "prompt_echo_len_char" in model_context:
|
if model_context and "prompt_echo_len_char" in model_context:
|
||||||
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
|
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
|
||||||
|
has_echo = bool(model_context.get("echo", True))
|
||||||
if prompt_echo_len_char != -1:
|
if prompt_echo_len_char != -1:
|
||||||
skip_echo_len = prompt_echo_len_char
|
skip_echo_len = prompt_echo_len_char
|
||||||
|
|
||||||
if data.get("error_code", 0) == 0:
|
if data.get("error_code", 0) == 0:
|
||||||
if "vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL:
|
if has_echo and ("vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL):
|
||||||
# TODO Judging from model_context
|
# TODO Judging from model_context
|
||||||
# output = data["text"][skip_echo_len + 11:].strip()
|
# output = data["text"][skip_echo_len + 11:].strip()
|
||||||
output = data["text"][skip_echo_len:].strip()
|
output = data["text"][skip_echo_len:].strip()
|
||||||
elif "guanaco" in CFG.LLM_MODEL:
|
elif has_echo and "guanaco" in CFG.LLM_MODEL:
|
||||||
# NO stream output
|
# NO stream output
|
||||||
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ class BaseChat(ABC):
|
|||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
self.current_user_input: str = current_user_input
|
self.current_user_input: str = current_user_input
|
||||||
self.llm_model = CFG.LLM_MODEL
|
self.llm_model = CFG.LLM_MODEL
|
||||||
|
self.llm_echo = False
|
||||||
### can configurable storage methods
|
### can configurable storage methods
|
||||||
self.memory = DuckdbHistoryMemory(chat_session_id)
|
self.memory = DuckdbHistoryMemory(chat_session_id)
|
||||||
|
|
||||||
@ -128,6 +129,7 @@ class BaseChat(ABC):
|
|||||||
"temperature": float(self.prompt_template.temperature),
|
"temperature": float(self.prompt_template.temperature),
|
||||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||||
"stop": self.prompt_template.sep,
|
"stop": self.prompt_template.sep,
|
||||||
|
"echo": self.llm_echo,
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
@ -62,7 +62,12 @@ class BaseChatAdpter:
|
|||||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||||
|
model_context["echo"] = params.get("echo", True)
|
||||||
params["prompt"] = new_prompt
|
params["prompt"] = new_prompt
|
||||||
|
|
||||||
|
# Overwrite model params:
|
||||||
|
params["stop"] = conv.stop_str
|
||||||
|
|
||||||
return params, model_context
|
return params, model_context
|
||||||
|
|
||||||
|
|
||||||
@ -195,6 +200,19 @@ class Llama2ChatAdapter(BaseChatAdpter):
|
|||||||
return generate_stream
|
return generate_stream
|
||||||
|
|
||||||
|
|
||||||
|
class BaichuanChatAdapter(BaseChatAdpter):
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "baichuan" in model_path.lower()
|
||||||
|
|
||||||
|
def get_conv_template(self) -> Conversation:
|
||||||
|
return get_conv_template("baichuan-chat")
|
||||||
|
|
||||||
|
def get_generate_stream_func(self):
|
||||||
|
from pilot.model.inference import generate_stream
|
||||||
|
|
||||||
|
return generate_stream
|
||||||
|
|
||||||
|
|
||||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||||
@ -202,6 +220,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
|
|||||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||||
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||||
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
||||||
|
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
||||||
|
|
||||||
# Proxy model for test and develop, it's cheap for us now.
|
# Proxy model for test and develop, it's cheap for us now.
|
||||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||||
|
@ -136,6 +136,7 @@ class PromptRequest(BaseModel):
|
|||||||
max_new_tokens: int
|
max_new_tokens: int
|
||||||
model: str
|
model: str
|
||||||
stop: str = None
|
stop: str = None
|
||||||
|
echo: bool = True
|
||||||
|
|
||||||
|
|
||||||
class StreamRequest(BaseModel):
|
class StreamRequest(BaseModel):
|
||||||
@ -178,6 +179,7 @@ def generate(prompt_request: PromptRequest) -> str:
|
|||||||
"temperature": prompt_request.temperature,
|
"temperature": prompt_request.temperature,
|
||||||
"max_new_tokens": prompt_request.max_new_tokens,
|
"max_new_tokens": prompt_request.max_new_tokens,
|
||||||
"stop": prompt_request.stop,
|
"stop": prompt_request.stop,
|
||||||
|
"echo": prompt_request.echo,
|
||||||
}
|
}
|
||||||
|
|
||||||
rsp_str = ""
|
rsp_str = ""
|
||||||
|
@ -28,6 +28,7 @@ pyyaml==6.0
|
|||||||
tokenizers==0.13.2
|
tokenizers==0.13.2
|
||||||
tqdm==4.64.1
|
tqdm==4.64.1
|
||||||
transformers==4.30.0
|
transformers==4.30.0
|
||||||
|
transformers_stream_generator
|
||||||
timm==0.6.13
|
timm==0.6.13
|
||||||
spacy==3.5.3
|
spacy==3.5.3
|
||||||
webdataset==0.2.48
|
webdataset==0.2.48
|
||||||
|
Loading…
Reference in New Issue
Block a user