chore(model): new required field for supported models

This commit is contained in:
FangYin Cheng 2023-09-19 09:58:48 +08:00
parent 78553477a9
commit ae34be23fd
8 changed files with 99 additions and 41 deletions

View File

@ -3,7 +3,7 @@
from enum import Enum
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription
@ -84,3 +84,25 @@ class WorkerSupportedModel:
]
worker_data["models"] = models
return cls(**worker_data)
@dataclass
class FlatSupportedModel(SupportedModel):
"""For web"""
host: str
port: int
@staticmethod
def from_supports(
supports: List[WorkerSupportedModel],
) -> List["FlatSupportedModel"]:
results = []
for s in supports:
host, port, models = s.host, s.port, s.models
for m in models:
kwargs = asdict(m)
kwargs["host"] = host
kwargs["port"] = port
results.append(FlatSupportedModel(**kwargs))
return results

View File

@ -5,6 +5,7 @@ from pilot.model.cluster.base import (
WorkerParameterRequest,
WorkerStartupRequest,
)
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.controller.controller import (
ModelRegistryClient,
run_model_controller,
BaseModelController,
)
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
@ -28,6 +30,7 @@ __all__ = [
"WorkerApplyRequest",
"WorkerParameterRequest",
"WorkerStartupRequest",
"WorkerManagerFactory",
"ModelWorker",
"DefaultModelWorker",
"worker_manager",

View File

@ -172,7 +172,9 @@ def _list_supported_models(
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_cls = llm_adapter.model_param_class()
model.enabled = True
params = _get_parameter_descriptions(param_cls)
params = _get_parameter_descriptions(
param_cls, model_name=model_name, model_path=model_path
)
model.params = params
except Exception:
pass

View File

@ -3,6 +3,7 @@ import uuid
import asyncio
import os
import shutil
import logging
from fastapi import (
APIRouter,
Request,
@ -11,6 +12,7 @@ from fastapi import (
Form,
Body,
BackgroundTasks,
Depends,
)
from fastapi.responses import StreamingResponse
@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
router = APIRouter()
CFG = Config()
CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
logger = logging.getLogger(__name__)
knowledge_service = KnowledgeService()
model_semaphore = None
@ -90,6 +95,20 @@ def knowledge_list():
return params
def get_model_controller() -> BaseModelController:
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
)
return controller
def get_worker_manager() -> WorkerManager:
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return worker_manager
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
@router.get("/v1/model/types")
async def model_types(request: Request):
print(f"/controller/model/types")
async def model_types(controller: BaseModelController = Depends(get_model_controller)):
logger.info(f"/controller/model/types")
try:
types = set()
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
)
models = await controller.get_all_instances(healthy_only=True)
for model in models:
worker_name, worker_type = model.model_name.split("@")
@ -371,6 +385,16 @@ async def model_types(request: Request):
return Result.faild(code="E000X", msg=f"controller model types error {e}")
@router.get("/v1/model/supports")
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
logger.info(f"/controller/model/supports")
try:
models = await worker_manager.supported_models()
return Result.succ(FlatSupportedModel.from_supports(models))
except Exception as e:
return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}")
async def no_stream_generator(chat):
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")

View File

@ -6,6 +6,7 @@ from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.componet import ComponetType
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
@ -142,8 +143,11 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
from pilot.model.cluster import worker_manager
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for output in worker_manager.generate_stream(payload):
yield output
except Exception as e:
@ -160,7 +164,11 @@ class BaseChat(ABC):
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
from pilot.model.cluster import worker_manager
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload)

View File

@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp):
def _create_model_start_listener(system_app: SystemApp):
from pilot.connections.manages.connection_manager import ConnectManager
from pilot.model.cluster import worker_manager
cfg = Config()

View File

@ -1,14 +1,10 @@
from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
import logging
from pilot.configs.model_config import get_device
from pilot.embedding_engine.embedding_factory import (
EmbeddingFactory,
DefaultEmbeddingFactory,
)
from typing import TYPE_CHECKING, Any, Type
from pilot.componet import ComponetType, SystemApp
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
if TYPE_CHECKING:
@ -39,13 +35,9 @@ def _initialize_embedding_model(
embedding_model_name: str,
embedding_model_path: str,
):
from pilot.model.cluster import worker_manager
if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
else:
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
@ -56,26 +48,28 @@ def _initialize_embedding_model(
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
) -> None:
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._worker_manager = worker_manager
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app):
pass
self.system_app = system_app
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
from pilot.model.cluster import WorkerManagerFactory
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_componet(
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
return RemoteEmbeddings(self._default_model_name, worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory):
return self._model
def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters

View File

@ -12,6 +12,7 @@ class ParameterDescription:
param_type: str
default_value: Optional[Any]
description: str
required: Optional[bool]
valid_values: Optional[List[Any]]
ext_metadata: Dict
@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type:
return type_mapping.get(type_str, str)
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
def _get_parameter_descriptions(
dataclass_type: Type, **kwargs
) -> List[ParameterDescription]:
descriptions = []
for field in fields(dataclass_type):
ext_metadata = {
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
}
default_value = field.default if field.default != MISSING else None
if field.name in kwargs:
default_value = kwargs[field.name]
descriptions.append(
ParameterDescription(
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None),
default_value=field.default if field.default != MISSING else None,
required=field.default is MISSING,
default_value=default_value,
valid_values=field.metadata.get("valid_values", None),
ext_metadata=ext_metadata,
)