diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 803d0fae9..0e1fb3d40 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -82,11 +82,6 @@ LLM_MODEL_CONFIG = { "codellama-7b-sql-sft": os.path.join(MODEL_PATH, "codellama-7b-sql-sft"), "codellama-13b": os.path.join(MODEL_PATH, "CodeLlama-13b-Instruct-hf"), "codellama-13b-sql-sft": os.path.join(MODEL_PATH, "codellama-13b-sql-sft"), - - - - - # For test now "opt-125m": os.path.join(MODEL_PATH, "opt-125m"), } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index cb9885d2a..5ce5b2173 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -319,8 +319,9 @@ class Llama2Adapter(BaseLLMAdaper): model.config.pad_token_id = tokenizer.pad_token_id return model, tokenizer + class CodeLlamaAdapter(BaseLLMAdaper): - """The model adapter for codellama """ + """The model adapter for codellama""" def match(self, model_path: str): return "codellama" in model_path.lower() diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index 98dfc720d..5d4309d9f 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -360,7 +360,6 @@ register_conv_template( ) - # Alpaca default template register_conv_template( Conversation( diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index cadb1cebd..e09b868e7 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -48,7 +48,7 @@ _OLD_MODELS = [ "codellama-13b-sql-sft", "codellama-7b", "codellama-7b-sql-sft", - "codellama-13b" + "codellama-13b", ] @@ -152,8 +152,12 @@ class LLMModelAdaper: conv.append_message(conv.roles[1], content) else: raise ValueError(f"Unknown role: {role}") + if system_messages: - conv.set_system_message("".join(system_messages)) + if isinstance(conv, Conversation): + conv.set_system_message("".join(system_messages)) + else: + conv.update_system_message("".join(system_messages)) # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 4b6dd0eed..64b72739b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -213,10 +213,11 @@ class Llama2ChatAdapter(BaseChatAdpter): def get_conv_template(self, model_path: str) -> Conversation: return get_conv_template("llama-2") - + class CodeLlamaChatAdapter(BaseChatAdpter): """The model ChatAdapter for codellama .""" + def match(self, model_path: str): return "codellama" in model_path.lower()