chore:merge main

This commit is contained in:
aries_ckt 2023-09-19 11:52:51 +08:00
commit 86ad276244
30 changed files with 473 additions and 97 deletions

View File

@ -0,0 +1,16 @@
CREATE DATABASE prompt_management;
use prompt_management;
CREATE TABLE `prompt_manage` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '场景',
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '子场景',
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '类型: common or private',
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt的名字',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'prompt的内容',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '用户名',
`gmt_created` datetime DEFAULT NULL,
`gmt_modified` datetime DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
KEY `gmt_created_idx` (`gmt_created`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='prompt管理表';

View File

@ -42,15 +42,16 @@ class LifeCycle:
pass pass
class ComponetType(str, Enum): class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller" MODEL_CONTROLLER = "dbgpt_model_controller"
class BaseComponet(LifeCycle, ABC): class BaseComponent(LifeCycle, ABC):
"""Abstract Base Component class. All custom components should extend this.""" """Abstract Base Component class. All custom components should extend this."""
name = "base_dbgpt_componet" name = "base_dbgpt_component"
def __init__(self, system_app: Optional[SystemApp] = None): def __init__(self, system_app: Optional[SystemApp] = None):
if system_app is not None: if system_app is not None:
@ -66,15 +67,15 @@ class BaseComponet(LifeCycle, ABC):
pass pass
T = TypeVar("T", bound=BaseComponet) T = TypeVar("T", bound=BaseComponent)
class SystemApp(LifeCycle): class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components.""" """Main System Application class that manages the lifecycle and registration of components."""
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None: def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
self.componets: Dict[ self.components: Dict[
str, BaseComponet str, BaseComponent
] = {} # Dictionary to store registered components. ] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app self._asgi_app = asgi_app
@ -83,58 +84,60 @@ class SystemApp(LifeCycle):
"""Returns the internal ASGI app.""" """Returns the internal ASGI app."""
return self._asgi_app return self._asgi_app
def register(self, componet: Type[BaseComponet], *args, **kwargs): def register(self, component: Type[BaseComponent], *args, **kwargs):
"""Register a new component by its type.""" """Register a new component by its type."""
instance = componet(self, *args, **kwargs) instance = component(self, *args, **kwargs)
self.register_instance(instance) self.register_instance(instance)
def register_instance(self, instance: T): def register_instance(self, instance: T):
"""Register an already initialized component.""" """Register an already initialized component."""
name = instance.name name = instance.name
if isinstance(name, ComponetType): if isinstance(name, ComponentType):
name = name.value name = name.value
if name in self.componets: if name in self.components:
raise RuntimeError( raise RuntimeError(
f"Componse name {name} already exists: {self.componets[name]}" f"Componse name {name} already exists: {self.components[name]}"
) )
logger.info(f"Register componet with name {name} and instance: {instance}") logger.info(f"Register component with name {name} and instance: {instance}")
self.componets[name] = instance self.components[name] = instance
instance.init_app(self) instance.init_app(self)
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T: def get_component(
self, name: Union[str, ComponentType], component_type: Type[T]
) -> T:
"""Retrieve a registered component by its name and type.""" """Retrieve a registered component by its name and type."""
if isinstance(name, ComponetType): if isinstance(name, ComponentType):
name = name.value name = name.value
component = self.componets.get(name) component = self.components.get(name)
if not component: if not component:
raise ValueError(f"No component found with name {name}") raise ValueError(f"No component found with name {name}")
if not isinstance(component, componet_type): if not isinstance(component, component_type):
raise TypeError(f"Component {name} is not of type {componet_type}") raise TypeError(f"Component {name} is not of type {component_type}")
return component return component
def before_start(self): def before_start(self):
"""Invoke the before_start hooks for all registered components.""" """Invoke the before_start hooks for all registered components."""
for _, v in self.componets.items(): for _, v in self.components.items():
v.before_start() v.before_start()
async def async_before_start(self): async def async_before_start(self):
"""Asynchronously invoke the before_start hooks for all registered components.""" """Asynchronously invoke the before_start hooks for all registered components."""
tasks = [v.async_before_start() for _, v in self.componets.items()] tasks = [v.async_before_start() for _, v in self.components.items()]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def after_start(self): def after_start(self):
"""Invoke the after_start hooks for all registered components.""" """Invoke the after_start hooks for all registered components."""
for _, v in self.componets.items(): for _, v in self.components.items():
v.after_start() v.after_start()
async def async_after_start(self): async def async_after_start(self):
"""Asynchronously invoke the after_start hooks for all registered components.""" """Asynchronously invoke the after_start hooks for all registered components."""
tasks = [v.async_after_start() for _, v in self.componets.items()] tasks = [v.async_after_start() for _, v in self.components.items()]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def before_stop(self): def before_stop(self):
"""Invoke the before_stop hooks for all registered components.""" """Invoke the before_stop hooks for all registered components."""
for _, v in self.componets.items(): for _, v in self.components.items():
try: try:
v.before_stop() v.before_stop()
except Exception as e: except Exception as e:
@ -142,7 +145,7 @@ class SystemApp(LifeCycle):
async def async_before_stop(self): async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components.""" """Asynchronously invoke the before_stop hooks for all registered components."""
tasks = [v.async_before_stop() for _, v in self.componets.items()] tasks = [v.async_before_stop() for _, v in self.components.items()]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def _build(self): def _build(self):

View File

@ -189,7 +189,7 @@ class Config(metaclass=Singleton):
### Log level ### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO") self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
from pilot.componet import SystemApp from pilot.component import SystemApp
self.SYSTEM_APP: SystemApp = None self.SYSTEM_APP: SystemApp = None

View File

@ -4,7 +4,7 @@ import asyncio
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.connections.rdbms.conn_mysql import MySQLConnect from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.base import BaseConnect from pilot.connections.base import BaseConnect

View File

@ -1,13 +1,10 @@
import re import re
from typing import Optional, Any from typing import Optional, Any
from pyspark import SQLContext
from sqlalchemy import text from sqlalchemy import text
from pilot.connections.rdbms.base import RDBMSDatabase
from pyspark.sql import SparkSession, DataFrame from pyspark.sql import SparkSession, DataFrame
from sqlalchemy import create_engine
class SparkConnect: class SparkConnect:

View File

@ -2,13 +2,13 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING from typing import Any, Type, TYPE_CHECKING
from pilot.componet import BaseComponet from pilot.component import BaseComponent
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
class EmbeddingFactory(BaseComponet, ABC): class EmbeddingFactory(BaseComponent, ABC):
name = "embedding_factory" name = "embedding_factory"
@abstractmethod @abstractmethod

View File

@ -3,7 +3,7 @@
from enum import Enum from enum import Enum
from typing import TypedDict, Optional, Dict, List from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass from dataclasses import dataclass, asdict
from datetime import datetime from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription from pilot.utils.parameter_utils import ParameterDescription
@ -84,3 +84,25 @@ class WorkerSupportedModel:
] ]
worker_data["models"] = models worker_data["models"] = models
return cls(**worker_data) return cls(**worker_data)
@dataclass
class FlatSupportedModel(SupportedModel):
"""For web"""
host: str
port: int
@staticmethod
def from_supports(
supports: List[WorkerSupportedModel],
) -> List["FlatSupportedModel"]:
results = []
for s in supports:
host, port, models = s.host, s.port, s.models
for m in models:
kwargs = asdict(m)
kwargs["host"] = host
kwargs["port"] = port
results.append(FlatSupportedModel(**kwargs))
return results

View File

@ -5,6 +5,7 @@ from pilot.model.cluster.base import (
WorkerParameterRequest, WorkerParameterRequest,
WorkerStartupRequest, WorkerStartupRequest,
) )
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.default_worker import DefaultModelWorker from pilot.model.cluster.worker.default_worker import DefaultModelWorker
@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.controller.controller import ( from pilot.model.cluster.controller.controller import (
ModelRegistryClient, ModelRegistryClient,
run_model_controller, run_model_controller,
BaseModelController,
) )
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
@ -28,6 +30,7 @@ __all__ = [
"WorkerApplyRequest", "WorkerApplyRequest",
"WorkerParameterRequest", "WorkerParameterRequest",
"WorkerStartupRequest", "WorkerStartupRequest",
"WorkerManagerFactory",
"ModelWorker", "ModelWorker",
"DefaultModelWorker", "DefaultModelWorker",
"worker_manager", "worker_manager",

View File

@ -4,7 +4,7 @@ import logging
from typing import List from typing import List
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from pilot.componet import BaseComponet, ComponetType, SystemApp from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import ModelInstance from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters from pilot.model.parameter import ModelControllerParameters
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
@ -15,8 +15,8 @@ from pilot.utils.api_utils import (
) )
class BaseModelController(BaseComponet, ABC): class BaseModelController(BaseComponent, ABC):
name = ComponetType.MODEL_CONTROLLER name = ComponentType.MODEL_CONTROLLER
def init_app(self, system_app: SystemApp): def init_app(self, system_app: SystemApp):
pass pass

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Dict, Iterator, Callable
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from concurrent.futures import Future from concurrent.futures import Future
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
@ -104,3 +105,14 @@ class WorkerManager(ABC):
self, worker_type: str, model_name: str self, worker_type: str, model_name: str
) -> List[ParameterDescription]: ) -> List[ParameterDescription]:
"""Get parameter descriptions of model""" """Get parameter descriptions of model"""
class WorkerManagerFactory(BaseComponent, ABC):
name = ComponentType.WORKER_MANAGER_FACTORY.value
def init_app(self, system_app: SystemApp):
pass
@abstractmethod
def create(self) -> WorkerManager:
"""Create worker manager"""

View File

@ -7,10 +7,11 @@ import random
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pilot.component import SystemApp
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.model.base import ( from pilot.model.base import (
ModelInstance, ModelInstance,
@ -23,7 +24,11 @@ from pilot.model.cluster.registry import ModelRegistry
from pilot.model.llm_utils import list_supported_models from pilot.model.llm_utils import list_supported_models
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData from pilot.model.cluster.manager_base import (
WorkerManager,
WorkerRunData,
WorkerManagerFactory,
)
from pilot.model.cluster.base import * from pilot.model.cluster.base import *
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.utils.parameter_utils import ( from pilot.utils.parameter_utils import (
@ -548,6 +553,17 @@ class WorkerManagerAdapter(WorkerManager):
return await self.worker_manager.parameter_descriptions(worker_type, model_name) return await self.worker_manager.parameter_descriptions(worker_type, model_name)
class _DefaultWorkerManagerFactory(WorkerManagerFactory):
def __init__(
self, system_app: SystemApp | None = None, worker_manager: WorkerManager = None
):
super().__init__(system_app)
self.worker_manager = worker_manager
def create(self) -> WorkerManager:
return self.worker_manager
worker_manager = WorkerManagerAdapter() worker_manager = WorkerManagerAdapter()
router = APIRouter() router = APIRouter()
@ -787,6 +803,7 @@ def initialize_worker_manager_in_client(
embedding_model_name: str = None, embedding_model_name: str = None,
embedding_model_path: str = None, embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None, start_listener: Callable[["WorkerManager"], None] = None,
system_app: SystemApp = None,
): ):
"""Initialize WorkerManager in client. """Initialize WorkerManager in client.
If run_locally is True: If run_locally is True:
@ -845,6 +862,8 @@ def initialize_worker_manager_in_client(
if include_router and app: if include_router and app:
# mount WorkerManager router # mount WorkerManager router
app.include_router(router, prefix="/api") app.include_router(router, prefix="/api")
if system_app:
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
def run_worker_manager( def run_worker_manager(

View File

@ -172,7 +172,9 @@ def _list_supported_models(
llm_adapter = get_llm_model_adapter(model_name, model_path) llm_adapter = get_llm_model_adapter(model_name, model_path)
param_cls = llm_adapter.model_param_class() param_cls = llm_adapter.model_param_class()
model.enabled = True model.enabled = True
params = _get_parameter_descriptions(param_cls) params = _get_parameter_descriptions(
param_cls, model_name=model_name, model_path=model_path
)
model.params = params model.params = params
except Exception: except Exception:
pass pass

View File

@ -3,6 +3,7 @@ import uuid
import asyncio import asyncio
import os import os
import shutil import shutil
import logging
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Request, Request,
@ -11,6 +12,7 @@ from fastapi import (
Form, Form,
Body, Body,
BackgroundTasks, BackgroundTasks,
Depends,
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -18,7 +20,7 @@ from fastapi.exceptions import RequestValidationError
from typing import List from typing import List
import tempfile import tempfile
from pilot.componet import ComponetType from pilot.component import ComponentType
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
Result, Result,
ConversationVo, ConversationVo,
@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
CHAT_FACTORY = ChatFactory() CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log") logger = logging.getLogger(__name__)
knowledge_service = KnowledgeService() knowledge_service = KnowledgeService()
model_semaphore = None model_semaphore = None
@ -90,6 +95,20 @@ def knowledge_list():
return params return params
def get_model_controller() -> BaseModelController:
controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController
)
return controller
def get_worker_manager() -> WorkerManager:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return worker_manager
@router.get("/v1/chat/db/list", response_model=Result[DBConfig]) @router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list(): async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
@router.get("/v1/model/types") @router.get("/v1/model/types")
async def model_types(request: Request): async def model_types(controller: BaseModelController = Depends(get_model_controller)):
print(f"/controller/model/types") logger.info(f"/controller/model/types")
try: try:
types = set() types = set()
from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet(
ComponetType.MODEL_CONTROLLER, BaseModelController
)
models = await controller.get_all_instances(healthy_only=True) models = await controller.get_all_instances(healthy_only=True)
for model in models: for model in models:
worker_name, worker_type = model.model_name.split("@") worker_name, worker_type = model.model_name.split("@")
@ -371,6 +385,16 @@ async def model_types(request: Request):
return Result.faild(code="E000X", msg=f"controller model types error {e}") return Result.faild(code="E000X", msg=f"controller model types error {e}")
@router.get("/v1/model/supports")
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
logger.info(f"/controller/model/supports")
try:
models = await worker_manager.supported_models()
return Result.succ(FlatSupportedModel.from_supports(models))
except Exception as e:
return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}")
async def no_stream_generator(chat): async def no_stream_generator(chat):
msg = await chat.nostream_call() msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")

View File

@ -6,6 +6,7 @@ from typing import Any, List, Dict
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.component import ComponentType
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory
@ -142,8 +143,11 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
from pilot.model.cluster import worker_manager from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
async for output in worker_manager.generate_stream(payload): async for output in worker_manager.generate_stream(payload):
yield output yield output
except Exception as e: except Exception as e:
@ -160,7 +164,11 @@ class BaseChat(ABC):
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
from pilot.model.cluster import worker_manager from pilot.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
model_output = await worker_manager.generate(payload) model_output = await worker_manager.generate(payload)

View File

@ -49,7 +49,7 @@ class ChatKnowledge(BaseChat):
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
self.knowledge_embedding_client = EmbeddingEngine( self.knowledge_embedding_client = EmbeddingEngine(

View File

@ -6,7 +6,7 @@ from typing import Optional, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.utils.parameter_utils import BaseParameters from pilot.utils.parameter_utils import BaseParameters
@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp):
def _create_model_start_listener(system_app: SystemApp): def _create_model_start_listener(system_app: SystemApp):
from pilot.connections.manages.connection_manager import ConnectManager from pilot.connections.manages.connection_manager import ConnectManager
from pilot.model.cluster import worker_manager
cfg = Config() cfg = Config()

View File

@ -1,14 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
import logging import logging
from pilot.configs.model_config import get_device from typing import TYPE_CHECKING, Any, Type
from pilot.embedding_engine.embedding_factory import (
EmbeddingFactory, from pilot.component import ComponentType, SystemApp
DefaultEmbeddingFactory, from pilot.embedding_engine.embedding_factory import EmbeddingFactory
)
from pilot.server.base import WebWerverParameters from pilot.server.base import WebWerverParameters
if TYPE_CHECKING: if TYPE_CHECKING:
@ -18,7 +14,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def initialize_componets( def initialize_components(
param: WebWerverParameters, param: WebWerverParameters,
system_app: SystemApp, system_app: SystemApp,
embedding_model_name: str, embedding_model_name: str,
@ -39,13 +35,9 @@ def _initialize_embedding_model(
embedding_model_name: str, embedding_model_name: str,
embedding_model_path: str, embedding_model_path: str,
): ):
from pilot.model.cluster import worker_manager
if param.remote_embedding: if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory") logger.info("Register remote RemoteEmbeddingFactory")
system_app.register( system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
else: else:
logger.info(f"Register local LocalEmbeddingFactory") logger.info(f"Register local LocalEmbeddingFactory")
system_app.register( system_app.register(
@ -56,26 +48,28 @@ def _initialize_embedding_model(
class RemoteEmbeddingFactory(EmbeddingFactory): class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__( def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
) -> None:
super().__init__(system_app=system_app) super().__init__(system_app=system_app)
self._worker_manager = worker_manager
self._default_model_name = model_name self._default_model_name = model_name
self.kwargs = kwargs self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app): def init_app(self, system_app):
pass self.system_app = system_app
def create( def create(
self, model_name: str = None, embedding_cls: Type = None self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings": ) -> "Embeddings":
from pilot.model.cluster import WorkerManagerFactory
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls: if embedding_cls:
raise NotImplementedError raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args # Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager) return RemoteEmbeddings(self._default_model_name, worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory): class LocalEmbeddingFactory(EmbeddingFactory):
@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory):
return self._model return self._model
def _load_model(self) -> "Embeddings": def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters self._default_model_name, EmbeddingModelParameters

View File

@ -8,14 +8,14 @@ sys.path.append(ROOT_PATH)
import signal import signal
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.server.base import ( from pilot.server.base import (
server_init, server_init,
WebWerverParameters, WebWerverParameters,
_create_model_start_listener, _create_model_start_listener,
) )
from pilot.server.componet_configs import initialize_componets from pilot.server.component_configs import initialize_components
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications from fastapi import FastAPI, applications
@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router from pilot.server.knowledge.api import router as knowledge_router
from pilot.server.prompt.api import router as prompt_router
from pilot.server.llm_manage.api import router as llm_manage_api from pilot.server.llm_manage.api import router as llm_manage_api
@ -76,6 +77,7 @@ app.include_router(llm_manage_api, prefix="/api")
# app.include_router(api_v1) # app.include_router(api_v1)
app.include_router(knowledge_router) app.include_router(knowledge_router)
app.include_router(prompt_router)
# app.include_router(api_editor_route_v1) # app.include_router(api_editor_route_v1)
@ -118,7 +120,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
server_init(param, system_app) server_init(param, system_app)
model_start_listener = _create_model_start_listener(system_app) model_start_listener = _create_model_start_listener(system_app)
initialize_componets(param, system_app, embedding_model_name, embedding_model_path) initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not param.light: if not param.light:
@ -133,6 +135,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
embedding_model_name=embedding_model_name, embedding_model_name=embedding_model_name,
embedding_model_path=embedding_model_path, embedding_model_path=embedding_model_path,
start_listener=model_start_listener, start_listener=model_start_listener,
system_app=system_app,
) )
CFG.NEW_SERVER_MODE = True CFG.NEW_SERVER_MODE = True
@ -146,6 +149,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
controller_addr=CFG.MODEL_SERVER, controller_addr=CFG.MODEL_SERVER,
local_port=param.port, local_port=param.port,
start_listener=model_start_listener, start_listener=model_start_listener,
system_app=system_app,
) )
CFG.SERVER_LIGHT_MODE = True CFG.SERVER_LIGHT_MODE = True

View File

@ -181,7 +181,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
@router.post("/knowledge/{vector_name}/query") @router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest): def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}") print(f"Received params: {space_name}, {query_request}")
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
client = EmbeddingEngine( client = EmbeddingEngine(

View File

@ -205,7 +205,7 @@ class KnowledgeService:
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
) )
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
client = EmbeddingEngine( client = EmbeddingEngine(

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from pilot.componet import ComponetType from pilot.component import ComponentType
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.model.base import ModelInstance, WorkerApplyType from pilot.model.base import ModelInstance, WorkerApplyType
@ -31,8 +31,8 @@ async def model_list():
try: try:
from pilot.model.cluster.controller.controller import BaseModelController from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet( controller = CFG.SYSTEM_APP.get_component(
ComponetType.MODEL_CONTROLLER, BaseModelController ComponentType.MODEL_CONTROLLER, BaseModelController
) )
responses = [] responses = []
managers = await controller.get_all_instances( managers = await controller.get_all_instances(
@ -70,8 +70,8 @@ async def model_start(request: WorkerStartupRequest):
try: try:
from pilot.model.cluster.controller.controller import BaseModelController from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet( controller = CFG.SYSTEM_APP.get_component(
ComponetType.MODEL_CONTROLLER, BaseModelController ComponentType.MODEL_CONTROLLER, BaseModelController
) )
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
worker_instance = None worker_instance = None
@ -98,8 +98,8 @@ async def model_start(request: WorkerStartupRequest):
try: try:
from pilot.model.cluster.controller.controller import BaseModelController from pilot.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_componet( controller = CFG.SYSTEM_APP.get_component(
ComponetType.MODEL_CONTROLLER, BaseModelController ComponentType.MODEL_CONTROLLER, BaseModelController
) )
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True) instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
worker_instance = None worker_instance = None

View File

View File

@ -0,0 +1,46 @@
from fastapi import APIRouter, File, UploadFile, Form
from pilot.openapi.api_view_model import Result
from pilot.server.prompt.service import PromptManageService
from pilot.server.prompt.request.request import PromptManageRequest
router = APIRouter()
prompt_manage_service = PromptManageService()
@router.post("/prompt/add")
def prompt_add(request: PromptManageRequest):
print(f"/space/add params: {request}")
try:
prompt_manage_service.create_prompt(request)
return Result.succ([])
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt add error {e}")
@router.post("/prompt/list")
def prompt_list(request: PromptManageRequest):
print(f"/prompt/list params: {request}")
try:
return Result.succ(prompt_manage_service.get_prompts(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt list error {e}")
@router.post("/prompt/update")
def prompt_update(request: PromptManageRequest):
print(f"/prompt/update params: {request}")
try:
return Result.succ(prompt_manage_service.update_prompt(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt update error {e}")
@router.post("/prompt/delete")
def prompt_delete(request: PromptManageRequest):
print(f"/prompt/delete params: {request}")
try:
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt delete error {e}")

View File

@ -0,0 +1,91 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.server.prompt.request.request import PromptManageRequest
CFG = Config()
Base = declarative_base()
class PromptManageEntity(Base):
__tablename__ = "prompt_manage"
id = Column(Integer, primary_key=True)
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class PromptManageDao(BaseDao):
def __init__(self):
super().__init__(
database="prompt_management", orm_base=Base, create_not_exist_table=True
)
def create_prompt(self, prompt: PromptManageRequest):
session = self.Session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
prompt_type=prompt.prompt_type,
prompt_name=prompt.prompt_name,
content=prompt.content,
user_name=prompt.user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
session.add(prompt_manage)
session.commit()
session.close()
def get_prompts(self, query: PromptManageEntity):
session = self.Session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
if query.sub_chat_scene is not None:
prompts = prompts.filter(
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
)
if query.prompt_type is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_type == query.prompt_type
)
if query.prompt_type == "private" and query.user_name is not None:
prompts = prompts.filter(
PromptManageEntity.user_name == query.user_name
)
if query.prompt_name is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
session.close()
return result
def update_prompt(self, prompt: PromptManageEntity):
session = self.Session()
session.merge(prompt)
session.commit()
session.close()
def delete_prompt(self, prompt: PromptManageEntity):
session = self.Session()
if prompt:
session.delete(prompt)
session.commit()
session.close()

View File

View File

@ -0,0 +1,24 @@
from typing import List
from pydantic import BaseModel
class PromptManageRequest(BaseModel):
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None

View File

@ -0,0 +1,26 @@
from typing import List
from pydantic import BaseModel
class PromptQueryResponse(BaseModel):
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None

View File

@ -0,0 +1,80 @@
from datetime import datetime
from pilot.server.prompt.request.request import PromptManageRequest
from pilot.server.prompt.request.response import PromptQueryResponse
from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
prompt_manage_dao = PromptManageDao()
class PromptManageService:
def __init__(self):
pass
"""create prompt"""
def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
)
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
raise Exception(f"prompt name:{request.prompt_name} have already named")
prompt_manage_dao.create_prompt(request)
return True
"""get prompts"""
def get_prompts(self, request: PromptManageRequest):
query = PromptManageRequest(
chat_scene=request.chat_scene,
sub_chat_scene=request.sub_chat_scene,
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
for prompt in prompts:
res = PromptQueryResponse()
res.id = prompt.id
res.chat_scene = prompt.chat_scene
res.sub_chat_scene = prompt.sub_chat_scene
res.prompt_type = prompt.prompt_type
res.content = prompt.content
res.user_name = prompt.user_name
res.prompt_name = prompt.prompt_name
res.gmt_created = prompt.gmt_created
res.gmt_modified = prompt.gmt_modified
responses.append(res)
return responses
"""update prompt"""
def update_prompt(self, request: PromptManageRequest):
query = PromptManageEntity(prompt_name=request.prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) != 1:
raise Exception(
f"there are no or more than one space called {request.prompt_name}"
)
prompt = prompts[0]
prompt.chat_scene = request.chat_scene
prompt.sub_chat_scene = request.sub_chat_scene
prompt.prompt_type = request.prompt_type
prompt.content = request.content
prompt.user_name = request.user_name
prompt.gmt_modified = datetime.now()
return prompt_manage_dao.update_prompt(prompt)
"""delete prompt"""
def delete_prompt(self, prompt_name: str):
query = PromptManageEntity(prompt_name=prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) == 0:
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
# delete prompt
prompt = prompts[0]
return prompt_manage_dao.delete_prompt(prompt)

View File

@ -2,7 +2,7 @@ import json
import uuid import uuid
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.componet import SystemApp from pilot.component import SystemApp
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -36,7 +36,7 @@ class DBSummaryClient:
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type) db_summary_client = RdbmsSummary(dbname, db_type)
embedding_factory = self.system_app.get_componet( embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
embeddings = embedding_factory.create( embeddings = embedding_factory.create(
@ -94,7 +94,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
knowledge_embedding_client = EmbeddingEngine( knowledge_embedding_client = EmbeddingEngine(
@ -117,7 +117,7 @@ class DBSummaryClient:
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
embedding_factory = CFG.SYSTEM_APP.get_componet( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
knowledge_embedding_client = EmbeddingEngine( knowledge_embedding_client = EmbeddingEngine(

View File

@ -12,6 +12,7 @@ class ParameterDescription:
param_type: str param_type: str
default_value: Optional[Any] default_value: Optional[Any]
description: str description: str
required: Optional[bool]
valid_values: Optional[List[Any]] valid_values: Optional[List[Any]]
ext_metadata: Dict ext_metadata: Dict
@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type:
return type_mapping.get(type_str, str) return type_mapping.get(type_str, str)
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: def _get_parameter_descriptions(
dataclass_type: Type, **kwargs
) -> List[ParameterDescription]:
descriptions = [] descriptions = []
for field in fields(dataclass_type): for field in fields(dataclass_type):
ext_metadata = { ext_metadata = {
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"] k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
} }
default_value = field.default if field.default != MISSING else None
if field.name in kwargs:
default_value = kwargs[field.name]
descriptions.append( descriptions.append(
ParameterDescription( ParameterDescription(
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}", param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
param_name=field.name, param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type), param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None), description=field.metadata.get("help", None),
default_value=field.default if field.default != MISSING else None, required=field.default is MISSING,
default_value=default_value,
valid_values=field.metadata.get("valid_values", None), valid_values=field.metadata.get("valid_values", None),
ext_metadata=ext_metadata, ext_metadata=ext_metadata,
) )