Merge remote-tracking branch 'origin/main' into feat_llm_manage

This commit is contained in:
aries_ckt 2023-09-14 14:40:42 +08:00
commit cb52486e32
9 changed files with 103 additions and 21 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 256 KiB

After

Width:  |  Height:  |  Size: 141 KiB

View File

@ -1,13 +1,17 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum
import logging
import asyncio import asyncio
# Checking for type hints during runtime # Checking for type hints during runtime
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
logger = logging.getLogger(__name__)
class LifeCycle: class LifeCycle:
"""This class defines hooks for lifecycle events of a component.""" """This class defines hooks for lifecycle events of a component."""
@ -37,6 +41,11 @@ class LifeCycle:
pass pass
class ComponetType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
MODEL_CONTROLLER = "dbgpt_model_controller"
class BaseComponet(LifeCycle, ABC): class BaseComponet(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this.""" """Abstract Base Component class. All custom components should extend this."""
@ -80,11 +89,21 @@ class SystemApp(LifeCycle):
def register_instance(self, instance: T): def register_instance(self, instance: T):
"""Register an already initialized component.""" """Register an already initialized component."""
self.componets[instance.name] = instance name = instance.name
if isinstance(name, ComponetType):
name = name.value
if name in self.componets:
raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}"
)
logger.info(f"Register componet with name {name} and instance: {instance}")
self.componets[name] = instance
instance.init_app(self) instance.init_app(self)
def get_componet(self, name: str, componet_type: Type[T]) -> T: def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T:
"""Retrieve a registered component by its name and type.""" """Retrieve a registered component by its name and type."""
if isinstance(name, ComponetType):
name = name.value
component = self.componets.get(name) component = self.componets.get(name)
if not component: if not component:
raise ValueError(f"No component found with name {name}") raise ValueError(f"No component found with name {name}")

View File

@ -69,6 +69,9 @@ LLM_MODEL_CONFIG = {
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2 # (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"), "wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"), "llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
} }
EMBEDDING_MODEL_CONFIG = { EMBEDDING_MODEL_CONFIG = {

View File

@ -411,6 +411,29 @@ class LlamaCppAdapater(BaseLLMAdaper):
return model, tokenizer return model, tokenizer
class InternLMAdapter(BaseLLMAdaper):
"""The model adapter for internlm/internlm-chat-7b"""
def match(self, model_path: str):
return "internlm" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
model = model.eval()
if "8k" in model_path.lower():
model.config.max_sequence_length = 8192
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True, revision=revision
)
return model, tokenizer
register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)
@ -421,6 +444,7 @@ register_llm_model_adapters(Llama2Adapter)
register_llm_model_adapters(BaichuanAdapter) register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter) register_llm_model_adapters(WizardLMAdapter)
register_llm_model_adapters(LlamaCppAdapater) register_llm_model_adapters(LlamaCppAdapater)
register_llm_model_adapters(InternLMAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later # just for test_py, remove this later

View File

@ -4,6 +4,7 @@ import logging
from typing import List from typing import List
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp
from pilot.model.base import ModelInstance from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
@ -14,7 +15,12 @@ from pilot.utils.api_utils import (
) )
class BaseModelController(ABC): class BaseModelController(BaseComponet, ABC):
name = ComponetType.MODEL_CONTROLLER
def init_app(self, system_app: SystemApp):
pass
@abstractmethod @abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool: async def register_instance(self, instance: ModelInstance) -> bool:
"""Register a given model instance""" """Register a given model instance"""
@ -25,7 +31,7 @@ class BaseModelController(ABC):
@abstractmethod @abstractmethod
async def get_all_instances( async def get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" """Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@ -51,7 +57,7 @@ class LocalModelController(BaseModelController):
return await self.registry.deregister_instance(instance) return await self.registry.deregister_instance(instance)
async def get_all_instances( async def get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
logging.info( logging.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}" f"Get all instances with {model_name}, healthy_only: {healthy_only}"
@ -94,7 +100,7 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
@sync_api_remote(path="/api/controller/models") @sync_api_remote(path="/api/controller/models")
def sync_get_all_instances( def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
pass pass
@ -110,7 +116,7 @@ class ModelControllerAdapter(BaseModelController):
return await self.backend.deregister_instance(instance) return await self.backend.deregister_instance(instance)
async def get_all_instances( async def get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
return await self.backend.get_all_instances(model_name, healthy_only) return await self.backend.get_all_instances(model_name, healthy_only)

View File

@ -2,6 +2,8 @@
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates. Conversation prompt templates.
TODO Using fastchat core package
""" """
import dataclasses import dataclasses
@ -366,4 +368,21 @@ register_conv_template(
) )
) )
# Internlm-chat template
register_conv_template(
Conversation(
name="internlm-chat",
system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
roles=("<|User|>", "<|Bot|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.CHATINTERN,
sep="<eoh>",
sep2="<eoa>",
stop_token_ids=[1, 103028],
stop_str="<eoa>",
)
)
# TODO Support other model conversation template # TODO Support other model conversation template

View File

@ -18,6 +18,7 @@ from fastapi.exceptions import RequestValidationError
from typing import List from typing import List
import tempfile import tempfile
from pilot.componet import ComponetType
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
Result, Result,
ConversationVo, ConversationVo,
@ -352,20 +353,17 @@ async def chat_completions(dialogue: ConversationVo = Body()):
async def model_types(request: Request): async def model_types(request: Request):
print(f"/controller/model/types") print(f"/controller/model/types")
try: try:
import httpx
async with httpx.AsyncClient() as client:
base_url = request.base_url
response = await client.get(
f"{base_url}api/controller/models?healthy_only=true",
)
types = set() types = set()
if response.status_code == 200: from pilot.model.cluster.controller.controller import BaseModelController
models = json.loads(response.text)
for model in models: controller = CFG.SYSTEM_APP.get_componet(
worker_type = model["model_name"].split("@")[1] ComponetType.MODEL_CONTROLLER, BaseModelController
if worker_type == "llm": )
types.add(model["model_name"].split("@")[0]) models = await controller.get_all_instances(healthy_only=True)
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type == "llm":
types.add(worker_name)
return Result.succ(list(types)) return Result.succ(list(types))
except Exception as e: except Exception as e:

View File

@ -247,6 +247,16 @@ class LlamaCppChatAdapter(BaseChatAdpter):
return generate_stream return generate_stream
class InternLMChatAdapter(BaseChatAdpter):
"""The model adapter for internlm/internlm-chat-7b"""
def match(self, model_path: str):
return "internlm" in model_path.lower()
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("internlm-chat")
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
@ -257,6 +267,7 @@ register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter) register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter) register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter) register_llm_model_chat_adapter(LlamaCppChatAdapter)
register_llm_model_chat_adapter(InternLMChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -9,10 +9,12 @@ if TYPE_CHECKING:
def initialize_componets(system_app: SystemApp, embedding_model_name: str): def initialize_componets(system_app: SystemApp, embedding_model_name: str):
from pilot.model.cluster import worker_manager from pilot.model.cluster import worker_manager
from pilot.model.cluster.controller.controller import controller
system_app.register( system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
) )
system_app.register_instance(controller)
class RemoteEmbeddingFactory(EmbeddingFactory): class RemoteEmbeddingFactory(EmbeddingFactory):