diff --git a/pilot/model/base.py b/pilot/model/base.py index b6eb9da25..1d46b3161 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -3,7 +3,7 @@ from enum import Enum from typing import TypedDict, Optional, Dict, List -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import datetime from pilot.utils.parameter_utils import ParameterDescription @@ -84,3 +84,25 @@ class WorkerSupportedModel: ] worker_data["models"] = models return cls(**worker_data) + + +@dataclass +class FlatSupportedModel(SupportedModel): + """For web""" + + host: str + port: int + + @staticmethod + def from_supports( + supports: List[WorkerSupportedModel], + ) -> List["FlatSupportedModel"]: + results = [] + for s in supports: + host, port, models = s.host, s.port, s.models + for m in models: + kwargs = asdict(m) + kwargs["host"] = host + kwargs["port"] = port + results.append(FlatSupportedModel(**kwargs)) + return results diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py index b73fd7873..9937ffa0b 100644 --- a/pilot/model/cluster/__init__.py +++ b/pilot/model/cluster/__init__.py @@ -5,6 +5,7 @@ from pilot.model.cluster.base import ( WorkerParameterRequest, WorkerStartupRequest, ) +from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker.default_worker import DefaultModelWorker @@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry from pilot.model.cluster.controller.controller import ( ModelRegistryClient, run_model_controller, + BaseModelController, ) from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager @@ -28,6 +30,7 @@ __all__ = [ "WorkerApplyRequest", "WorkerParameterRequest", "WorkerStartupRequest", + "WorkerManagerFactory", "ModelWorker", "DefaultModelWorker", "worker_manager", diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 9131490f5..690a6afbf 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -172,7 +172,9 @@ def _list_supported_models( llm_adapter = get_llm_model_adapter(model_name, model_path) param_cls = llm_adapter.model_param_class() model.enabled = True - params = _get_parameter_descriptions(param_cls) + params = _get_parameter_descriptions( + param_cls, model_name=model_name, model_path=model_path + ) model.params = params except Exception: pass diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 45d00c13c..8af1528cc 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -3,6 +3,7 @@ import uuid import asyncio import os import shutil +import logging from fastapi import ( APIRouter, Request, @@ -11,6 +12,7 @@ from fastapi import ( Form, Body, BackgroundTasks, + Depends, ) from fastapi.responses import StreamingResponse @@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.summary.db_summary_client import DBSummaryClient +from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory +from pilot.model.base import FlatSupportedModel + router = APIRouter() CFG = Config() CHAT_FACTORY = ChatFactory() -logger = build_logger("api_v1", LOGDIR + "api_v1.log") +logger = logging.getLogger(__name__) knowledge_service = KnowledgeService() model_semaphore = None @@ -90,6 +95,20 @@ def knowledge_list(): return params +def get_model_controller() -> BaseModelController: + controller = CFG.SYSTEM_APP.get_componet( + ComponetType.MODEL_CONTROLLER, BaseModelController + ) + return controller + + +def get_worker_manager() -> WorkerManager: + worker_manager = CFG.SYSTEM_APP.get_componet( + ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + return worker_manager + + @router.get("/v1/chat/db/list", response_model=Result[DBConfig]) async def db_connect_list(): return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) @@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()): @router.get("/v1/model/types") -async def model_types(request: Request): - print(f"/controller/model/types") +async def model_types(controller: BaseModelController = Depends(get_model_controller)): + logger.info(f"/controller/model/types") try: types = set() - from pilot.model.cluster.controller.controller import BaseModelController - - controller = CFG.SYSTEM_APP.get_componet( - ComponetType.MODEL_CONTROLLER, BaseModelController - ) models = await controller.get_all_instances(healthy_only=True) for model in models: worker_name, worker_type = model.model_name.split("@") @@ -371,6 +385,16 @@ async def model_types(request: Request): return Result.faild(code="E000X", msg=f"controller model types error {e}") +@router.get("/v1/model/supports") +async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)): + logger.info(f"/controller/model/supports") + try: + models = await worker_manager.supported_models() + return Result.succ(FlatSupportedModel.from_supports(models)) + except Exception as e: + return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}") + + async def no_stream_generator(chat): msg = await chat.nostream_call() msg = msg.replace("\n", "\\n") diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 9e7a22373..a11d5086f 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -6,6 +6,7 @@ from typing import Any, List, Dict from pilot.configs.config import Config from pilot.configs.model_config import LOGDIR +from pilot.componet import ComponetType from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory @@ -142,8 +143,11 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - from pilot.model.cluster import worker_manager + from pilot.model.cluster import WorkerManagerFactory + worker_manager = CFG.SYSTEM_APP.get_componet( + ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() async for output in worker_manager.generate_stream(payload): yield output except Exception as e: @@ -160,7 +164,11 @@ class BaseChat(ABC): logger.info(f"Request: \n{payload}") ai_response_text = "" try: - from pilot.model.cluster import worker_manager + from pilot.model.cluster import WorkerManagerFactory + + worker_manager = CFG.SYSTEM_APP.get_componet( + ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() model_output = await worker_manager.generate(payload) diff --git a/pilot/server/base.py b/pilot/server/base.py index 888ebbf3d..34b48f599 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp): def _create_model_start_listener(system_app: SystemApp): from pilot.connections.manages.connection_manager import ConnectManager - from pilot.model.cluster import worker_manager cfg = Config() diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py index d46b626ca..41b5d2ddb 100644 --- a/pilot/server/componet_configs.py +++ b/pilot/server/componet_configs.py @@ -1,14 +1,10 @@ from __future__ import annotations -from typing import Any, Type, TYPE_CHECKING - -from pilot.componet import SystemApp import logging -from pilot.configs.model_config import get_device -from pilot.embedding_engine.embedding_factory import ( - EmbeddingFactory, - DefaultEmbeddingFactory, -) +from typing import TYPE_CHECKING, Any, Type + +from pilot.componet import ComponetType, SystemApp +from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.server.base import WebWerverParameters if TYPE_CHECKING: @@ -39,13 +35,9 @@ def _initialize_embedding_model( embedding_model_name: str, embedding_model_path: str, ): - from pilot.model.cluster import worker_manager - if param.remote_embedding: logger.info("Register remote RemoteEmbeddingFactory") - system_app.register( - RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name - ) + system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name) else: logger.info(f"Register local LocalEmbeddingFactory") system_app.register( @@ -56,26 +48,28 @@ def _initialize_embedding_model( class RemoteEmbeddingFactory(EmbeddingFactory): - def __init__( - self, system_app, worker_manager, model_name: str = None, **kwargs: Any - ) -> None: + def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None: super().__init__(system_app=system_app) - self._worker_manager = worker_manager self._default_model_name = model_name self.kwargs = kwargs + self.system_app = system_app def init_app(self, system_app): - pass + self.system_app = system_app def create( self, model_name: str = None, embedding_cls: Type = None ) -> "Embeddings": + from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings if embedding_cls: raise NotImplementedError + worker_manager = self.system_app.get_componet( + ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() # Ignore model_name args - return RemoteEmbeddings(self._default_model_name, self._worker_manager) + return RemoteEmbeddings(self._default_model_name, worker_manager) class LocalEmbeddingFactory(EmbeddingFactory): @@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory): return self._model def _load_model(self) -> "Embeddings": - from pilot.model.parameter import ( - EmbeddingModelParameters, - BaseEmbeddingModelParameters, - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, - ) - from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params from pilot.model.cluster.embedding.loader import EmbeddingLoader + from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params + from pilot.model.parameter import ( + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + BaseEmbeddingModelParameters, + EmbeddingModelParameters, + ) param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( self._default_model_name, EmbeddingModelParameters diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py index b67282442..5be747e23 100644 --- a/pilot/utils/parameter_utils.py +++ b/pilot/utils/parameter_utils.py @@ -12,6 +12,7 @@ class ParameterDescription: param_type: str default_value: Optional[Any] description: str + required: Optional[bool] valid_values: Optional[List[Any]] ext_metadata: Dict @@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type: return type_mapping.get(type_str, str) -def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: +def _get_parameter_descriptions( + dataclass_type: Type, **kwargs +) -> List[ParameterDescription]: descriptions = [] for field in fields(dataclass_type): ext_metadata = { k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"] } - + default_value = field.default if field.default != MISSING else None + if field.name in kwargs: + default_value = kwargs[field.name] descriptions.append( ParameterDescription( param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}", param_name=field.name, param_type=EnvArgumentParser._get_argparse_type_str(field.type), description=field.metadata.get("help", None), - default_value=field.default if field.default != MISSING else None, + required=field.default is MISSING, + default_value=default_value, valid_values=field.metadata.get("valid_values", None), ext_metadata=ext_metadata, )