mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat:llm manage
This commit is contained in:
parent
86ad276244
commit
eca313a608
@ -1,11 +1,10 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from pilot.component import ComponentType
|
||||
from pilot.configs.config import Config
|
||||
from pilot.model.base import ModelInstance, WorkerApplyType
|
||||
|
||||
from pilot.model.cluster import WorkerStartupRequest
|
||||
from pilot.model.cluster import WorkerStartupRequest, WorkerManager
|
||||
from pilot.openapi.api_view_model import Result
|
||||
|
||||
from pilot.server.llm_manage.request.request import ModelResponse
|
||||
@ -14,15 +13,30 @@ CFG = Config()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/controller/list")
|
||||
async def controller_list(request: ModelInstance):
|
||||
print(f"/controller/list params:")
|
||||
# @router.post("/controller/list")
|
||||
# async def controller_list(request: ModelInstance):
|
||||
# print(f"/controller/list params:")
|
||||
# try:
|
||||
# CFG.LLM_MODEL = request.model_name
|
||||
# return Result.succ("success")
|
||||
#
|
||||
# except Exception as e:
|
||||
# return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||
@router.get("/v1/worker/model/params")
|
||||
async def model_params():
|
||||
print(f"/worker/model/params")
|
||||
try:
|
||||
CFG.LLM_MODEL = request.model_name
|
||||
return Result.succ("success")
|
||||
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return Result.succ(await worker_manager.supported_models())
|
||||
if not worker_instance:
|
||||
return Result.faild(code="E000X", msg=f"can not find worker manager")
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||
return Result.faild(code="E000X", msg=f"model stop failed {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/v1/worker/model/list")
|
||||
@ -73,13 +87,12 @@ async def model_start(request: WorkerStartupRequest):
|
||||
controller = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||
)
|
||||
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
|
||||
instances = await controller.get_all_instances(
|
||||
model_name="WorkerManager@service", healthy_only=True
|
||||
)
|
||||
worker_instance = None
|
||||
for instance in instances:
|
||||
if (
|
||||
instance.host == request.host
|
||||
and instance.port == request.port
|
||||
):
|
||||
if instance.host == request.host and instance.port == request.port:
|
||||
from pilot.model.cluster import ModelRegistryClient
|
||||
from pilot.model.cluster import RemoteWorkerManager
|
||||
|
||||
@ -101,13 +114,12 @@ async def model_start(request: WorkerStartupRequest):
|
||||
controller = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||
)
|
||||
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
|
||||
instances = await controller.get_all_instances(
|
||||
model_name="WorkerManager@service", healthy_only=True
|
||||
)
|
||||
worker_instance = None
|
||||
for instance in instances:
|
||||
if (
|
||||
instance.host == request.host
|
||||
and instance.port == request.port
|
||||
):
|
||||
if instance.host == request.host and instance.port == request.port:
|
||||
from pilot.model.cluster import ModelRegistryClient
|
||||
from pilot.model.cluster import RemoteWorkerManager
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user