diff --git a/.env.template b/.env.template index e03650033..272ee2922 100644 --- a/.env.template +++ b/.env.template @@ -22,7 +22,8 @@ WEB_SERVER_PORT=7860 #** LLM MODELS **# #*******************************************************************# # LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG -LLM_MODEL=vicuna-13b-v1.5 +# LLM_MODEL=vicuna-13b-v1.5 +LLM_MODEL=codellama-13b-sql-sft ## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL. ## Of course you can specify your model path according to LLM_MODEL_PATH ## In DB-GPT, the priority from high to low to read model path: diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index e1575ea03..16deee50a 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -78,6 +78,8 @@ LLM_MODEL_CONFIG = { "internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"), "internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"), "internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"), + "codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"), + # For test now "opt-125m": os.path.join(MODEL_PATH, "opt-125m"), } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 69b159a13..02fbe8aa9 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -319,6 +319,18 @@ class Llama2Adapter(BaseLLMAdaper): model.config.pad_token_id = tokenizer.pad_token_id return model, tokenizer +class CodeLlamaAdapter(BaseLLMAdaper): + """The model adapter for codellama """ + + def match(self, model_path: str): + return "codelama" in model_path.lower() + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().loader(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + class BaichuanAdapter(BaseLLMAdaper): """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)""" @@ -420,6 +432,7 @@ register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GPT4AllAdapter) register_llm_model_adapters(Llama2Adapter) +register_llm_model_adapters(CodeLlamaAdapter) register_llm_model_adapters(BaichuanAdapter) register_llm_model_adapters(WizardLMAdapter) register_llm_model_adapters(LlamaCppAdapater) diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index b3674e946..98dfc720d 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -339,6 +339,28 @@ register_conv_template( ) ) + +# codellama template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md +register_conv_template( + Conversation( + name="codellama", + system="[INST] <>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request." + "If you don't know the answer to the request, please don't share false information.\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + system_formatter=lambda msg: f"[INST] <>\n{msg}\n<>\n\n", + ) +) + + + # Alpaca default template register_conv_template( Conversation( diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index 1580e8863..112fb468a 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -45,6 +45,7 @@ _OLD_MODELS = [ "llama-cpp", "proxyllm", "gptj-6b", + "codellama-13b-sql-sft" ] diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index cb486021b..509305247 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -213,6 +213,15 @@ class Llama2ChatAdapter(BaseChatAdpter): def get_conv_template(self, model_path: str) -> Conversation: return get_conv_template("llama-2") + + +class CodeLlamaChatAdapter(BaseChatAdpter): + """The model ChatAdapter for codellama .""" + def match(self, model_path: str): + return "codelama" in model_path.lower() + + def get_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("codellama") class BaichuanChatAdapter(BaseChatAdpter): @@ -268,6 +277,7 @@ register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GPT4AllChatAdapter) register_llm_model_chat_adapter(Llama2ChatAdapter) +register_llm_model_chat_adapter(CodeLlamaChatAdapter) register_llm_model_chat_adapter(BaichuanChatAdapter) register_llm_model_chat_adapter(WizardLMChatAdapter) register_llm_model_chat_adapter(LlamaCppChatAdapter)