chore:merge main

This commit is contained in:
aries_ckt 2023-09-19 11:52:51 +08:00
commit 86ad276244
30 changed files with 473 additions and 97 deletions

View File

@ -0,0 +1,16 @@
CREATE DATABASE prompt_management;
use prompt_management;
CREATE TABLE `prompt_manage` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '场景',
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '子场景',
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '类型: common or private',
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt的名字',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'prompt的内容',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '用户名',
`gmt_created` datetime DEFAULT NULL,
`gmt_modified` datetime DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
KEY `gmt_created_idx` (`gmt_created`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='prompt管理表';

View File

@ -42,15 +42,16 @@ class LifeCycle:
pass
class ComponetType(str, Enum):
class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
class BaseComponet(LifeCycle, ABC):
class BaseComponent(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this."""
name = "base_dbgpt_componet"
name = "base_dbgpt_component"
def __init__(self, system_app: Optional[SystemApp] = None):
if system_app is not None:
@ -66,15 +67,15 @@ class BaseComponet(LifeCycle, ABC):
pass
T = TypeVar("T", bound=BaseComponet)
T = TypeVar("T", bound=BaseComponent)
class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components."""
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
self.componets: Dict[
str, BaseComponet
self.components: Dict[
str, BaseComponent
] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app
@ -83,58 +84,60 @@ class SystemApp(LifeCycle):
"""Returns the internal ASGI app."""
return self._asgi_app
def register(self, componet: Type[BaseComponet], *args, **kwargs):
def register(self, component: Type[BaseComponent], *args, **kwargs):
"""Register a new component by its type."""
instance = componet(self, *args, **kwargs)
instance = component(self, *args, **kwargs)
self.register_instance(instance)
def register_instance(self, instance: T):
"""Register an already initialized component."""
name = instance.name
if isinstance(name, ComponetType):
if isinstance(name, ComponentType):
name = name.value
if name in self.componets:
if name in self.components:
raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}"
f"Componse name {name} already exists: {self.components[name]}"
)
logger.info(f"Register componet with name {name} and instance: {instance}")
self.componets[name] = instance
logger.info(f"Register component with name {name} and instance: {instance}")
self.components[name] = instance
instance.init_app(self)
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T:
def get_component(
self, name: Union[str, ComponentType], component_type: Type[T]
) -> T:
"""Retrieve a registered component by its name and type."""
if isinstance(name, ComponetType):
if isinstance(name, ComponentType):
name = name.value
component = self.componets.get(name)
component = self.components.get(name)
if not component:
raise ValueError(f"No component found with name {name}")
if not isinstance(component, componet_type):
raise TypeError(f"Component {name} is not of type {componet_type}")
if not isinstance(component, component_type):
raise TypeError(f"Component {name} is not of type {component_type}")
return component
def before_start(self):
"""Invoke the before_start hooks for all registered components."""
for _, v in self.componets.items():
for _, v in self.components.items():
v.before_start()
async def async_before_start(self):
"""Asynchronously invoke the before_start hooks for all registered components."""
tasks = [v.async_before_start() for _, v in self.componets.items()]
tasks = [v.async_before_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def after_start(self):
"""Invoke the after_start hooks for all registered components."""
for _, v in self.componets.items():
for _, v in self.components.items():
v.after_start()
async def async_after_start(self):
"""Asynchronously invoke the after_start hooks for all registered components."""
tasks = [v.async_after_start() for _, v in self.componets.items()]
tasks = [v.async_after_start() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def before_stop(self):
"""Invoke the before_stop hooks for all registered components."""
for _, v in self.componets.items():
for _, v in self.components.items():
try:
v.before_stop()
except Exception as e:
@ -142,7 +145,7 @@ class SystemApp(LifeCycle):
async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components."""
tasks = [v.async_before_stop() for _, v in self.componets.items()]
tasks = [v.async_before_stop() for _, v in self.components.items()]
await asyncio.gather(*tasks)
def _build(self):

View File

@ -189,7 +189,7 @@ class Config(metaclass=Singleton):
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
from pilot.componet import SystemApp
from pilot.component import SystemApp
self.SYSTEM_APP: SystemApp = None

View File

@ -4,7 +4,7 @@ import asyncio
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect

View File

@ -1,13 +1,10 @@
import re
from typing import Optional, Any
from pyspark import SQLContext
from sqlalchemy import text
from pilot.connections.rdbms.base import RDBMSDatabase
from pyspark.sql import SparkSession, DataFrame
from sqlalchemy import create_engine
class SparkConnect:

View File

@ -2,13 +2,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import BaseComponet
from pilot.component import BaseComponent
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
class EmbeddingFactory(BaseComponet, ABC):
class EmbeddingFactory(BaseComponent, ABC):
name = "embedding_factory"
@abstractmethod

View File

@ -3,7 +3,7 @@
from enum import Enum
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription
@ -84,3 +84,25 @@ class WorkerSupportedModel:
]
worker_data["models"] = models
return cls(**worker_data)
@dataclass
class FlatSupportedModel(SupportedModel):
"""For web"""
host: str
port: int
@staticmethod
def from_supports(
supports: List[WorkerSupportedModel],
) -> List["FlatSupportedModel"]:
results = []
for s in supports:
host, port, models = s.host, s.port, s.models
for m in models:
kwargs = asdict(m)
kwargs["host"] = host
kwargs["port"] = port
results.append(FlatSupportedModel(**kwargs))
return results

View File

@ -5,6 +5,7 @@ from pilot.model.cluster.base import (
WorkerParameterRequest,
WorkerStartupRequest,
)
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.controller.controller import (
ModelRegistryClient,
run_model_controller,
BaseModelController,
)
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
@ -28,6 +30,7 @@ __all__ = [
"WorkerApplyRequest",
"WorkerParameterRequest",
"WorkerStartupRequest",
"WorkerManagerFactory",
"ModelWorker",
"DefaultModelWorker",
"worker_manager",

View File

@ -4,7 +4,7 @@ import logging
from typing import List
from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
@ -15,8 +15,8 @@ from pilot.utils.api_utils import (
)
class BaseModelController(BaseComponet, ABC):
name = ComponetType.MODEL_CONTROLLER
class BaseModelController(BaseComponent, ABC):
name = ComponentType.MODEL_CONTROLLER
def init_app(self, system_app: SystemApp):
pass

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Dict, Iterator, Callable
from abc import ABC, abstractmethod
from datetime import datetime
from concurrent.futures import Future
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
@ -104,3 +105,14 @@ class WorkerManager(ABC):
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
"""Get parameter descriptions of model"""
class WorkerManagerFactory(BaseComponent, ABC):
name = ComponentType.WORKER_MANAGER_FACTORY.value
def init_app(self, system_app: SystemApp):
pass
@abstractmethod
def create(self) -> WorkerManager:
"""Create worker manager"""

View File

@ -7,10 +7,11 @@ import random
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
from pilot.component import SystemApp
from pilot.configs.model_config import LOGDIR
from pilot.model.base import (
ModelInstance,
@ -23,7 +24,11 @@ from pilot.model.cluster.registry import ModelRegistry
from pilot.model.llm_utils import list_supported_models
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData
from pilot.model.cluster.manager_base import (
WorkerManager,
WorkerRunData,
WorkerManagerFactory,
)
from pilot.model.cluster.base import *
from pilot.utils import build_logger
from pilot.utils.parameter_utils import (
@ -548,6 +553,17 @@ class WorkerManagerAdapter(WorkerManager):
return await self.worker_manager.parameter_descriptions(worker_type, model_name)
class _DefaultWorkerManagerFactory(WorkerManagerFactory):
def __init__(
self, system_app: SystemApp | None = None, worker_manager: WorkerManager = None
):
super().__init__(system_app)
self.worker_manager = worker_manager
def create(self) -> WorkerManager:
return self.worker_manager
worker_manager = WorkerManagerAdapter()
router = APIRouter()
@ -787,6 +803,7 @@ def initialize_worker_manager_in_client(
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
system_app: SystemApp = None,
):
"""Initialize WorkerManager in client.
If run_locally is True:
@ -845,6 +862,8 @@ def initialize_worker_manager_in_client(
if include_router and app:
# mount WorkerManager router
app.include_router(router, prefix="/api")
if system_app:
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
def run_worker_manager(

View File

@ -172,7 +172,9 @@ def _list_supported_models(
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_cls = llm_adapter.model_param_class()
model.enabled = True
params = _get_parameter_descriptions(param_cls)
params = _get_parameter_descriptions(
param_cls, model_name=model_name, model_path=model_path
)
model.params = params
except Exception:
pass

View File

@ -3,6 +3,7 @@ import uuid
import asyncio
import os
import shutil
import logging
from fastapi import (
APIRouter,
Request,
@ -11,6 +12,7 @@ from fastapi import (
Form,
Body,
BackgroundTasks,
Depends,
)
from fastapi.responses import StreamingResponse
@ -18,7 +20,7 @@ from fastapi.exceptions import RequestValidationError
from typing import List
import tempfile
from pilot.componet import ComponetType
from pilot.component import ComponentType
from pilot.openapi.api_view_model import (
Result,
ConversationVo,
@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
router = APIRouter()
CFG = Config()
CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
logger = logging.getLogger(__name__)
knowledge_service = KnowledgeService()
model_semaphore = None
@ -90,6 +95,20 @@ def knowledge_list():
return params
def get_model_controller() -> BaseModelController:
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
return controller
def get_worker_manager() -> WorkerManager:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return worker_manager
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
@router.get("/v1/model/types")
async def model_types(request: Request):
print(f"/controller/model/types")
async def model_types(controller: BaseModelController = Depends(get_model_controller)):
logger.info(f"/controller/model/types")
try:
types = set()
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
)
models = await controller.get_all_instances(healthy_only=True)
for model in models:
worker_name, worker_type = model.model_name.split("@")
@ -371,6 +385,16 @@ async def model_types(request: Request):
return Result.faild(code="E000X", msg=f"controller model types error {e}")
@router.get("/v1/model/supports")
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
logger.info(f"/controller/model/supports")
try:
models = await worker_manager.supported_models()
return Result.succ(FlatSupportedModel.from_supports(models))
except Exception as e:
return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}")
async def no_stream_generator(chat):
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")

View File

@ -6,6 +6,7 @@ from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.component import ComponentType
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
@ -142,8 +143,11 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
from pilot.model.cluster import worker_manager
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for output in worker_manager.generate_stream(payload):
yield output
except Exception as e:
@ -160,7 +164,11 @@ class BaseChat(ABC):
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
from pilot.model.cluster import worker_manager
from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload)

View File

@ -49,7 +49,7 @@ class ChatKnowledge(BaseChat):
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
self.knowledge_embedding_client = EmbeddingEngine(

View File

@ -6,7 +6,7 @@ from typing import Optional, Any
from dataclasses import dataclass, field
from pilot.configs.config import Config
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.utils.parameter_utils import BaseParameters
@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp):
def _create_model_start_listener(system_app: SystemApp):
from pilot.connections.manages.connection_manager import ConnectManager
from pilot.model.cluster import worker_manager
cfg = Config()

View File

@ -1,14 +1,10 @@
from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
import logging
from pilot.configs.model_config import get_device
from pilot.embedding_engine.embedding_factory import (
EmbeddingFactory,
DefaultEmbeddingFactory,
)
from typing import TYPE_CHECKING, Any, Type
from pilot.component import ComponentType, SystemApp
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
if TYPE_CHECKING:
@ -18,7 +14,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def initialize_componets(
def initialize_components(
param: WebWerverParameters,
system_app: SystemApp,
embedding_model_name: str,
@ -39,13 +35,9 @@ def _initialize_embedding_model(
embedding_model_name: str,
embedding_model_path: str,
):
from pilot.model.cluster import worker_manager
if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
else:
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
@ -56,26 +48,28 @@ def _initialize_embedding_model(
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
) -> None:
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._worker_manager = worker_manager
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app):
pass
self.system_app = system_app
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
from pilot.model.cluster import WorkerManagerFactory
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
return RemoteEmbeddings(self._default_model_name, worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory):
return self._model
def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters

View File

@ -8,14 +8,14 @@ sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.server.base import (
server_init,
WebWerverParameters,
_create_model_start_listener,
)
from pilot.server.componet_configs import initialize_componets
from pilot.server.component_configs import initialize_components
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router
from pilot.server.prompt.api import router as prompt_router
from pilot.server.llm_manage.api import router as llm_manage_api
@ -76,6 +77,7 @@ app.include_router(llm_manage_api, prefix="/api")
# app.include_router(api_v1)
app.include_router(knowledge_router)
app.include_router(prompt_router)
# app.include_router(api_editor_route_v1)
@ -118,7 +120,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
server_init(param, system_app)
model_start_listener = _create_model_start_listener(system_app)
initialize_componets(param, system_app, embedding_model_name, embedding_model_path)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light:
@ -133,6 +135,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
embedding_model_name=embedding_model_name,
embedding_model_path=embedding_model_path,
start_listener=model_start_listener,
system_app=system_app,
)
CFG.NEW_SERVER_MODE = True
@ -146,6 +149,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
controller_addr=CFG.MODEL_SERVER,
local_port=param.port,
start_listener=model_start_listener,
system_app=system_app,
)
CFG.SERVER_LIGHT_MODE = True

View File

@ -181,7 +181,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
@router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}")
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(

View File

@ -205,7 +205,7 @@ class KnowledgeService:
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter
from pilot.componet import ComponetType
from pilot.component import ComponentType
from pilot.configs.config import Config
from pilot.model.base import ModelInstance, WorkerApplyType
@ -31,8 +31,8 @@ async def model_list():
try:
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
responses = []
managers = await controller.get_all_instances(
@ -70,8 +70,8 @@ async def model_start(request: WorkerStartupRequest):
try:
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
worker_instance = None
@ -98,8 +98,8 @@ async def model_start(request: WorkerStartupRequest):
try:
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
worker_instance = None

View File

View File

@ -0,0 +1,46 @@
from fastapi import APIRouter, File, UploadFile, Form
from pilot.openapi.api_view_model import Result
from pilot.server.prompt.service import PromptManageService
from pilot.server.prompt.request.request import PromptManageRequest
router = APIRouter()
prompt_manage_service = PromptManageService()
@router.post("/prompt/add")
def prompt_add(request: PromptManageRequest):
print(f"/space/add params: {request}")
try:
prompt_manage_service.create_prompt(request)
return Result.succ([])
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt add error {e}")
@router.post("/prompt/list")
def prompt_list(request: PromptManageRequest):
print(f"/prompt/list params: {request}")
try:
return Result.succ(prompt_manage_service.get_prompts(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt list error {e}")
@router.post("/prompt/update")
def prompt_update(request: PromptManageRequest):
print(f"/prompt/update params: {request}")
try:
return Result.succ(prompt_manage_service.update_prompt(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt update error {e}")
@router.post("/prompt/delete")
def prompt_delete(request: PromptManageRequest):
print(f"/prompt/delete params: {request}")
try:
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt delete error {e}")

View File

@ -0,0 +1,91 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.server.prompt.request.request import PromptManageRequest
CFG = Config()
Base = declarative_base()
class PromptManageEntity(Base):
__tablename__ = "prompt_manage"
id = Column(Integer, primary_key=True)
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class PromptManageDao(BaseDao):
def __init__(self):
super().__init__(
database="prompt_management", orm_base=Base, create_not_exist_table=True
)
def create_prompt(self, prompt: PromptManageRequest):
session = self.Session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
prompt_type=prompt.prompt_type,
prompt_name=prompt.prompt_name,
content=prompt.content,
user_name=prompt.user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
session.add(prompt_manage)
session.commit()
session.close()
def get_prompts(self, query: PromptManageEntity):
session = self.Session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
if query.sub_chat_scene is not None:
prompts = prompts.filter(
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
)
if query.prompt_type is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_type == query.prompt_type
)
if query.prompt_type == "private" and query.user_name is not None:
prompts = prompts.filter(
PromptManageEntity.user_name == query.user_name
)
if query.prompt_name is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
session.close()
return result
def update_prompt(self, prompt: PromptManageEntity):
session = self.Session()
session.merge(prompt)
session.commit()
session.close()
def delete_prompt(self, prompt: PromptManageEntity):
session = self.Session()
if prompt:
session.delete(prompt)
session.commit()
session.close()

View File

View File

@ -0,0 +1,24 @@
from typing import List
from pydantic import BaseModel
class PromptManageRequest(BaseModel):
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None

View File

@ -0,0 +1,26 @@
from typing import List
from pydantic import BaseModel
class PromptQueryResponse(BaseModel):
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None

View File

@ -0,0 +1,80 @@
from datetime import datetime
from pilot.server.prompt.request.request import PromptManageRequest
from pilot.server.prompt.request.response import PromptQueryResponse
from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
prompt_manage_dao = PromptManageDao()
class PromptManageService:
def __init__(self):
pass
"""create prompt"""
def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
)
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
raise Exception(f"prompt name:{request.prompt_name} have already named")
prompt_manage_dao.create_prompt(request)
return True
"""get prompts"""
def get_prompts(self, request: PromptManageRequest):
query = PromptManageRequest(
chat_scene=request.chat_scene,
sub_chat_scene=request.sub_chat_scene,
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
for prompt in prompts:
res = PromptQueryResponse()
res.id = prompt.id
res.chat_scene = prompt.chat_scene
res.sub_chat_scene = prompt.sub_chat_scene
res.prompt_type = prompt.prompt_type
res.content = prompt.content
res.user_name = prompt.user_name
res.prompt_name = prompt.prompt_name
res.gmt_created = prompt.gmt_created
res.gmt_modified = prompt.gmt_modified
responses.append(res)
return responses
"""update prompt"""
def update_prompt(self, request: PromptManageRequest):
query = PromptManageEntity(prompt_name=request.prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) != 1:
raise Exception(
f"there are no or more than one space called {request.prompt_name}"
)
prompt = prompts[0]
prompt.chat_scene = request.chat_scene
prompt.sub_chat_scene = request.sub_chat_scene
prompt.prompt_type = request.prompt_type
prompt.content = request.content
prompt.user_name = request.user_name
prompt.gmt_modified = datetime.now()
return prompt_manage_dao.update_prompt(prompt)
"""delete prompt"""
def delete_prompt(self, prompt_name: str):
query = PromptManageEntity(prompt_name=prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) == 0:
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
# delete prompt
prompt = prompts[0]
return prompt_manage_dao.delete_prompt(prompt)

View File

@ -2,7 +2,7 @@ import json
import uuid
from pilot.common.schema import DBType
from pilot.componet import SystemApp
from pilot.component import SystemApp
from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -36,7 +36,7 @@ class DBSummaryClient:
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type)
embedding_factory = self.system_app.get_componet(
embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory
)
embeddings = embedding_factory.create(
@ -94,7 +94,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
@ -117,7 +117,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
embedding_factory = CFG.SYSTEM_APP.get_componet(
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(

View File

@ -12,6 +12,7 @@ class ParameterDescription:
param_type: str
default_value: Optional[Any]
description: str
required: Optional[bool]
valid_values: Optional[List[Any]]
ext_metadata: Dict
@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type:
return type_mapping.get(type_str, str)
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
def _get_parameter_descriptions(
dataclass_type: Type, **kwargs
) -> List[ParameterDescription]:
descriptions = []
for field in fields(dataclass_type):
ext_metadata = {
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
}
default_value = field.default if field.default != MISSING else None
if field.name in kwargs:
default_value = kwargs[field.name]
descriptions.append(
ParameterDescription(
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None),
default_value=field.default if field.default != MISSING else None,
required=field.default is MISSING,
default_value=default_value,
valid_values=field.metadata.get("valid_values", None),
ext_metadata=ext_metadata,
)