chore(model): new required field for supported models

This commit is contained in:
FangYin Cheng 2023-09-19 09:58:48 +08:00
parent 78553477a9
commit ae34be23fd
8 changed files with 99 additions and 41 deletions

View File

@ -3,7 +3,7 @@
from enum import Enum from enum import Enum
from typing import TypedDict, Optional, Dict, List from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass from dataclasses import dataclass, asdict
from datetime import datetime from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription from pilot.utils.parameter_utils import ParameterDescription
@ -84,3 +84,25 @@ class WorkerSupportedModel:
] ]
worker_data["models"] = models worker_data["models"] = models
return cls(**worker_data) return cls(**worker_data)
@dataclass
class FlatSupportedModel(SupportedModel):
"""For web"""
host: str
port: int
@staticmethod
def from_supports(
supports: List[WorkerSupportedModel],
) -> List["FlatSupportedModel"]:
results = []
for s in supports:
host, port, models = s.host, s.port, s.models
for m in models:
kwargs = asdict(m)
kwargs["host"] = host
kwargs["port"] = port
results.append(FlatSupportedModel(**kwargs))
return results

View File

@ -5,6 +5,7 @@ from pilot.model.cluster.base import (
WorkerParameterRequest, WorkerParameterRequest,
WorkerStartupRequest, WorkerStartupRequest,
) )
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.default_worker import DefaultModelWorker from pilot.model.cluster.worker.default_worker import DefaultModelWorker
@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.controller.controller import ( from pilot.model.cluster.controller.controller import (
ModelRegistryClient, ModelRegistryClient,
run_model_controller, run_model_controller,
BaseModelController,
) )
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
@ -28,6 +30,7 @@ __all__ = [
"WorkerApplyRequest", "WorkerApplyRequest",
"WorkerParameterRequest", "WorkerParameterRequest",
"WorkerStartupRequest", "WorkerStartupRequest",
"WorkerManagerFactory",
"ModelWorker", "ModelWorker",
"DefaultModelWorker", "DefaultModelWorker",
"worker_manager", "worker_manager",

View File

@ -172,7 +172,9 @@ def _list_supported_models(
llm_adapter = get_llm_model_adapter(model_name, model_path) llm_adapter = get_llm_model_adapter(model_name, model_path)
param_cls = llm_adapter.model_param_class() param_cls = llm_adapter.model_param_class()
model.enabled = True model.enabled = True
params = _get_parameter_descriptions(param_cls) params = _get_parameter_descriptions(
param_cls, model_name=model_name, model_path=model_path
)
model.params = params model.params = params
except Exception: except Exception:
pass pass

View File

@ -3,6 +3,7 @@ import uuid
import asyncio import asyncio
import os import os
import shutil import shutil
import logging
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Request, Request,
@ -11,6 +12,7 @@ from fastapi import (
Form, Form,
Body, Body,
BackgroundTasks, BackgroundTasks,
Depends,
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
CHAT_FACTORY = ChatFactory() CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log") logger = logging.getLogger(__name__)
knowledge_service = KnowledgeService() knowledge_service = KnowledgeService()
model_semaphore = None model_semaphore = None
@ -90,6 +95,20 @@ def knowledge_list():
return params return params
def get_model_controller() -> BaseModelController:
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
)
return controller
def get_worker_manager() -> WorkerManager:
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return worker_manager
@router.get("/v1/chat/db/list", response_model=Result[DBConfig]) @router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list(): async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
@router.get("/v1/model/types") @router.get("/v1/model/types")
async def model_types(request: Request): async def model_types(controller: BaseModelController = Depends(get_model_controller)):
print(f"/controller/model/types") logger.info(f"/controller/model/types")
try: try:
types = set() types = set()
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
)
models = await controller.get_all_instances(healthy_only=True) models = await controller.get_all_instances(healthy_only=True)
for model in models: for model in models:
worker_name, worker_type = model.model_name.split("@") worker_name, worker_type = model.model_name.split("@")
@ -371,6 +385,16 @@ async def model_types(request: Request):
return Result.faild(code="E000X", msg=f"controller model types error {e}") return Result.faild(code="E000X", msg=f"controller model types error {e}")
@router.get("/v1/model/supports")
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
logger.info(f"/controller/model/supports")
try:
models = await worker_manager.supported_models()
return Result.succ(FlatSupportedModel.from_supports(models))
except Exception as e:
return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}")
async def no_stream_generator(chat): async def no_stream_generator(chat):
msg = await chat.nostream_call() msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")

View File

@ -6,6 +6,7 @@ from typing import Any, List, Dict
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.componet import ComponetType
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory
@ -142,8 +143,11 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
from pilot.model.cluster import worker_manager from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for output in worker_manager.generate_stream(payload): async for output in worker_manager.generate_stream(payload):
yield output yield output
except Exception as e: except Exception as e:
@ -160,7 +164,11 @@ class BaseChat(ABC):
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
from pilot.model.cluster import worker_manager from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload) model_output = await worker_manager.generate(payload)

View File

@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp):
def _create_model_start_listener(system_app: SystemApp): def _create_model_start_listener(system_app: SystemApp):
from pilot.connections.manages.connection_manager import ConnectManager from pilot.connections.manages.connection_manager import ConnectManager
from pilot.model.cluster import worker_manager
cfg = Config() cfg = Config()

View File

@ -1,14 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
import logging import logging
from pilot.configs.model_config import get_device from typing import TYPE_CHECKING, Any, Type
from pilot.embedding_engine.embedding_factory import (
EmbeddingFactory, from pilot.componet import ComponetType, SystemApp
DefaultEmbeddingFactory, from pilot.embedding_engine.embedding_factory import EmbeddingFactory
)
from pilot.server.base import WebWerverParameters from pilot.server.base import WebWerverParameters
if TYPE_CHECKING: if TYPE_CHECKING:
@ -39,13 +35,9 @@ def _initialize_embedding_model(
embedding_model_name: str, embedding_model_name: str,
embedding_model_path: str, embedding_model_path: str,
): ):
from pilot.model.cluster import worker_manager
if param.remote_embedding: if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory") logger.info("Register remote RemoteEmbeddingFactory")
system_app.register( system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
else: else:
logger.info(f"Register local LocalEmbeddingFactory") logger.info(f"Register local LocalEmbeddingFactory")
system_app.register( system_app.register(
@ -56,26 +48,28 @@ def _initialize_embedding_model(
class RemoteEmbeddingFactory(EmbeddingFactory): class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__( def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app) super().__init__(system_app=system_app)
self._worker_manager = worker_manager
self._default_model_name = model_name self._default_model_name = model_name
self.kwargs = kwargs self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app): def init_app(self, system_app):
pass self.system_app = system_app
def create( def create(
self, model_name: str = None, embedding_cls: Type = None self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings": ) -> "Embeddings":
from pilot.model.cluster import WorkerManagerFactory
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls: if embedding_cls:
raise NotImplementedError raise NotImplementedError
worker_manager = self.system_app.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args # Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager) return RemoteEmbeddings(self._default_model_name, worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory): class LocalEmbeddingFactory(EmbeddingFactory):
@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory):
return self._model return self._model
def _load_model(self) -> "Embeddings": def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters self._default_model_name, EmbeddingModelParameters

View File

@ -12,6 +12,7 @@ class ParameterDescription:
param_type: str param_type: str
default_value: Optional[Any] default_value: Optional[Any]
description: str description: str
required: Optional[bool]
valid_values: Optional[List[Any]] valid_values: Optional[List[Any]]
ext_metadata: Dict ext_metadata: Dict
@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type:
return type_mapping.get(type_str, str) return type_mapping.get(type_str, str)
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: def _get_parameter_descriptions(
dataclass_type: Type, **kwargs
) -> List[ParameterDescription]:
descriptions = [] descriptions = []
for field in fields(dataclass_type): for field in fields(dataclass_type):
ext_metadata = { ext_metadata = {
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"] k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
} }
default_value = field.default if field.default != MISSING else None
if field.name in kwargs:
default_value = kwargs[field.name]
descriptions.append( descriptions.append(
ParameterDescription( ParameterDescription(
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}", param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
param_name=field.name, param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type), param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None), description=field.metadata.get("help", None),
default_value=field.default if field.default != MISSING else None, required=field.default is MISSING,
default_value=default_value,
valid_values=field.metadata.get("valid_values", None), valid_values=field.metadata.get("valid_values", None),
ext_metadata=ext_metadata, ext_metadata=ext_metadata,
) )