diff --git a/pilot/componet.py b/pilot/componet.py index 705eb1193..d88ad384f 100644 --- a/pilot/componet.py +++ b/pilot/componet.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING +from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING +from enum import Enum import asyncio # Checking for type hints during runtime @@ -37,6 +38,10 @@ class LifeCycle: pass +class ComponetType(str, Enum): + WORKER_MANAGER = "dbgpt_worker_manager" + + class BaseComponet(LifeCycle, ABC): """Abstract Base Component class. All custom components should extend this.""" @@ -80,11 +85,20 @@ class SystemApp(LifeCycle): def register_instance(self, instance: T): """Register an already initialized component.""" - self.componets[instance.name] = instance + name = instance.name + if isinstance(name, ComponetType): + name = name.value + if name in self.componets: + raise RuntimeError( + f"Componse name {name} already exists: {self.componets[name]}" + ) + self.componets[name] = instance instance.init_app(self) - def get_componet(self, name: str, componet_type: Type[T]) -> T: + def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T: """Retrieve a registered component by its name and type.""" + if isinstance(name, ComponetType): + name = name.value component = self.componets.get(name) if not component: raise ValueError(f"No component found with name {name}") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index b4632003e..e80f9c8e8 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -69,6 +69,9 @@ LLM_MODEL_CONFIG = { # (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2 "wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"), "llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"), + # https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288 + "internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"), + "internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"), } EMBEDDING_MODEL_CONFIG = { diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 0b93faa40..8fe4a9057 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -411,6 +411,29 @@ class LlamaCppAdapater(BaseLLMAdaper): return model, tokenizer +class InternLMAdapter(BaseLLMAdaper): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + if "8k" in model_path.lower(): + model.config.max_sequence_length = 8192 + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) @@ -421,6 +444,7 @@ register_llm_model_adapters(Llama2Adapter) register_llm_model_adapters(BaichuanAdapter) register_llm_model_adapters(WizardLMAdapter) register_llm_model_adapters(LlamaCppAdapater) +register_llm_model_adapters(InternLMAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test_py, remove this later diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index f3b85842f..c8cb2a74b 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -2,6 +2,8 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py Conversation prompt templates. + +TODO Using fastchat core package """ import dataclasses @@ -366,4 +368,21 @@ register_conv_template( ) ) +# Internlm-chat template +register_conv_template( + Conversation( + name="internlm-chat", + system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", + roles=("<|User|>", "<|Bot|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.CHATINTERN, + sep="", + sep2="", + stop_token_ids=[1, 103028], + stop_str="", + ) +) + + # TODO Support other model conversation template diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 466a96c0f..80f22effe 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -247,6 +247,16 @@ class LlamaCppChatAdapter(BaseChatAdpter): return generate_stream +class InternLMChatAdapter(BaseChatAdpter): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def get_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("internlm-chat") + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) @@ -257,6 +267,7 @@ register_llm_model_chat_adapter(Llama2ChatAdapter) register_llm_model_chat_adapter(BaichuanChatAdapter) register_llm_model_chat_adapter(WizardLMChatAdapter) register_llm_model_chat_adapter(LlamaCppChatAdapter) +register_llm_model_chat_adapter(InternLMChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter)