diff --git a/pilot/server/llm_manage/api.py b/pilot/server/llm_manage/api.py index be1957c14..037a711f0 100644 --- a/pilot/server/llm_manage/api.py +++ b/pilot/server/llm_manage/api.py @@ -1,11 +1,10 @@ - from fastapi import APIRouter from pilot.component import ComponentType from pilot.configs.config import Config from pilot.model.base import ModelInstance, WorkerApplyType -from pilot.model.cluster import WorkerStartupRequest +from pilot.model.cluster import WorkerStartupRequest, WorkerManager from pilot.openapi.api_view_model import Result from pilot.server.llm_manage.request.request import ModelResponse @@ -14,15 +13,30 @@ CFG = Config() router = APIRouter() -@router.post("/controller/list") -async def controller_list(request: ModelInstance): - print(f"/controller/list params:") +# @router.post("/controller/list") +# async def controller_list(request: ModelInstance): +# print(f"/controller/list params:") +# try: +# CFG.LLM_MODEL = request.model_name +# return Result.succ("success") +# +# except Exception as e: +# return Result.faild(code="E000X", msg=f"space list error {e}") +@router.get("/v1/worker/model/params") +async def model_params(): + print(f"/worker/model/params") try: - CFG.LLM_MODEL = request.model_name - return Result.succ("success") - + from pilot.model.cluster import WorkerManagerFactory + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + return Result.succ(await worker_manager.supported_models()) + if not worker_instance: + return Result.faild(code="E000X", msg=f"can not find worker manager") except Exception as e: - return Result.faild(code="E000X", msg=f"space list error {e}") + return Result.faild(code="E000X", msg=f"model stop failed {e}") + + @router.get("/v1/worker/model/list") @@ -73,13 +87,12 @@ async def model_start(request: WorkerStartupRequest): controller = CFG.SYSTEM_APP.get_component( ComponentType.MODEL_CONTROLLER, BaseModelController ) - instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) + instances = await controller.get_all_instances( + model_name="WorkerManager@service", healthy_only=True + ) worker_instance = None for instance in instances: - if ( - instance.host == request.host - and instance.port == request.port - ): + if instance.host == request.host and instance.port == request.port: from pilot.model.cluster import ModelRegistryClient from pilot.model.cluster import RemoteWorkerManager @@ -101,13 +114,12 @@ async def model_start(request: WorkerStartupRequest): controller = CFG.SYSTEM_APP.get_component( ComponentType.MODEL_CONTROLLER, BaseModelController ) - instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) + instances = await controller.get_all_instances( + model_name="WorkerManager@service", healthy_only=True + ) worker_instance = None for instance in instances: - if ( - instance.host == request.host - and instance.port == request.port - ): + if instance.host == request.host and instance.port == request.port: from pilot.model.cluster import ModelRegistryClient from pilot.model.cluster import RemoteWorkerManager