mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 02:46:40 +00:00
chore(model): new required field for supported models
This commit is contained in:
parent
78553477a9
commit
ae34be23fd
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypedDict, Optional, Dict, List
|
from typing import TypedDict, Optional, Dict, List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, asdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pilot.utils.parameter_utils import ParameterDescription
|
from pilot.utils.parameter_utils import ParameterDescription
|
||||||
|
|
||||||
@ -84,3 +84,25 @@ class WorkerSupportedModel:
|
|||||||
]
|
]
|
||||||
worker_data["models"] = models
|
worker_data["models"] = models
|
||||||
return cls(**worker_data)
|
return cls(**worker_data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlatSupportedModel(SupportedModel):
|
||||||
|
"""For web"""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_supports(
|
||||||
|
supports: List[WorkerSupportedModel],
|
||||||
|
) -> List["FlatSupportedModel"]:
|
||||||
|
results = []
|
||||||
|
for s in supports:
|
||||||
|
host, port, models = s.host, s.port, s.models
|
||||||
|
for m in models:
|
||||||
|
kwargs = asdict(m)
|
||||||
|
kwargs["host"] = host
|
||||||
|
kwargs["port"] = port
|
||||||
|
results.append(FlatSupportedModel(**kwargs))
|
||||||
|
return results
|
||||||
|
@ -5,6 +5,7 @@ from pilot.model.cluster.base import (
|
|||||||
WorkerParameterRequest,
|
WorkerParameterRequest,
|
||||||
WorkerStartupRequest,
|
WorkerStartupRequest,
|
||||||
)
|
)
|
||||||
|
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
||||||
|
|
||||||
@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry
|
|||||||
from pilot.model.cluster.controller.controller import (
|
from pilot.model.cluster.controller.controller import (
|
||||||
ModelRegistryClient,
|
ModelRegistryClient,
|
||||||
run_model_controller,
|
run_model_controller,
|
||||||
|
BaseModelController,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||||
@ -28,6 +30,7 @@ __all__ = [
|
|||||||
"WorkerApplyRequest",
|
"WorkerApplyRequest",
|
||||||
"WorkerParameterRequest",
|
"WorkerParameterRequest",
|
||||||
"WorkerStartupRequest",
|
"WorkerStartupRequest",
|
||||||
|
"WorkerManagerFactory",
|
||||||
"ModelWorker",
|
"ModelWorker",
|
||||||
"DefaultModelWorker",
|
"DefaultModelWorker",
|
||||||
"worker_manager",
|
"worker_manager",
|
||||||
|
@ -172,7 +172,9 @@ def _list_supported_models(
|
|||||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||||
param_cls = llm_adapter.model_param_class()
|
param_cls = llm_adapter.model_param_class()
|
||||||
model.enabled = True
|
model.enabled = True
|
||||||
params = _get_parameter_descriptions(param_cls)
|
params = _get_parameter_descriptions(
|
||||||
|
param_cls, model_name=model_name, model_path=model_path
|
||||||
|
)
|
||||||
model.params = params
|
model.params = params
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -3,6 +3,7 @@ import uuid
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import logging
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Request,
|
Request,
|
||||||
@ -11,6 +12,7 @@ from fastapi import (
|
|||||||
Form,
|
Form,
|
||||||
Body,
|
Body,
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
|
Depends,
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation
|
|||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
|
|
||||||
|
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||||
|
from pilot.model.base import FlatSupportedModel
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
CHAT_FACTORY = ChatFactory()
|
CHAT_FACTORY = ChatFactory()
|
||||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
logger = logging.getLogger(__name__)
|
||||||
knowledge_service = KnowledgeService()
|
knowledge_service = KnowledgeService()
|
||||||
|
|
||||||
model_semaphore = None
|
model_semaphore = None
|
||||||
@ -90,6 +95,20 @@ def knowledge_list():
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_controller() -> BaseModelController:
|
||||||
|
controller = CFG.SYSTEM_APP.get_componet(
|
||||||
|
ComponetType.MODEL_CONTROLLER, BaseModelController
|
||||||
|
)
|
||||||
|
return controller
|
||||||
|
|
||||||
|
|
||||||
|
def get_worker_manager() -> WorkerManager:
|
||||||
|
worker_manager = CFG.SYSTEM_APP.get_componet(
|
||||||
|
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
|
return worker_manager
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||||
async def db_connect_list():
|
async def db_connect_list():
|
||||||
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
||||||
@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/v1/model/types")
|
@router.get("/v1/model/types")
|
||||||
async def model_types(request: Request):
|
async def model_types(controller: BaseModelController = Depends(get_model_controller)):
|
||||||
print(f"/controller/model/types")
|
logger.info(f"/controller/model/types")
|
||||||
try:
|
try:
|
||||||
types = set()
|
types = set()
|
||||||
from pilot.model.cluster.controller.controller import BaseModelController
|
|
||||||
|
|
||||||
controller = CFG.SYSTEM_APP.get_componet(
|
|
||||||
ComponetType.MODEL_CONTROLLER, BaseModelController
|
|
||||||
)
|
|
||||||
models = await controller.get_all_instances(healthy_only=True)
|
models = await controller.get_all_instances(healthy_only=True)
|
||||||
for model in models:
|
for model in models:
|
||||||
worker_name, worker_type = model.model_name.split("@")
|
worker_name, worker_type = model.model_name.split("@")
|
||||||
@ -371,6 +385,16 @@ async def model_types(request: Request):
|
|||||||
return Result.faild(code="E000X", msg=f"controller model types error {e}")
|
return Result.faild(code="E000X", msg=f"controller model types error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/model/supports")
|
||||||
|
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
|
||||||
|
logger.info(f"/controller/model/supports")
|
||||||
|
try:
|
||||||
|
models = await worker_manager.supported_models()
|
||||||
|
return Result.succ(FlatSupportedModel.from_supports(models))
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}")
|
||||||
|
|
||||||
|
|
||||||
async def no_stream_generator(chat):
|
async def no_stream_generator(chat):
|
||||||
msg = await chat.nostream_call()
|
msg = await chat.nostream_call()
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, List, Dict
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.componet import ComponetType
|
||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||||
@ -142,8 +143,11 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import worker_manager
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
|
worker_manager = CFG.SYSTEM_APP.get_componet(
|
||||||
|
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
async for output in worker_manager.generate_stream(payload):
|
async for output in worker_manager.generate_stream(payload):
|
||||||
yield output
|
yield output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -160,7 +164,11 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import worker_manager
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
|
worker_manager = CFG.SYSTEM_APP.get_componet(
|
||||||
|
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
|
|
||||||
model_output = await worker_manager.generate(payload)
|
model_output = await worker_manager.generate(payload)
|
||||||
|
|
||||||
|
@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp):
|
|||||||
|
|
||||||
def _create_model_start_listener(system_app: SystemApp):
|
def _create_model_start_listener(system_app: SystemApp):
|
||||||
from pilot.connections.manages.connection_manager import ConnectManager
|
from pilot.connections.manages.connection_manager import ConnectManager
|
||||||
from pilot.model.cluster import worker_manager
|
|
||||||
|
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Type, TYPE_CHECKING
|
|
||||||
|
|
||||||
from pilot.componet import SystemApp
|
|
||||||
import logging
|
import logging
|
||||||
from pilot.configs.model_config import get_device
|
from typing import TYPE_CHECKING, Any, Type
|
||||||
from pilot.embedding_engine.embedding_factory import (
|
|
||||||
EmbeddingFactory,
|
from pilot.componet import ComponetType, SystemApp
|
||||||
DefaultEmbeddingFactory,
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
)
|
|
||||||
from pilot.server.base import WebWerverParameters
|
from pilot.server.base import WebWerverParameters
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -39,13 +35,9 @@ def _initialize_embedding_model(
|
|||||||
embedding_model_name: str,
|
embedding_model_name: str,
|
||||||
embedding_model_path: str,
|
embedding_model_path: str,
|
||||||
):
|
):
|
||||||
from pilot.model.cluster import worker_manager
|
|
||||||
|
|
||||||
if param.remote_embedding:
|
if param.remote_embedding:
|
||||||
logger.info("Register remote RemoteEmbeddingFactory")
|
logger.info("Register remote RemoteEmbeddingFactory")
|
||||||
system_app.register(
|
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
|
||||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Register local LocalEmbeddingFactory")
|
logger.info(f"Register local LocalEmbeddingFactory")
|
||||||
system_app.register(
|
system_app.register(
|
||||||
@ -56,26 +48,28 @@ def _initialize_embedding_model(
|
|||||||
|
|
||||||
|
|
||||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||||
def __init__(
|
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||||
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
super().__init__(system_app=system_app)
|
super().__init__(system_app=system_app)
|
||||||
self._worker_manager = worker_manager
|
|
||||||
self._default_model_name = model_name
|
self._default_model_name = model_name
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
self.system_app = system_app
|
||||||
|
|
||||||
def init_app(self, system_app):
|
def init_app(self, system_app):
|
||||||
pass
|
self.system_app = system_app
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self, model_name: str = None, embedding_cls: Type = None
|
self, model_name: str = None, embedding_cls: Type = None
|
||||||
) -> "Embeddings":
|
) -> "Embeddings":
|
||||||
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
||||||
|
|
||||||
if embedding_cls:
|
if embedding_cls:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
worker_manager = self.system_app.get_componet(
|
||||||
|
ComponetType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
# Ignore model_name args
|
# Ignore model_name args
|
||||||
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
|
return RemoteEmbeddings(self._default_model_name, worker_manager)
|
||||||
|
|
||||||
|
|
||||||
class LocalEmbeddingFactory(EmbeddingFactory):
|
class LocalEmbeddingFactory(EmbeddingFactory):
|
||||||
@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory):
|
|||||||
return self._model
|
return self._model
|
||||||
|
|
||||||
def _load_model(self) -> "Embeddings":
|
def _load_model(self) -> "Embeddings":
|
||||||
from pilot.model.parameter import (
|
|
||||||
EmbeddingModelParameters,
|
|
||||||
BaseEmbeddingModelParameters,
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
|
||||||
)
|
|
||||||
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
|
||||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
|
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||||
|
from pilot.model.parameter import (
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||||
|
BaseEmbeddingModelParameters,
|
||||||
|
EmbeddingModelParameters,
|
||||||
|
)
|
||||||
|
|
||||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||||
self._default_model_name, EmbeddingModelParameters
|
self._default_model_name, EmbeddingModelParameters
|
||||||
|
@ -12,6 +12,7 @@ class ParameterDescription:
|
|||||||
param_type: str
|
param_type: str
|
||||||
default_value: Optional[Any]
|
default_value: Optional[Any]
|
||||||
description: str
|
description: str
|
||||||
|
required: Optional[bool]
|
||||||
valid_values: Optional[List[Any]]
|
valid_values: Optional[List[Any]]
|
||||||
ext_metadata: Dict
|
ext_metadata: Dict
|
||||||
|
|
||||||
@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type:
|
|||||||
return type_mapping.get(type_str, str)
|
return type_mapping.get(type_str, str)
|
||||||
|
|
||||||
|
|
||||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
def _get_parameter_descriptions(
|
||||||
|
dataclass_type: Type, **kwargs
|
||||||
|
) -> List[ParameterDescription]:
|
||||||
descriptions = []
|
descriptions = []
|
||||||
for field in fields(dataclass_type):
|
for field in fields(dataclass_type):
|
||||||
ext_metadata = {
|
ext_metadata = {
|
||||||
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
||||||
}
|
}
|
||||||
|
default_value = field.default if field.default != MISSING else None
|
||||||
|
if field.name in kwargs:
|
||||||
|
default_value = kwargs[field.name]
|
||||||
descriptions.append(
|
descriptions.append(
|
||||||
ParameterDescription(
|
ParameterDescription(
|
||||||
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
||||||
param_name=field.name,
|
param_name=field.name,
|
||||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||||
description=field.metadata.get("help", None),
|
description=field.metadata.get("help", None),
|
||||||
default_value=field.default if field.default != MISSING else None,
|
required=field.default is MISSING,
|
||||||
|
default_value=default_value,
|
||||||
valid_values=field.metadata.get("valid_values", None),
|
valid_values=field.metadata.get("valid_values", None),
|
||||||
ext_metadata=ext_metadata,
|
ext_metadata=ext_metadata,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user