mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-12 05:32:32 +00:00
feat(model) Support the DB-GPT-Hub trained mode (#760)
Close #745 Support the DB-GPT-Hub trained model, codellama series .
This commit is contained in:
commit
852cf673bb
@ -78,6 +78,10 @@ LLM_MODEL_CONFIG = {
|
|||||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
|
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
|
||||||
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
||||||
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
|
"internlm-20b": os.path.join(MODEL_PATH, "internlm-chat-20b"),
|
||||||
|
"codellama-7b": os.path.join(MODEL_PATH, "CodeLlama-7b-Instruct-hf"),
|
||||||
|
"codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"),
|
||||||
|
"codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"),
|
||||||
|
"codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"),
|
||||||
# For test now
|
# For test now
|
||||||
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
|
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
|
||||||
}
|
}
|
||||||
|
@ -320,6 +320,19 @@ class Llama2Adapter(BaseLLMAdaper):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class CodeLlamaAdapter(BaseLLMAdaper):
|
||||||
|
"""The model adapter for codellama"""
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "codellama" in model_path.lower()
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
|
||||||
|
model.config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
model.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
class BaichuanAdapter(BaseLLMAdaper):
|
class BaichuanAdapter(BaseLLMAdaper):
|
||||||
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
|
"""The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-13B-Chat)"""
|
||||||
|
|
||||||
@ -420,6 +433,7 @@ register_llm_model_adapters(FalconAdapater)
|
|||||||
register_llm_model_adapters(GorillaAdapter)
|
register_llm_model_adapters(GorillaAdapter)
|
||||||
register_llm_model_adapters(GPT4AllAdapter)
|
register_llm_model_adapters(GPT4AllAdapter)
|
||||||
register_llm_model_adapters(Llama2Adapter)
|
register_llm_model_adapters(Llama2Adapter)
|
||||||
|
register_llm_model_adapters(CodeLlamaAdapter)
|
||||||
register_llm_model_adapters(BaichuanAdapter)
|
register_llm_model_adapters(BaichuanAdapter)
|
||||||
register_llm_model_adapters(WizardLMAdapter)
|
register_llm_model_adapters(WizardLMAdapter)
|
||||||
register_llm_model_adapters(LlamaCppAdapater)
|
register_llm_model_adapters(LlamaCppAdapater)
|
||||||
|
@ -339,6 +339,27 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# codellama template
|
||||||
|
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
||||||
|
# reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="codellama",
|
||||||
|
system="<s>[INST] <<SYS>>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
|
||||||
|
"If you don't know the answer to the request, please don't share false information.\n<</SYS>>\n\n",
|
||||||
|
roles=("[INST]", "[/INST]"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.LLAMA2,
|
||||||
|
sep=" ",
|
||||||
|
sep2=" </s><s>",
|
||||||
|
stop_token_ids=[2],
|
||||||
|
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Alpaca default template
|
# Alpaca default template
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
|
@ -45,6 +45,10 @@ _OLD_MODELS = [
|
|||||||
"llama-cpp",
|
"llama-cpp",
|
||||||
"proxyllm",
|
"proxyllm",
|
||||||
"gptj-6b",
|
"gptj-6b",
|
||||||
|
"codellama-13b-sql-sft",
|
||||||
|
"codellama-7b",
|
||||||
|
"codellama-7b-sql-sft",
|
||||||
|
"codellama-13b",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -148,8 +152,12 @@ class LLMModelAdaper:
|
|||||||
conv.append_message(conv.roles[1], content)
|
conv.append_message(conv.roles[1], content)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown role: {role}")
|
raise ValueError(f"Unknown role: {role}")
|
||||||
|
|
||||||
if system_messages:
|
if system_messages:
|
||||||
conv.set_system_message("".join(system_messages))
|
if isinstance(conv, Conversation):
|
||||||
|
conv.set_system_message("".join(system_messages))
|
||||||
|
else:
|
||||||
|
conv.update_system_message("".join(system_messages))
|
||||||
|
|
||||||
# Add a blank message for the assistant.
|
# Add a blank message for the assistant.
|
||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
|
@ -215,6 +215,16 @@ class Llama2ChatAdapter(BaseChatAdpter):
|
|||||||
return get_conv_template("llama-2")
|
return get_conv_template("llama-2")
|
||||||
|
|
||||||
|
|
||||||
|
class CodeLlamaChatAdapter(BaseChatAdpter):
|
||||||
|
"""The model ChatAdapter for codellama ."""
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "codellama" in model_path.lower()
|
||||||
|
|
||||||
|
def get_conv_template(self, model_path: str) -> Conversation:
|
||||||
|
return get_conv_template("codellama")
|
||||||
|
|
||||||
|
|
||||||
class BaichuanChatAdapter(BaseChatAdpter):
|
class BaichuanChatAdapter(BaseChatAdpter):
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "baichuan" in model_path.lower()
|
return "baichuan" in model_path.lower()
|
||||||
@ -268,6 +278,7 @@ register_llm_model_chat_adapter(FalconChatAdapter)
|
|||||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||||
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||||
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
||||||
|
register_llm_model_chat_adapter(CodeLlamaChatAdapter)
|
||||||
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
||||||
register_llm_model_chat_adapter(WizardLMChatAdapter)
|
register_llm_model_chat_adapter(WizardLMChatAdapter)
|
||||||
register_llm_model_chat_adapter(LlamaCppChatAdapter)
|
register_llm_model_chat_adapter(LlamaCppChatAdapter)
|
||||||
|
Loading…
Reference in New Issue
Block a user