From 1356759f48aceba7771911fb03b62566b707760f Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Mon, 18 Sep 2023 19:29:37 +0800 Subject: [PATCH] feat(model):llm manage --- pilot/server/dbgpt_server.py | 2 + pilot/server/llm_manage/api.py | 120 +++++++++++++++++++++ pilot/server/llm_manage/request/request.py | 28 +++++ 3 files changed, 150 insertions(+) create mode 100644 pilot/server/llm_manage/api.py create mode 100644 pilot/server/llm_manage/request/request.py diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index d2307f06a..c2a510eaf 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from pilot.server.knowledge.api import router as knowledge_router +from pilot.server.llm_manage.api import router as llm_manage_api from pilot.openapi.api_v1.api_v1 import router as api_v1 @@ -71,6 +72,7 @@ app.add_middleware( app.include_router(api_v1, prefix="/api") app.include_router(knowledge_router, prefix="/api") app.include_router(api_editor_route_v1, prefix="/api") +app.include_router(llm_manage_api, prefix="/api") # app.include_router(api_v1) app.include_router(knowledge_router) diff --git a/pilot/server/llm_manage/api.py b/pilot/server/llm_manage/api.py new file mode 100644 index 000000000..d68083940 --- /dev/null +++ b/pilot/server/llm_manage/api.py @@ -0,0 +1,120 @@ + +from fastapi import APIRouter + +from pilot.componet import ComponetType +from pilot.configs.config import Config +from pilot.model.base import ModelInstance, WorkerApplyType + +from pilot.model.cluster import WorkerStartupRequest +from pilot.openapi.api_view_model import Result + +from pilot.server.llm_manage.request.request import ModelResponse + +CFG = Config() +router = APIRouter() + + +@router.post("/controller/list") +async def controller_list(request: ModelInstance): + print(f"/controller/list params:") + try: + CFG.LLM_MODEL = request.model_name + return Result.succ("success") + + except Exception as e: + return Result.faild(code="E000X", msg=f"space list error {e}") + + +@router.get("/v1/worker/model/list") +async def model_list(): + print(f"/worker/model/list") + try: + from pilot.model.cluster.controller.controller import BaseModelController + + controller = CFG.SYSTEM_APP.get_componet( + ComponetType.MODEL_CONTROLLER, BaseModelController + ) + responses = [] + managers = await controller.get_all_instances( + model_name="WorkerManager@service", healthy_only=True + ) + manager_map = dict(map(lambda manager: (manager.host, manager), managers)) + models = await controller.get_all_instances() + for model in models: + worker_name, worker_type = model.model_name.split("@") + if worker_type == "llm" or worker_type == "text2vec": + response = ModelResponse( + model_name=worker_name, + model_type=worker_type, + host=model.host, + port=model.port, + healthy=model.healthy, + check_healthy=model.check_healthy, + last_heartbeat=model.last_heartbeat, + prompt_template=model.prompt_template, + ) + response.manager_host = model.host if manager_map[model.host] else None + response.manager_port = ( + manager_map[model.host].port if manager_map[model.host] else None + ) + responses.append(response) + return Result.succ(responses) + + except Exception as e: + return Result.faild(code="E000X", msg=f"space list error {e}") + + +@router.post("/v1/worker/model/stop") +async def model_start(request: WorkerStartupRequest): + print(f"/v1/worker/model/stop:") + try: + from pilot.model.cluster.controller.controller import BaseModelController + + controller = CFG.SYSTEM_APP.get_componet( + ComponetType.MODEL_CONTROLLER, BaseModelController + ) + instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) + worker_instance = None + for instance in instances: + if ( + instance.host == request.host + and instance.port == request.port + ): + from pilot.model.cluster import ModelRegistryClient + from pilot.model.cluster import RemoteWorkerManager + + registry = ModelRegistryClient(f"http://{request.host}:{request.port}") + worker_manager = RemoteWorkerManager(registry) + return Result.succ(await worker_manager.model_shutdown(request)) + if not worker_instance: + return Result.faild(code="E000X", msg=f"can not find worker manager") + except Exception as e: + return Result.faild(code="E000X", msg=f"model stop failed {e}") + + +@router.post("/v1/worker/model/start") +async def model_start(request: WorkerStartupRequest): + print(f"/v1/worker/model/start:") + try: + from pilot.model.cluster.controller.controller import BaseModelController + + controller = CFG.SYSTEM_APP.get_componet( + ComponetType.MODEL_CONTROLLER, BaseModelController + ) + instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) + worker_instance = None + for instance in instances: + if ( + instance.host == request.host + and instance.port == request.port + ): + from pilot.model.cluster import ModelRegistryClient + from pilot.model.cluster import RemoteWorkerManager + + registry = ModelRegistryClient(f"http://{request.host}:{request.port}") + worker_manager = RemoteWorkerManager(registry) + return Result.succ(await worker_manager.model_startup(request)) + if not worker_instance: + return Result.faild(code="E000X", msg=f"can not find worker manager") + except Exception as e: + return Result.faild(code="E000X", msg=f"model start failed {e}") diff --git a/pilot/server/llm_manage/request/request.py b/pilot/server/llm_manage/request/request.py new file mode 100644 index 000000000..e74f6768a --- /dev/null +++ b/pilot/server/llm_manage/request/request.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + + +@dataclass +class ModelResponse: + """ModelRequest""" + + """model_name: model_name""" + model_name: str = None + """model_type: model_type""" + model_type: str = None + """host: host""" + host: str = None + """port: port""" + port: int = None + """manager_host: manager_host""" + manager_host: str = None + """manager_port: manager_port""" + manager_port: int = None + """healthy: healthy""" + healthy: bool = True + + """check_healthy: check_healthy""" + check_healthy: bool = True + prompt_template: str = None + last_heartbeat: str = None + stream_api: str = None + nostream_api: str = None