diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index aea5c6731..d719ad3dc 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -50,6 +50,7 @@ LLM_MODEL_CONFIG = { "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), + "baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"), } # Load model config diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index cea73c602..ebe7b82d5 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -12,6 +12,8 @@ from transformers import ( LlamaTokenizer, BitsAndBytesConfig, ) +from transformers.generation.utils import GenerationConfig + from pilot.configs.model_config import DEVICE from pilot.configs.config import Config @@ -276,6 +278,24 @@ class Llama2Adapter(BaseLLMAdaper): return model, tokenizer +class BaichuanAdapter(BaseLLMAdaper): + """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)""" + + def match(self, model_path: str): + return "baichuan" in model_path.lower() + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + # revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) @@ -283,6 +303,7 @@ register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GPT4AllAdapter) register_llm_model_adapters(Llama2Adapter) +register_llm_model_adapters(BaichuanAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test_py, remove this later diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index 5443992df..fa57b2af5 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -2,6 +2,8 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py Conversation prompt templates. + + """ import dataclasses @@ -305,4 +307,21 @@ register_conv_template( ) ) +# Baichuan-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507 + # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json + Conversation( + name="baichuan-chat", + system="", + roles=(" ", " "), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_token_ids=[2, 195], + ) +) + # TODO Support other model conversation template diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 000d92709..013f15b1e 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -54,17 +54,19 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ model_context = data.get("model_context") + has_echo = True if model_context and "prompt_echo_len_char" in model_context: prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1)) + has_echo = bool(model_context.get("echo", True)) if prompt_echo_len_char != -1: skip_echo_len = prompt_echo_len_char if data.get("error_code", 0) == 0: - if "vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL: + if has_echo and ("vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL): # TODO Judging from model_context # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() - elif "guanaco" in CFG.LLM_MODEL: + elif has_echo and "guanaco" in CFG.LLM_MODEL: # NO stream output # output = data["text"][skip_echo_len + 2:].replace("", "").strip() diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index ccb70da75..e4520931f 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -69,6 +69,7 @@ class BaseChat(ABC): self.chat_mode = chat_mode self.current_user_input: str = current_user_input self.llm_model = CFG.LLM_MODEL + self.llm_echo = False ### can configurable storage methods self.memory = DuckdbHistoryMemory(chat_session_id) @@ -128,6 +129,7 @@ class BaseChat(ABC): "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), "stop": self.prompt_template.sep, + "echo": self.llm_echo, } return payload diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 1c0d73bac..422fc1117 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -62,7 +62,12 @@ class BaseChatAdpter: # TODO remote bos token and eos token from tokenizer_config.json of model prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) model_context["prompt_echo_len_char"] = prompt_echo_len_char + model_context["echo"] = params.get("echo", True) params["prompt"] = new_prompt + + # Overwrite model params: + params["stop"] = conv.stop_str + return params, model_context @@ -195,6 +200,19 @@ class Llama2ChatAdapter(BaseChatAdpter): return generate_stream +class BaichuanChatAdapter(BaseChatAdpter): + def match(self, model_path: str): + return "baichuan" in model_path.lower() + + def get_conv_template(self) -> Conversation: + return get_conv_template("baichuan-chat") + + def get_generate_stream_func(self): + from pilot.model.inference import generate_stream + + return generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) @@ -202,6 +220,7 @@ register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GPT4AllChatAdapter) register_llm_model_chat_adapter(Llama2ChatAdapter) +register_llm_model_chat_adapter(BaichuanChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 43f5a1f94..3f97e3f86 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -136,6 +136,7 @@ class PromptRequest(BaseModel): max_new_tokens: int model: str stop: str = None + echo: bool = True class StreamRequest(BaseModel): @@ -178,6 +179,7 @@ def generate(prompt_request: PromptRequest) -> str: "temperature": prompt_request.temperature, "max_new_tokens": prompt_request.max_new_tokens, "stop": prompt_request.stop, + "echo": prompt_request.echo, } rsp_str = "" diff --git a/requirements.txt b/requirements.txt index 3bbdabb4e..8fbe1b9c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,6 +28,7 @@ pyyaml==6.0 tokenizers==0.13.2 tqdm==4.64.1 transformers==4.30.0 +transformers_stream_generator timm==0.6.13 spacy==3.5.3 webdataset==0.2.48