mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 10:29:36 +00:00
Merge remote-tracking branch 'origin/main' into feat_llm_manage
This commit is contained in:
commit
cb52486e32
Binary file not shown.
Before Width: | Height: | Size: 256 KiB After Width: | Height: | Size: 141 KiB |
@ -1,13 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Type, Dict, TypeVar, Optional, TYPE_CHECKING
|
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# Checking for type hints during runtime
|
# Checking for type hints during runtime
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LifeCycle:
|
class LifeCycle:
|
||||||
"""This class defines hooks for lifecycle events of a component."""
|
"""This class defines hooks for lifecycle events of a component."""
|
||||||
@ -37,6 +41,11 @@ class LifeCycle:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ComponetType(str, Enum):
|
||||||
|
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||||
|
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||||
|
|
||||||
|
|
||||||
class BaseComponet(LifeCycle, ABC):
|
class BaseComponet(LifeCycle, ABC):
|
||||||
"""Abstract Base Component class. All custom components should extend this."""
|
"""Abstract Base Component class. All custom components should extend this."""
|
||||||
|
|
||||||
@ -80,11 +89,21 @@ class SystemApp(LifeCycle):
|
|||||||
|
|
||||||
def register_instance(self, instance: T):
|
def register_instance(self, instance: T):
|
||||||
"""Register an already initialized component."""
|
"""Register an already initialized component."""
|
||||||
self.componets[instance.name] = instance
|
name = instance.name
|
||||||
|
if isinstance(name, ComponetType):
|
||||||
|
name = name.value
|
||||||
|
if name in self.componets:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Componse name {name} already exists: {self.componets[name]}"
|
||||||
|
)
|
||||||
|
logger.info(f"Register componet with name {name} and instance: {instance}")
|
||||||
|
self.componets[name] = instance
|
||||||
instance.init_app(self)
|
instance.init_app(self)
|
||||||
|
|
||||||
def get_componet(self, name: str, componet_type: Type[T]) -> T:
|
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T:
|
||||||
"""Retrieve a registered component by its name and type."""
|
"""Retrieve a registered component by its name and type."""
|
||||||
|
if isinstance(name, ComponetType):
|
||||||
|
name = name.value
|
||||||
component = self.componets.get(name)
|
component = self.componets.get(name)
|
||||||
if not component:
|
if not component:
|
||||||
raise ValueError(f"No component found with name {name}")
|
raise ValueError(f"No component found with name {name}")
|
||||||
|
@ -69,6 +69,9 @@ LLM_MODEL_CONFIG = {
|
|||||||
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
||||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||||
|
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
|
||||||
|
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"),
|
||||||
|
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBEDDING_MODEL_CONFIG = {
|
EMBEDDING_MODEL_CONFIG = {
|
||||||
|
@ -411,6 +411,29 @@ class LlamaCppAdapater(BaseLLMAdaper):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMAdapter(BaseLLMAdaper):
|
||||||
|
"""The model adapter for internlm/internlm-chat-7b"""
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "internlm" in model_path.lower()
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
revision = from_pretrained_kwargs.get("revision", "main")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
**from_pretrained_kwargs,
|
||||||
|
)
|
||||||
|
model = model.eval()
|
||||||
|
if "8k" in model_path.lower():
|
||||||
|
model.config.max_sequence_length = 8192
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path, use_fast=False, trust_remote_code=True, revision=revision
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
register_llm_model_adapters(VicunaLLMAdapater)
|
register_llm_model_adapters(VicunaLLMAdapater)
|
||||||
register_llm_model_adapters(ChatGLMAdapater)
|
register_llm_model_adapters(ChatGLMAdapater)
|
||||||
register_llm_model_adapters(GuanacoAdapter)
|
register_llm_model_adapters(GuanacoAdapter)
|
||||||
@ -421,6 +444,7 @@ register_llm_model_adapters(Llama2Adapter)
|
|||||||
register_llm_model_adapters(BaichuanAdapter)
|
register_llm_model_adapters(BaichuanAdapter)
|
||||||
register_llm_model_adapters(WizardLMAdapter)
|
register_llm_model_adapters(WizardLMAdapter)
|
||||||
register_llm_model_adapters(LlamaCppAdapater)
|
register_llm_model_adapters(LlamaCppAdapater)
|
||||||
|
register_llm_model_adapters(InternLMAdapter)
|
||||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||||
|
|
||||||
# just for test_py, remove this later
|
# just for test_py, remove this later
|
||||||
|
@ -4,6 +4,7 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
|
from pilot.componet import BaseComponet, ComponetType, SystemApp
|
||||||
from pilot.model.base import ModelInstance
|
from pilot.model.base import ModelInstance
|
||||||
from pilot.model.parameter import ModelControllerParameters
|
from pilot.model.parameter import ModelControllerParameters
|
||||||
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||||
@ -14,7 +15,12 @@ from pilot.utils.api_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelController(ABC):
|
class BaseModelController(BaseComponet, ABC):
|
||||||
|
name = ComponetType.MODEL_CONTROLLER
|
||||||
|
|
||||||
|
def init_app(self, system_app: SystemApp):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||||
"""Register a given model instance"""
|
"""Register a given model instance"""
|
||||||
@ -25,7 +31,7 @@ class BaseModelController(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_all_instances(
|
async def get_all_instances(
|
||||||
self, model_name: str, healthy_only: bool = False
|
self, model_name: str = None, healthy_only: bool = False
|
||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||||
|
|
||||||
@ -51,7 +57,7 @@ class LocalModelController(BaseModelController):
|
|||||||
return await self.registry.deregister_instance(instance)
|
return await self.registry.deregister_instance(instance)
|
||||||
|
|
||||||
async def get_all_instances(
|
async def get_all_instances(
|
||||||
self, model_name: str, healthy_only: bool = False
|
self, model_name: str = None, healthy_only: bool = False
|
||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||||
@ -94,7 +100,7 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
|||||||
|
|
||||||
@sync_api_remote(path="/api/controller/models")
|
@sync_api_remote(path="/api/controller/models")
|
||||||
def sync_get_all_instances(
|
def sync_get_all_instances(
|
||||||
self, model_name: str, healthy_only: bool = False
|
self, model_name: str = None, healthy_only: bool = False
|
||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -110,7 +116,7 @@ class ModelControllerAdapter(BaseModelController):
|
|||||||
return await self.backend.deregister_instance(instance)
|
return await self.backend.deregister_instance(instance)
|
||||||
|
|
||||||
async def get_all_instances(
|
async def get_all_instances(
|
||||||
self, model_name: str, healthy_only: bool = False
|
self, model_name: str = None, healthy_only: bool = False
|
||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
return await self.backend.get_all_instances(model_name, healthy_only)
|
return await self.backend.get_all_instances(model_name, healthy_only)
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
|
|
||||||
Conversation prompt templates.
|
Conversation prompt templates.
|
||||||
|
|
||||||
|
TODO Using fastchat core package
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
@ -366,4 +368,21 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Internlm-chat template
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="internlm-chat",
|
||||||
|
system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
|
||||||
|
roles=("<|User|>", "<|Bot|>"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.CHATINTERN,
|
||||||
|
sep="<eoh>",
|
||||||
|
sep2="<eoa>",
|
||||||
|
stop_token_ids=[1, 103028],
|
||||||
|
stop_str="<eoa>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO Support other model conversation template
|
# TODO Support other model conversation template
|
||||||
|
@ -18,6 +18,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from typing import List
|
from typing import List
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
from pilot.componet import ComponetType
|
||||||
from pilot.openapi.api_view_model import (
|
from pilot.openapi.api_view_model import (
|
||||||
Result,
|
Result,
|
||||||
ConversationVo,
|
ConversationVo,
|
||||||
@ -352,20 +353,17 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
async def model_types(request: Request):
|
async def model_types(request: Request):
|
||||||
print(f"/controller/model/types")
|
print(f"/controller/model/types")
|
||||||
try:
|
try:
|
||||||
import httpx
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
base_url = request.base_url
|
|
||||||
response = await client.get(
|
|
||||||
f"{base_url}api/controller/models?healthy_only=true",
|
|
||||||
)
|
|
||||||
types = set()
|
types = set()
|
||||||
if response.status_code == 200:
|
from pilot.model.cluster.controller.controller import BaseModelController
|
||||||
models = json.loads(response.text)
|
|
||||||
for model in models:
|
controller = CFG.SYSTEM_APP.get_componet(
|
||||||
worker_type = model["model_name"].split("@")[1]
|
ComponetType.MODEL_CONTROLLER, BaseModelController
|
||||||
if worker_type == "llm":
|
)
|
||||||
types.add(model["model_name"].split("@")[0])
|
models = await controller.get_all_instances(healthy_only=True)
|
||||||
|
for model in models:
|
||||||
|
worker_name, worker_type = model.model_name.split("@")
|
||||||
|
if worker_type == "llm":
|
||||||
|
types.add(worker_name)
|
||||||
return Result.succ(list(types))
|
return Result.succ(list(types))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -247,6 +247,16 @@ class LlamaCppChatAdapter(BaseChatAdpter):
|
|||||||
return generate_stream
|
return generate_stream
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMChatAdapter(BaseChatAdpter):
|
||||||
|
"""The model adapter for internlm/internlm-chat-7b"""
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "internlm" in model_path.lower()
|
||||||
|
|
||||||
|
def get_conv_template(self, model_path: str) -> Conversation:
|
||||||
|
return get_conv_template("internlm-chat")
|
||||||
|
|
||||||
|
|
||||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||||
@ -257,6 +267,7 @@ register_llm_model_chat_adapter(Llama2ChatAdapter)
|
|||||||
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
||||||
register_llm_model_chat_adapter(WizardLMChatAdapter)
|
register_llm_model_chat_adapter(WizardLMChatAdapter)
|
||||||
register_llm_model_chat_adapter(LlamaCppChatAdapter)
|
register_llm_model_chat_adapter(LlamaCppChatAdapter)
|
||||||
|
register_llm_model_chat_adapter(InternLMChatAdapter)
|
||||||
|
|
||||||
# Proxy model for test and develop, it's cheap for us now.
|
# Proxy model for test and develop, it's cheap for us now.
|
||||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||||
|
@ -9,10 +9,12 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
|
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
|
||||||
from pilot.model.cluster import worker_manager
|
from pilot.model.cluster import worker_manager
|
||||||
|
from pilot.model.cluster.controller.controller import controller
|
||||||
|
|
||||||
system_app.register(
|
system_app.register(
|
||||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||||
)
|
)
|
||||||
|
system_app.register_instance(controller)
|
||||||
|
|
||||||
|
|
||||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||||
|
Loading…
Reference in New Issue
Block a user