feat(model): support InternLM

This commit is contained in:
FangYin Cheng 2023-09-14 11:32:29 +08:00
parent 6555b6701b
commit 7b64c03d58
5 changed files with 74 additions and 3 deletions

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum
import asyncio
# Checking for type hints during runtime
@ -37,6 +38,10 @@ class LifeCycle:
pass
class ComponetType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
class BaseComponet(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""
@ -80,11 +85,20 @@ class SystemApp(LifeCycle):
def register_instance(self, instance: T):
"""Register an already initialized component."""
self.componets[instance.name] = instance
name = instance.name
if isinstance(name, ComponetType):
name = name.value
if name in self.componets:
raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}"
)
self.componets[name] = instance
instance.init_app(self)
def get_componet(self, name: str, componet_type: Type[T]) -> T:
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T:
"""Retrieve a registered component by its name and type."""
if isinstance(name, ComponetType):
name = name.value
component = self.componets.get(name)
if not component:
raise ValueError(f"No component found with name {name}")

View File

@ -69,6 +69,9 @@ LLM_MODEL_CONFIG = {
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
}
EMBEDDING_MODEL_CONFIG = {

View File

@ -411,6 +411,29 @@ class LlamaCppAdapater(BaseLLMAdaper):
return model, tokenizer
class InternLMAdapter(BaseLLMAdaper):
"""The model adapter for internlm/internlm-chat-7b"""
def match(self, model_path: str):
return "internlm" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
model = model.eval()
if "8k" in model_path.lower():
model.config.max_sequence_length = 8192
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True, revision=revision
)
return model, tokenizer
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
@ -421,6 +444,7 @@ register_llm_model_adapters(Llama2Adapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater)
register_llm_model_adapters(InternLMAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later

View File

@ -2,6 +2,8 @@
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates.
TODO Using fastchat core package
"""
import dataclasses
@ -366,4 +368,21 @@ register_conv_template(
)
)
# Internlm-chat template
register_conv_template(
Conversation(
name="internlm-chat",
system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
roles=("<|User|>", "<|Bot|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.CHATINTERN,
sep="<eoh>",
sep2="<eoa>",
stop_token_ids=[1, 103028],
stop_str="<eoa>",
)
)
# TODO Support other model conversation template

View File

@ -247,6 +247,16 @@ class LlamaCppChatAdapter(BaseChatAdpter):
return generate_stream
class InternLMChatAdapter(BaseChatAdpter):
"""The model adapter for internlm/internlm-chat-7b"""
def match(self, model_path: str):
return "internlm" in model_path.lower()
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("internlm-chat")
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
@ -257,6 +267,7 @@ register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)
register_llm_model_chat_adapter(InternLMChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)