diff --git a/assets/schema/prompt_management.sql b/assets/schema/prompt_management.sql new file mode 100644 index 000000000..b2ed6de23 --- /dev/null +++ b/assets/schema/prompt_management.sql @@ -0,0 +1,16 @@ +CREATE DATABASE prompt_management; +use prompt_management; +CREATE TABLE `prompt_manage` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '场景', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '子场景', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '类型: common or private', + `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt的名字', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'prompt的内容', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '用户名', + `gmt_created` datetime DEFAULT NULL, + `gmt_modified` datetime DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='prompt管理表'; \ No newline at end of file diff --git a/pilot/componet.py b/pilot/component.py similarity index 79% rename from pilot/componet.py rename to pilot/component.py index 8697d9560..7c865479d 100644 --- a/pilot/componet.py +++ b/pilot/component.py @@ -42,15 +42,16 @@ class LifeCycle: pass -class ComponetType(str, Enum): +class ComponentType(str, Enum): WORKER_MANAGER = "dbgpt_worker_manager" + WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory" MODEL_CONTROLLER = "dbgpt_model_controller" -class BaseComponet(LifeCycle, ABC): +class BaseComponent(LifeCycle, ABC): """Abstract Base Component class. All custom components should extend this.""" - name = "base_dbgpt_componet" + name = "base_dbgpt_component" def __init__(self, system_app: Optional[SystemApp] = None): if system_app is not None: @@ -66,15 +67,15 @@ class BaseComponet(LifeCycle, ABC): pass -T = TypeVar("T", bound=BaseComponet) +T = TypeVar("T", bound=BaseComponent) class SystemApp(LifeCycle): """Main System Application class that manages the lifecycle and registration of components.""" def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None: - self.componets: Dict[ - str, BaseComponet + self.components: Dict[ + str, BaseComponent ] = {} # Dictionary to store registered components. self._asgi_app = asgi_app @@ -83,58 +84,60 @@ class SystemApp(LifeCycle): """Returns the internal ASGI app.""" return self._asgi_app - def register(self, componet: Type[BaseComponet], *args, **kwargs): + def register(self, component: Type[BaseComponent], *args, **kwargs): """Register a new component by its type.""" - instance = componet(self, *args, **kwargs) + instance = component(self, *args, **kwargs) self.register_instance(instance) def register_instance(self, instance: T): """Register an already initialized component.""" name = instance.name - if isinstance(name, ComponetType): + if isinstance(name, ComponentType): name = name.value - if name in self.componets: + if name in self.components: raise RuntimeError( - f"Componse name {name} already exists: {self.componets[name]}" + f"Componse name {name} already exists: {self.components[name]}" ) - logger.info(f"Register componet with name {name} and instance: {instance}") - self.componets[name] = instance + logger.info(f"Register component with name {name} and instance: {instance}") + self.components[name] = instance instance.init_app(self) - def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T: + def get_component( + self, name: Union[str, ComponentType], component_type: Type[T] + ) -> T: """Retrieve a registered component by its name and type.""" - if isinstance(name, ComponetType): + if isinstance(name, ComponentType): name = name.value - component = self.componets.get(name) + component = self.components.get(name) if not component: raise ValueError(f"No component found with name {name}") - if not isinstance(component, componet_type): - raise TypeError(f"Component {name} is not of type {componet_type}") + if not isinstance(component, component_type): + raise TypeError(f"Component {name} is not of type {component_type}") return component def before_start(self): """Invoke the before_start hooks for all registered components.""" - for _, v in self.componets.items(): + for _, v in self.components.items(): v.before_start() async def async_before_start(self): """Asynchronously invoke the before_start hooks for all registered components.""" - tasks = [v.async_before_start() for _, v in self.componets.items()] + tasks = [v.async_before_start() for _, v in self.components.items()] await asyncio.gather(*tasks) def after_start(self): """Invoke the after_start hooks for all registered components.""" - for _, v in self.componets.items(): + for _, v in self.components.items(): v.after_start() async def async_after_start(self): """Asynchronously invoke the after_start hooks for all registered components.""" - tasks = [v.async_after_start() for _, v in self.componets.items()] + tasks = [v.async_after_start() for _, v in self.components.items()] await asyncio.gather(*tasks) def before_stop(self): """Invoke the before_stop hooks for all registered components.""" - for _, v in self.componets.items(): + for _, v in self.components.items(): try: v.before_stop() except Exception as e: @@ -142,7 +145,7 @@ class SystemApp(LifeCycle): async def async_before_stop(self): """Asynchronously invoke the before_stop hooks for all registered components.""" - tasks = [v.async_before_stop() for _, v in self.componets.items()] + tasks = [v.async_before_stop() for _, v in self.components.items()] await asyncio.gather(*tasks) def _build(self): diff --git a/pilot/configs/config.py b/pilot/configs/config.py index e068613c1..5d256a8e5 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -189,7 +189,7 @@ class Config(metaclass=Singleton): ### Log level self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO") - from pilot.componet import SystemApp + from pilot.component import SystemApp self.SYSTEM_APP: SystemApp = None diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index b5cfbbfdd..534cd36f0 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -4,7 +4,7 @@ import asyncio from pilot.configs.config import Config from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.common.schema import DBType -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.base import BaseConnect diff --git a/pilot/connections/rdbms/conn_spark.py b/pilot/connections/rdbms/conn_spark.py index 4ea976f96..aca5f9f4a 100644 --- a/pilot/connections/rdbms/conn_spark.py +++ b/pilot/connections/rdbms/conn_spark.py @@ -1,13 +1,10 @@ import re from typing import Optional, Any -from pyspark import SQLContext from sqlalchemy import text -from pilot.connections.rdbms.base import RDBMSDatabase from pyspark.sql import SparkSession, DataFrame -from sqlalchemy import create_engine class SparkConnect: diff --git a/pilot/embedding_engine/embedding_factory.py b/pilot/embedding_engine/embedding_factory.py index c6f51f4c9..5b49df767 100644 --- a/pilot/embedding_engine/embedding_factory.py +++ b/pilot/embedding_engine/embedding_factory.py @@ -2,13 +2,13 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, Type, TYPE_CHECKING -from pilot.componet import BaseComponet +from pilot.component import BaseComponent if TYPE_CHECKING: from langchain.embeddings.base import Embeddings -class EmbeddingFactory(BaseComponet, ABC): +class EmbeddingFactory(BaseComponent, ABC): name = "embedding_factory" @abstractmethod diff --git a/pilot/model/base.py b/pilot/model/base.py index b6eb9da25..1d46b3161 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -3,7 +3,7 @@ from enum import Enum from typing import TypedDict, Optional, Dict, List -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import datetime from pilot.utils.parameter_utils import ParameterDescription @@ -84,3 +84,25 @@ class WorkerSupportedModel: ] worker_data["models"] = models return cls(**worker_data) + + +@dataclass +class FlatSupportedModel(SupportedModel): + """For web""" + + host: str + port: int + + @staticmethod + def from_supports( + supports: List[WorkerSupportedModel], + ) -> List["FlatSupportedModel"]: + results = [] + for s in supports: + host, port, models = s.host, s.port, s.models + for m in models: + kwargs = asdict(m) + kwargs["host"] = host + kwargs["port"] = port + results.append(FlatSupportedModel(**kwargs)) + return results diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py index b73fd7873..9937ffa0b 100644 --- a/pilot/model/cluster/__init__.py +++ b/pilot/model/cluster/__init__.py @@ -5,6 +5,7 @@ from pilot.model.cluster.base import ( WorkerParameterRequest, WorkerStartupRequest, ) +from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker.default_worker import DefaultModelWorker @@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry from pilot.model.cluster.controller.controller import ( ModelRegistryClient, run_model_controller, + BaseModelController, ) from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager @@ -28,6 +30,7 @@ __all__ = [ "WorkerApplyRequest", "WorkerParameterRequest", "WorkerStartupRequest", + "WorkerManagerFactory", "ModelWorker", "DefaultModelWorker", "worker_manager", diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 54360e477..e93216929 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -4,7 +4,7 @@ import logging from typing import List from fastapi import APIRouter, FastAPI -from pilot.componet import BaseComponet, ComponetType, SystemApp +from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.model.base import ModelInstance from pilot.model.parameter import ModelControllerParameters from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry @@ -15,8 +15,8 @@ from pilot.utils.api_utils import ( ) -class BaseModelController(BaseComponet, ABC): - name = ComponetType.MODEL_CONTROLLER +class BaseModelController(BaseComponent, ABC): + name = ComponentType.MODEL_CONTROLLER def init_app(self, system_app: SystemApp): pass diff --git a/pilot/model/cluster/manager_base.py b/pilot/model/cluster/manager_base.py index d66991cae..ce37755f6 100644 --- a/pilot/model/cluster/manager_base.py +++ b/pilot/model/cluster/manager_base.py @@ -4,6 +4,7 @@ from typing import List, Optional, Dict, Iterator, Callable from abc import ABC, abstractmethod from datetime import datetime from concurrent.futures import Future +from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest @@ -104,3 +105,14 @@ class WorkerManager(ABC): self, worker_type: str, model_name: str ) -> List[ParameterDescription]: """Get parameter descriptions of model""" + + +class WorkerManagerFactory(BaseComponent, ABC): + name = ComponentType.WORKER_MANAGER_FACTORY.value + + def init_app(self, system_app: SystemApp): + pass + + @abstractmethod + def create(self) -> WorkerManager: + """Create worker manager""" diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index ff8fd7767..2576c6a21 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -7,10 +7,11 @@ import random import time from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict -from typing import Awaitable, Callable, Dict, Iterator, List +from typing import Awaitable, Callable, Dict, Iterator, List, Optional from fastapi import APIRouter, FastAPI from fastapi.responses import StreamingResponse +from pilot.component import SystemApp from pilot.configs.model_config import LOGDIR from pilot.model.base import ( ModelInstance, @@ -23,7 +24,11 @@ from pilot.model.cluster.registry import ModelRegistry from pilot.model.llm_utils import list_supported_models from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType from pilot.model.cluster.worker_base import ModelWorker -from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData +from pilot.model.cluster.manager_base import ( + WorkerManager, + WorkerRunData, + WorkerManagerFactory, +) from pilot.model.cluster.base import * from pilot.utils import build_logger from pilot.utils.parameter_utils import ( @@ -548,6 +553,17 @@ class WorkerManagerAdapter(WorkerManager): return await self.worker_manager.parameter_descriptions(worker_type, model_name) +class _DefaultWorkerManagerFactory(WorkerManagerFactory): + def __init__( + self, system_app: SystemApp | None = None, worker_manager: WorkerManager = None + ): + super().__init__(system_app) + self.worker_manager = worker_manager + + def create(self) -> WorkerManager: + return self.worker_manager + + worker_manager = WorkerManagerAdapter() router = APIRouter() @@ -787,6 +803,7 @@ def initialize_worker_manager_in_client( embedding_model_name: str = None, embedding_model_path: str = None, start_listener: Callable[["WorkerManager"], None] = None, + system_app: SystemApp = None, ): """Initialize WorkerManager in client. If run_locally is True: @@ -845,6 +862,8 @@ def initialize_worker_manager_in_client( if include_router and app: # mount WorkerManager router app.include_router(router, prefix="/api") + if system_app: + system_app.register(_DefaultWorkerManagerFactory, worker_manager) def run_worker_manager( diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 9131490f5..690a6afbf 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -172,7 +172,9 @@ def _list_supported_models( llm_adapter = get_llm_model_adapter(model_name, model_path) param_cls = llm_adapter.model_param_class() model.enabled = True - params = _get_parameter_descriptions(param_cls) + params = _get_parameter_descriptions( + param_cls, model_name=model_name, model_path=model_path + ) model.params = params except Exception: pass diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 45d00c13c..111487f00 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -3,6 +3,7 @@ import uuid import asyncio import os import shutil +import logging from fastapi import ( APIRouter, Request, @@ -11,6 +12,7 @@ from fastapi import ( Form, Body, BackgroundTasks, + Depends, ) from fastapi.responses import StreamingResponse @@ -18,7 +20,7 @@ from fastapi.exceptions import RequestValidationError from typing import List import tempfile -from pilot.componet import ComponetType +from pilot.component import ComponentType from pilot.openapi.api_view_model import ( Result, ConversationVo, @@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.summary.db_summary_client import DBSummaryClient +from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory +from pilot.model.base import FlatSupportedModel + router = APIRouter() CFG = Config() CHAT_FACTORY = ChatFactory() -logger = build_logger("api_v1", LOGDIR + "api_v1.log") +logger = logging.getLogger(__name__) knowledge_service = KnowledgeService() model_semaphore = None @@ -90,6 +95,20 @@ def knowledge_list(): return params +def get_model_controller() -> BaseModelController: + controller = CFG.SYSTEM_APP.get_component( + ComponentType.MODEL_CONTROLLER, BaseModelController + ) + return controller + + +def get_worker_manager() -> WorkerManager: + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + return worker_manager + + @router.get("/v1/chat/db/list", response_model=Result[DBConfig]) async def db_connect_list(): return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) @@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()): @router.get("/v1/model/types") -async def model_types(request: Request): - print(f"/controller/model/types") +async def model_types(controller: BaseModelController = Depends(get_model_controller)): + logger.info(f"/controller/model/types") try: types = set() - from pilot.model.cluster.controller.controller import BaseModelController - - controller = CFG.SYSTEM_APP.get_componet( - ComponetType.MODEL_CONTROLLER, BaseModelController - ) models = await controller.get_all_instances(healthy_only=True) for model in models: worker_name, worker_type = model.model_name.split("@") @@ -371,6 +385,16 @@ async def model_types(request: Request): return Result.faild(code="E000X", msg=f"controller model types error {e}") +@router.get("/v1/model/supports") +async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)): + logger.info(f"/controller/model/supports") + try: + models = await worker_manager.supported_models() + return Result.succ(FlatSupportedModel.from_supports(models)) + except Exception as e: + return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}") + + async def no_stream_generator(chat): msg = await chat.nostream_call() msg = msg.replace("\n", "\\n") diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 9e7a22373..70805ca80 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -6,6 +6,7 @@ from typing import Any, List, Dict from pilot.configs.config import Config from pilot.configs.model_config import LOGDIR +from pilot.component import ComponentType from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory @@ -142,8 +143,11 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - from pilot.model.cluster import worker_manager + from pilot.model.cluster import WorkerManagerFactory + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() async for output in worker_manager.generate_stream(payload): yield output except Exception as e: @@ -160,7 +164,11 @@ class BaseChat(ABC): logger.info(f"Request: \n{payload}") ai_response_text = "" try: - from pilot.model.cluster import worker_manager + from pilot.model.cluster import WorkerManagerFactory + + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() model_output = await worker_manager.generate(payload) diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 3856cfb05..3b9bacb99 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -49,7 +49,7 @@ class ChatKnowledge(BaseChat): "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) self.knowledge_embedding_client = EmbeddingEngine( diff --git a/pilot/server/base.py b/pilot/server/base.py index 888ebbf3d..5ebbc0003 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -6,7 +6,7 @@ from typing import Optional, Any from dataclasses import dataclass, field from pilot.configs.config import Config -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.utils.parameter_utils import BaseParameters @@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp): def _create_model_start_listener(system_app: SystemApp): from pilot.connections.manages.connection_manager import ConnectManager - from pilot.model.cluster import worker_manager cfg = Config() diff --git a/pilot/server/componet_configs.py b/pilot/server/component_configs.py similarity index 81% rename from pilot/server/componet_configs.py rename to pilot/server/component_configs.py index d46b626ca..d2dded0a1 100644 --- a/pilot/server/componet_configs.py +++ b/pilot/server/component_configs.py @@ -1,14 +1,10 @@ from __future__ import annotations -from typing import Any, Type, TYPE_CHECKING - -from pilot.componet import SystemApp import logging -from pilot.configs.model_config import get_device -from pilot.embedding_engine.embedding_factory import ( - EmbeddingFactory, - DefaultEmbeddingFactory, -) +from typing import TYPE_CHECKING, Any, Type + +from pilot.component import ComponentType, SystemApp +from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.server.base import WebWerverParameters if TYPE_CHECKING: @@ -18,7 +14,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def initialize_componets( +def initialize_components( param: WebWerverParameters, system_app: SystemApp, embedding_model_name: str, @@ -39,13 +35,9 @@ def _initialize_embedding_model( embedding_model_name: str, embedding_model_path: str, ): - from pilot.model.cluster import worker_manager - if param.remote_embedding: logger.info("Register remote RemoteEmbeddingFactory") - system_app.register( - RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name - ) + system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name) else: logger.info(f"Register local LocalEmbeddingFactory") system_app.register( @@ -56,26 +48,28 @@ def _initialize_embedding_model( class RemoteEmbeddingFactory(EmbeddingFactory): - def __init__( - self, system_app, worker_manager, model_name: str = None, **kwargs: Any - ) -> None: + def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None: super().__init__(system_app=system_app) - self._worker_manager = worker_manager self._default_model_name = model_name self.kwargs = kwargs + self.system_app = system_app def init_app(self, system_app): - pass + self.system_app = system_app def create( self, model_name: str = None, embedding_cls: Type = None ) -> "Embeddings": + from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings if embedding_cls: raise NotImplementedError + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() # Ignore model_name args - return RemoteEmbeddings(self._default_model_name, self._worker_manager) + return RemoteEmbeddings(self._default_model_name, worker_manager) class LocalEmbeddingFactory(EmbeddingFactory): @@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory): return self._model def _load_model(self) -> "Embeddings": - from pilot.model.parameter import ( - EmbeddingModelParameters, - BaseEmbeddingModelParameters, - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, - ) - from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params from pilot.model.cluster.embedding.loader import EmbeddingLoader + from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params + from pilot.model.parameter import ( + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + BaseEmbeddingModelParameters, + EmbeddingModelParameters, + ) param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( self._default_model_name, EmbeddingModelParameters diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index c2a510eaf..5cc56fbec 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -8,14 +8,14 @@ sys.path.append(ROOT_PATH) import signal from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.server.base import ( server_init, WebWerverParameters, _create_model_start_listener, ) -from pilot.server.componet_configs import initialize_componets +from pilot.server.component_configs import initialize_components from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications @@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from pilot.server.knowledge.api import router as knowledge_router +from pilot.server.prompt.api import router as prompt_router from pilot.server.llm_manage.api import router as llm_manage_api @@ -76,6 +77,7 @@ app.include_router(llm_manage_api, prefix="/api") # app.include_router(api_v1) app.include_router(knowledge_router) +app.include_router(prompt_router) # app.include_router(api_editor_route_v1) @@ -118,7 +120,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): server_init(param, system_app) model_start_listener = _create_model_start_listener(system_app) - initialize_componets(param, system_app, embedding_model_name, embedding_model_path) + initialize_components(param, system_app, embedding_model_name, embedding_model_path) model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] if not param.light: @@ -133,6 +135,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): embedding_model_name=embedding_model_name, embedding_model_path=embedding_model_path, start_listener=model_start_listener, + system_app=system_app, ) CFG.NEW_SERVER_MODE = True @@ -146,6 +149,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): controller_addr=CFG.MODEL_SERVER, local_port=param.port, start_listener=model_start_listener, + system_app=system_app, ) CFG.SERVER_LIGHT_MODE = True diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index f43333aa1..71b939924 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -181,7 +181,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest): @router.post("/knowledge/{vector_name}/query") def similar_query(space_name: str, query_request: KnowledgeQueryRequest): print(f"Received params: {space_name}, {query_request}") - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) client = EmbeddingEngine( diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 0c04dee3a..fc07040c7 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -205,7 +205,7 @@ class KnowledgeService: chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) client = EmbeddingEngine( diff --git a/pilot/server/llm_manage/api.py b/pilot/server/llm_manage/api.py index d68083940..be1957c14 100644 --- a/pilot/server/llm_manage/api.py +++ b/pilot/server/llm_manage/api.py @@ -1,7 +1,7 @@ from fastapi import APIRouter -from pilot.componet import ComponetType +from pilot.component import ComponentType from pilot.configs.config import Config from pilot.model.base import ModelInstance, WorkerApplyType @@ -31,8 +31,8 @@ async def model_list(): try: from pilot.model.cluster.controller.controller import BaseModelController - controller = CFG.SYSTEM_APP.get_componet( - ComponetType.MODEL_CONTROLLER, BaseModelController + controller = CFG.SYSTEM_APP.get_component( + ComponentType.MODEL_CONTROLLER, BaseModelController ) responses = [] managers = await controller.get_all_instances( @@ -70,8 +70,8 @@ async def model_start(request: WorkerStartupRequest): try: from pilot.model.cluster.controller.controller import BaseModelController - controller = CFG.SYSTEM_APP.get_componet( - ComponetType.MODEL_CONTROLLER, BaseModelController + controller = CFG.SYSTEM_APP.get_component( + ComponentType.MODEL_CONTROLLER, BaseModelController ) instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) worker_instance = None @@ -98,8 +98,8 @@ async def model_start(request: WorkerStartupRequest): try: from pilot.model.cluster.controller.controller import BaseModelController - controller = CFG.SYSTEM_APP.get_componet( - ComponetType.MODEL_CONTROLLER, BaseModelController + controller = CFG.SYSTEM_APP.get_component( + ComponentType.MODEL_CONTROLLER, BaseModelController ) instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) worker_instance = None diff --git a/pilot/server/prompt/__init__.py b/pilot/server/prompt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/server/prompt/api.py b/pilot/server/prompt/api.py new file mode 100644 index 000000000..b94546891 --- /dev/null +++ b/pilot/server/prompt/api.py @@ -0,0 +1,46 @@ +from fastapi import APIRouter, File, UploadFile, Form + +from pilot.openapi.api_view_model import Result +from pilot.server.prompt.service import PromptManageService +from pilot.server.prompt.request.request import PromptManageRequest + +router = APIRouter() + +prompt_manage_service = PromptManageService() + + +@router.post("/prompt/add") +def prompt_add(request: PromptManageRequest): + print(f"/space/add params: {request}") + try: + prompt_manage_service.create_prompt(request) + return Result.succ([]) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt add error {e}") + + +@router.post("/prompt/list") +def prompt_list(request: PromptManageRequest): + print(f"/prompt/list params: {request}") + try: + return Result.succ(prompt_manage_service.get_prompts(request)) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt list error {e}") + + +@router.post("/prompt/update") +def prompt_update(request: PromptManageRequest): + print(f"/prompt/update params: {request}") + try: + return Result.succ(prompt_manage_service.update_prompt(request)) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt update error {e}") + + +@router.post("/prompt/delete") +def prompt_delete(request: PromptManageRequest): + print(f"/prompt/delete params: {request}") + try: + return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name)) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt delete error {e}") diff --git a/pilot/server/prompt/prompt_manage_db.py b/pilot/server/prompt/prompt_manage_db.py new file mode 100644 index 000000000..6a02e6b5c --- /dev/null +++ b/pilot/server/prompt/prompt_manage_db.py @@ -0,0 +1,91 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer, Text, String, DateTime +from sqlalchemy.ext.declarative import declarative_base + +from pilot.configs.config import Config +from pilot.connections.rdbms.base_dao import BaseDao + +from pilot.server.prompt.request.request import PromptManageRequest + +CFG = Config() +Base = declarative_base() + + +class PromptManageEntity(Base): + __tablename__ = "prompt_manage" + id = Column(Integer, primary_key=True) + chat_scene = Column(String(100)) + sub_chat_scene = Column(String(100)) + prompt_type = Column(String(100)) + prompt_name = Column(String(512)) + content = Column(Text) + user_name = Column(String(128)) + gmt_created = Column(DateTime) + gmt_modified = Column(DateTime) + + def __repr__(self): + return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class PromptManageDao(BaseDao): + def __init__(self): + super().__init__( + database="prompt_management", orm_base=Base, create_not_exist_table=True + ) + + def create_prompt(self, prompt: PromptManageRequest): + session = self.Session() + prompt_manage = PromptManageEntity( + chat_scene=prompt.chat_scene, + sub_chat_scene=prompt.sub_chat_scene, + prompt_type=prompt.prompt_type, + prompt_name=prompt.prompt_name, + content=prompt.content, + user_name=prompt.user_name, + gmt_created=datetime.now(), + gmt_modified=datetime.now(), + ) + session.add(prompt_manage) + session.commit() + session.close() + + def get_prompts(self, query: PromptManageEntity): + session = self.Session() + prompts = session.query(PromptManageEntity) + if query.chat_scene is not None: + prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) + if query.sub_chat_scene is not None: + prompts = prompts.filter( + PromptManageEntity.sub_chat_scene == query.sub_chat_scene + ) + if query.prompt_type is not None: + prompts = prompts.filter( + PromptManageEntity.prompt_type == query.prompt_type + ) + if query.prompt_type == "private" and query.user_name is not None: + prompts = prompts.filter( + PromptManageEntity.user_name == query.user_name + ) + if query.prompt_name is not None: + prompts = prompts.filter( + PromptManageEntity.prompt_name == query.prompt_name + ) + + prompts = prompts.order_by(PromptManageEntity.gmt_created.desc()) + result = prompts.all() + session.close() + return result + + def update_prompt(self, prompt: PromptManageEntity): + session = self.Session() + session.merge(prompt) + session.commit() + session.close() + + def delete_prompt(self, prompt: PromptManageEntity): + session = self.Session() + if prompt: + session.delete(prompt) + session.commit() + session.close() diff --git a/pilot/server/prompt/request/__init__.py b/pilot/server/prompt/request/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/server/prompt/request/request.py b/pilot/server/prompt/request/request.py new file mode 100644 index 000000000..c1b0683ec --- /dev/null +++ b/pilot/server/prompt/request/request.py @@ -0,0 +1,24 @@ +from typing import List + +from pydantic import BaseModel + + +class PromptManageRequest(BaseModel): + """chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa""" + + chat_scene: str = None + + """sub_chat_scene: sub chat scene""" + sub_chat_scene: str = None + + """prompt_type: common or private""" + prompt_type: str = None + + """content: prompt content""" + content: str = None + + """user_name: user name""" + user_name: str = None + + """prompt_name: prompt name""" + prompt_name: str = None diff --git a/pilot/server/prompt/request/response.py b/pilot/server/prompt/request/response.py new file mode 100644 index 000000000..4da05e069 --- /dev/null +++ b/pilot/server/prompt/request/response.py @@ -0,0 +1,26 @@ +from typing import List +from pydantic import BaseModel + + +class PromptQueryResponse(BaseModel): + id: int = None + """chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa""" + + chat_scene: str = None + + """sub_chat_scene: sub chat scene""" + sub_chat_scene: str = None + + """prompt_type: common or private""" + prompt_type: str = None + + """content: prompt content""" + content: str = None + + """user_name: user name""" + user_name: str = None + + """prompt_name: prompt name""" + prompt_name: str = None + gmt_created: str = None + gmt_modified: str = None diff --git a/pilot/server/prompt/service.py b/pilot/server/prompt/service.py new file mode 100644 index 000000000..c108d8b88 --- /dev/null +++ b/pilot/server/prompt/service.py @@ -0,0 +1,80 @@ +from datetime import datetime + +from pilot.server.prompt.request.request import PromptManageRequest +from pilot.server.prompt.request.response import PromptQueryResponse +from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity + +prompt_manage_dao = PromptManageDao() + + +class PromptManageService: + def __init__(self): + pass + + """create prompt""" + + def create_prompt(self, request: PromptManageRequest): + query = PromptManageRequest( + prompt_name=request.prompt_name, + ) + prompt_name = prompt_manage_dao.get_prompts(query) + if len(prompt_name) > 0: + raise Exception(f"prompt name:{request.prompt_name} have already named") + prompt_manage_dao.create_prompt(request) + return True + + """get prompts""" + + def get_prompts(self, request: PromptManageRequest): + query = PromptManageRequest( + chat_scene=request.chat_scene, + sub_chat_scene=request.sub_chat_scene, + prompt_type=request.prompt_type, + prompt_name=request.prompt_name, + user_name=request.user_name, + ) + responses = [] + prompts = prompt_manage_dao.get_prompts(query) + for prompt in prompts: + res = PromptQueryResponse() + + res.id = prompt.id + res.chat_scene = prompt.chat_scene + res.sub_chat_scene = prompt.sub_chat_scene + res.prompt_type = prompt.prompt_type + res.content = prompt.content + res.user_name = prompt.user_name + res.prompt_name = prompt.prompt_name + res.gmt_created = prompt.gmt_created + res.gmt_modified = prompt.gmt_modified + responses.append(res) + return responses + + """update prompt""" + + def update_prompt(self, request: PromptManageRequest): + query = PromptManageEntity(prompt_name=request.prompt_name) + prompts = prompt_manage_dao.get_prompts(query) + if len(prompts) != 1: + raise Exception( + f"there are no or more than one space called {request.prompt_name}" + ) + prompt = prompts[0] + prompt.chat_scene = request.chat_scene + prompt.sub_chat_scene = request.sub_chat_scene + prompt.prompt_type = request.prompt_type + prompt.content = request.content + prompt.user_name = request.user_name + prompt.gmt_modified = datetime.now() + return prompt_manage_dao.update_prompt(prompt) + + """delete prompt""" + + def delete_prompt(self, prompt_name: str): + query = PromptManageEntity(prompt_name=prompt_name) + prompts = prompt_manage_dao.get_prompts(query) + if len(prompts) == 0: + raise Exception(f"delete error, no prompt name:{prompt_name} in database ") + # delete prompt + prompt = prompts[0] + return prompt_manage_dao.delete_prompt(prompt) diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index f41043601..d4850ec08 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -2,7 +2,7 @@ import json import uuid from pilot.common.schema import DBType -from pilot.componet import SystemApp +from pilot.component import SystemApp from pilot.configs.config import Config from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, @@ -36,7 +36,7 @@ class DBSummaryClient: from pilot.embedding_engine.embedding_factory import EmbeddingFactory db_summary_client = RdbmsSummary(dbname, db_type) - embedding_factory = self.system_app.get_componet( + embedding_factory = self.system_app.get_component( "embedding_factory", EmbeddingFactory ) embeddings = embedding_factory.create( @@ -94,7 +94,7 @@ class DBSummaryClient: "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) knowledge_embedding_client = EmbeddingEngine( @@ -117,7 +117,7 @@ class DBSummaryClient: "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - embedding_factory = CFG.SYSTEM_APP.get_componet( + embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) knowledge_embedding_client = EmbeddingEngine( diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py index b67282442..5be747e23 100644 --- a/pilot/utils/parameter_utils.py +++ b/pilot/utils/parameter_utils.py @@ -12,6 +12,7 @@ class ParameterDescription: param_type: str default_value: Optional[Any] description: str + required: Optional[bool] valid_values: Optional[List[Any]] ext_metadata: Dict @@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type: return type_mapping.get(type_str, str) -def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: +def _get_parameter_descriptions( + dataclass_type: Type, **kwargs +) -> List[ParameterDescription]: descriptions = [] for field in fields(dataclass_type): ext_metadata = { k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"] } - + default_value = field.default if field.default != MISSING else None + if field.name in kwargs: + default_value = kwargs[field.name] descriptions.append( ParameterDescription( param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}", param_name=field.name, param_type=EnvArgumentParser._get_argparse_type_str(field.type), description=field.metadata.get("help", None), - default_value=field.default if field.default != MISSING else None, + required=field.default is MISSING, + default_value=default_value, valid_values=field.metadata.get("valid_values", None), ext_metadata=ext_metadata, )