mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 19:04:24 +00:00
fix(model): Fix the bug that the webserver cannot return model instances
This commit is contained in:
parent
7b64c03d58
commit
f304f9709d
Binary file not shown.
Before Width: | Height: | Size: 256 KiB After Width: | Height: | Size: 141 KiB |
@ -3,12 +3,15 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
|
||||
from enum import Enum
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
# Checking for type hints during runtime
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LifeCycle:
|
||||
"""This class defines hooks for lifecycle events of a component."""
|
||||
@ -40,6 +43,7 @@ class LifeCycle:
|
||||
|
||||
class ComponetType(str, Enum):
|
||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||
|
||||
|
||||
class BaseComponet(LifeCycle, ABC):
|
||||
@ -92,6 +96,7 @@ class SystemApp(LifeCycle):
|
||||
raise RuntimeError(
|
||||
f"Componse name {name} already exists: {self.componets[name]}"
|
||||
)
|
||||
logger.info(f"Register componet with name {name} and instance: {instance}")
|
||||
self.componets[name] = instance
|
||||
instance.init_app(self)
|
||||
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from pilot.componet import BaseComponet, ComponetType, SystemApp
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.parameter import ModelControllerParameters
|
||||
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
@ -14,7 +15,12 @@ from pilot.utils.api_utils import (
|
||||
)
|
||||
|
||||
|
||||
class BaseModelController(ABC):
|
||||
class BaseModelController(BaseComponet, ABC):
|
||||
name = ComponetType.MODEL_CONTROLLER
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
"""Register a given model instance"""
|
||||
@ -25,7 +31,7 @@ class BaseModelController(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||
|
||||
@ -51,7 +57,7 @@ class LocalModelController(BaseModelController):
|
||||
return await self.registry.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
logging.info(
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
@ -94,7 +100,7 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||
|
||||
@sync_api_remote(path="/api/controller/models")
|
||||
def sync_get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@ -110,7 +116,7 @@ class ModelControllerAdapter(BaseModelController):
|
||||
return await self.backend.deregister_instance(instance)
|
||||
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return await self.backend.get_all_instances(model_name, healthy_only)
|
||||
|
||||
|
@ -18,6 +18,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from typing import List
|
||||
import tempfile
|
||||
|
||||
from pilot.componet import ComponetType
|
||||
from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
@ -352,20 +353,17 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
async def model_types(request: Request):
|
||||
print(f"/controller/model/types")
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
base_url = request.base_url
|
||||
response = await client.get(
|
||||
f"{base_url}api/controller/models?healthy_only=true",
|
||||
)
|
||||
types = set()
|
||||
if response.status_code == 200:
|
||||
models = json.loads(response.text)
|
||||
for model in models:
|
||||
worker_type = model["model_name"].split("@")[1]
|
||||
if worker_type == "llm":
|
||||
types.add(model["model_name"].split("@")[0])
|
||||
from pilot.model.cluster.controller.controller import BaseModelController
|
||||
|
||||
controller = CFG.SYSTEM_APP.get_componet(
|
||||
ComponetType.MODEL_CONTROLLER, BaseModelController
|
||||
)
|
||||
models = await controller.get_all_instances(healthy_only=True)
|
||||
for model in models:
|
||||
worker_name, worker_type = model.model_name.split("@")
|
||||
if worker_type == "llm":
|
||||
types.add(worker_name)
|
||||
return Result.succ(list(types))
|
||||
|
||||
except Exception as e:
|
||||
|
@ -9,10 +9,12 @@ if TYPE_CHECKING:
|
||||
|
||||
def initialize_componets(system_app: SystemApp, embedding_model_name: str):
|
||||
from pilot.model.cluster import worker_manager
|
||||
from pilot.model.cluster.controller.controller import controller
|
||||
|
||||
system_app.register(
|
||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||
)
|
||||
system_app.register_instance(controller)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
|
Loading…
Reference in New Issue
Block a user