fix(model): Fix the bug that the webserver cannot return model instances

This commit is contained in:
FangYin Cheng 2023-09-14 12:36:18 +08:00
parent 7b64c03d58
commit f304f9709d
5 changed files with 29 additions and 18 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 256 KiB

After

Width:  |  Height:  |  Size: 141 KiB

View File

@ -3,12 +3,15 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum from enum import Enum
import logging
import asyncio import asyncio
# Checking for type hints during runtime # Checking for type hints during runtime
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
logger = logging.getLogger(__name__)
class LifeCycle: class LifeCycle:
"""This class defines hooks for lifecycle events of a component.""" """This class defines hooks for lifecycle events of a component."""
@ -40,6 +43,7 @@ class LifeCycle:
class ComponetType(str, Enum): class ComponetType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER = "dbgpt_worker_manager"
MODEL_CONTROLLER = "dbgpt_model_controller"
class BaseComponet(LifeCycle, ABC): class BaseComponet(LifeCycle, ABC):
@ -92,6 +96,7 @@ class SystemApp(LifeCycle):
raise RuntimeError( raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}" f"Componse name {name} already exists: {self.componets[name]}"
) )
logger.info(f"Register componet with name {name} and instance: {instance}")
self.componets[name] = instance self.componets[name] = instance
instance.init_app(self) instance.init_app(self)

View File

@ -4,6 +4,7 @@ import logging
from typing import List from typing import List
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp
from pilot.model.base import ModelInstance from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
@ -14,7 +15,12 @@ from pilot.utils.api_utils import (
) )
class BaseModelController(ABC): class BaseModelController(BaseComponet, ABC):
name = ComponetType.MODEL_CONTROLLER
def init_app(self, system_app: SystemApp):
pass
@abstractmethod @abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool: async def register_instance(self, instance: ModelInstance) -> bool:
"""Register a given model instance""" """Register a given model instance"""
@ -25,7 +31,7 @@ class BaseModelController(ABC):
@abstractmethod @abstractmethod
async def get_all_instances( async def get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" """Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@ -51,7 +57,7 @@ class LocalModelController(BaseModelController):
return await self.registry.deregister_instance(instance) return await self.registry.deregister_instance(instance)
async def get_all_instances( async def get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
logging.info( logging.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}" f"Get all instances with {model_name}, healthy_only: {healthy_only}"
@ -94,7 +100,7 @@ class ModelRegistryClient(_RemoteModelController, ModelRegistry):
@sync_api_remote(path="/api/controller/models") @sync_api_remote(path="/api/controller/models")
def sync_get_all_instances( def sync_get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
pass pass
@ -110,7 +116,7 @@ class ModelControllerAdapter(BaseModelController):
return await self.backend.deregister_instance(instance) return await self.backend.deregister_instance(instance)
async def get_all_instances( async def get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
return await self.backend.get_all_instances(model_name, healthy_only) return await self.backend.get_all_instances(model_name, healthy_only)

View File

@ -18,6 +18,7 @@ from fastapi.exceptions import RequestValidationError
from typing import List from typing import List
import tempfile import tempfile
from pilot.componet import ComponetType
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
Result, Result,
ConversationVo, ConversationVo,
@ -352,20 +353,17 @@ async def chat_completions(dialogue: ConversationVo = Body()):
async def model_types(request: Request): async def model_types(request: Request):
print(f"/controller/model/types") print(f"/controller/model/types")
try: try:
import httpx
async with httpx.AsyncClient() as client:
base_url = request.base_url
response = await client.get(
f"{base_url}api/controller/models?healthy_only=true",
)
types = set() types = set()
if response.status_code == 200: from pilot.model.cluster.controller.controller import BaseModelController
models = json.loads(response.text)
for model in models: controller = CFG.SYSTEM_APP.get_componet(
worker_type = model["model_name"].split("@")[1] ComponetType.MODEL_CONTROLLER, BaseModelController
if worker_type == "llm": )
types.add(model["model_name"].split("@")[0]) models = await controller.get_all_instances(healthy_only=True)
for model in models:
worker_name, worker_type = model.model_name.split("@")
if worker_type == "llm":
types.add(worker_name)
return Result.succ(list(types)) return Result.succ(list(types))
except Exception as e: except Exception as e:

View File

@ -9,10 +9,12 @@ if TYPE_CHECKING:
def initialize_componets(system_app: SystemApp, embedding_model_name: str): def initialize_componets(system_app: SystemApp, embedding_model_name: str):
from pilot.model.cluster import worker_manager from pilot.model.cluster import worker_manager
from pilot.model.cluster.controller.controller import controller
system_app.register( system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
) )
system_app.register_instance(controller)
class RemoteEmbeddingFactory(EmbeddingFactory): class RemoteEmbeddingFactory(EmbeddingFactory):