diff --git a/.env.template b/.env.template
index e03650033..272ee2922 100644
--- a/.env.template
+++ b/.env.template
@@ -22,7 +22,8 @@ WEB_SERVER_PORT=7860
#** LLM MODELS **#
#*******************************************************************#
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
-LLM_MODEL=vicuna-13b-v1.5
+# LLM_MODEL=vicuna-13b-v1.5
+LLM_MODEL=codellama-13b-sql-sft
## LLM model path, by default, DB-GPT will read the model path from LLM_MODEL_CONFIG based on the LLM_MODEL.
## Of course you can specify your model path according to LLM_MODEL_PATH
## In DB-GPT, the priority from high to low to read model path:
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index e1575ea03..16deee50a 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -78,6 +78,8 @@ LLM_MODEL_CONFIG = {
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
+ "codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
+
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 69b159a13..02fbe8aa9 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -319,6 +319,18 @@ class Llama2Adapter(BaseLLMAdaper):
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
+class CodeLlamaAdapter(BaseLLMAdaper):
+ """The model adapter for codellama """
+
+ def match(self, model_path: str):
+ return "codelama" in model_path.lower()
+
+ def loader(self, model_path: str, from_pretrained_kwargs: dict):
+ model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
+ model.config.eos_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ return model, tokenizer
+
class BaichuanAdapter(BaseLLMAdaper):
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
@@ -420,6 +432,7 @@ register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
+register_llm_model_adapters(CodeLlamaAdapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater)
diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py
index b3674e946..98dfc720d 100644
--- a/pilot/model/conversation.py
+++ b/pilot/model/conversation.py
@@ -339,6 +339,28 @@ register_conv_template(
)
)
+
+# codellama template
+# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
+# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
+register_conv_template(
+ Conversation(
+ name="codellama",
+ system="[INST] <>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
+ "If you don't know the answer to the request, please don't share false information.\n<>\n\n",
+ roles=("[INST]", "[/INST]"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA2,
+ sep=" ",
+ sep2=" ",
+ stop_token_ids=[2],
+ system_formatter=lambda msg: f"[INST] <>\n{msg}\n<>\n\n",
+ )
+)
+
+
+
# Alpaca default template
register_conv_template(
Conversation(
diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py
index 1580e8863..112fb468a 100644
--- a/pilot/model/model_adapter.py
+++ b/pilot/model/model_adapter.py
@@ -45,6 +45,7 @@ _OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
+ "codellama-13b-sql-sft"
]
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index cb486021b..509305247 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -213,6 +213,15 @@ class Llama2ChatAdapter(BaseChatAdpter):
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")
+
+
+class CodeLlamaChatAdapter(BaseChatAdpter):
+ """The model ChatAdapter for codellama ."""
+ def match(self, model_path: str):
+ return "codelama" in model_path.lower()
+
+ def get_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("codellama")
class BaichuanChatAdapter(BaseChatAdpter):
@@ -268,6 +277,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
+register_llm_model_chat_adapter(CodeLlamaChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)