mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 20:28:07 +00:00
chore:merge main
This commit is contained in:
commit
86ad276244
16
assets/schema/prompt_management.sql
Normal file
16
assets/schema/prompt_management.sql
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
CREATE DATABASE prompt_management;
|
||||||
|
use prompt_management;
|
||||||
|
CREATE TABLE `prompt_manage` (
|
||||||
|
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||||
|
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '场景',
|
||||||
|
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '子场景',
|
||||||
|
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '类型: common or private',
|
||||||
|
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt的名字',
|
||||||
|
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'prompt的内容',
|
||||||
|
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '用户名',
|
||||||
|
`gmt_created` datetime DEFAULT NULL,
|
||||||
|
`gmt_modified` datetime DEFAULT NULL,
|
||||||
|
PRIMARY KEY (`id`),
|
||||||
|
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
|
||||||
|
KEY `gmt_created_idx` (`gmt_created`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='prompt管理表';
|
@ -42,15 +42,16 @@ class LifeCycle:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ComponetType(str, Enum):
|
class ComponentType(str, Enum):
|
||||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||||
|
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
||||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||||
|
|
||||||
|
|
||||||
class BaseComponet(LifeCycle, ABC):
|
class BaseComponent(LifeCycle, ABC):
|
||||||
"""Abstract Base Component class. All custom components should extend this."""
|
"""Abstract Base Component class. All custom components should extend this."""
|
||||||
|
|
||||||
name = "base_dbgpt_componet"
|
name = "base_dbgpt_component"
|
||||||
|
|
||||||
def __init__(self, system_app: Optional[SystemApp] = None):
|
def __init__(self, system_app: Optional[SystemApp] = None):
|
||||||
if system_app is not None:
|
if system_app is not None:
|
||||||
@ -66,15 +67,15 @@ class BaseComponet(LifeCycle, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseComponet)
|
T = TypeVar("T", bound=BaseComponent)
|
||||||
|
|
||||||
|
|
||||||
class SystemApp(LifeCycle):
|
class SystemApp(LifeCycle):
|
||||||
"""Main System Application class that manages the lifecycle and registration of components."""
|
"""Main System Application class that manages the lifecycle and registration of components."""
|
||||||
|
|
||||||
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
|
def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
|
||||||
self.componets: Dict[
|
self.components: Dict[
|
||||||
str, BaseComponet
|
str, BaseComponent
|
||||||
] = {} # Dictionary to store registered components.
|
] = {} # Dictionary to store registered components.
|
||||||
self._asgi_app = asgi_app
|
self._asgi_app = asgi_app
|
||||||
|
|
||||||
@ -83,58 +84,60 @@ class SystemApp(LifeCycle):
|
|||||||
"""Returns the internal ASGI app."""
|
"""Returns the internal ASGI app."""
|
||||||
return self._asgi_app
|
return self._asgi_app
|
||||||
|
|
||||||
def register(self, componet: Type[BaseComponet], *args, **kwargs):
|
def register(self, component: Type[BaseComponent], *args, **kwargs):
|
||||||
"""Register a new component by its type."""
|
"""Register a new component by its type."""
|
||||||
instance = componet(self, *args, **kwargs)
|
instance = component(self, *args, **kwargs)
|
||||||
self.register_instance(instance)
|
self.register_instance(instance)
|
||||||
|
|
||||||
def register_instance(self, instance: T):
|
def register_instance(self, instance: T):
|
||||||
"""Register an already initialized component."""
|
"""Register an already initialized component."""
|
||||||
name = instance.name
|
name = instance.name
|
||||||
if isinstance(name, ComponetType):
|
if isinstance(name, ComponentType):
|
||||||
name = name.value
|
name = name.value
|
||||||
if name in self.componets:
|
if name in self.components:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Componse name {name} already exists: {self.componets[name]}"
|
f"Componse name {name} already exists: {self.components[name]}"
|
||||||
)
|
)
|
||||||
logger.info(f"Register componet with name {name} and instance: {instance}")
|
logger.info(f"Register component with name {name} and instance: {instance}")
|
||||||
self.componets[name] = instance
|
self.components[name] = instance
|
||||||
instance.init_app(self)
|
instance.init_app(self)
|
||||||
|
|
||||||
def get_componet(self, name: Union[str, ComponetType], componet_type: Type[T]) -> T:
|
def get_component(
|
||||||
|
self, name: Union[str, ComponentType], component_type: Type[T]
|
||||||
|
) -> T:
|
||||||
"""Retrieve a registered component by its name and type."""
|
"""Retrieve a registered component by its name and type."""
|
||||||
if isinstance(name, ComponetType):
|
if isinstance(name, ComponentType):
|
||||||
name = name.value
|
name = name.value
|
||||||
component = self.componets.get(name)
|
component = self.components.get(name)
|
||||||
if not component:
|
if not component:
|
||||||
raise ValueError(f"No component found with name {name}")
|
raise ValueError(f"No component found with name {name}")
|
||||||
if not isinstance(component, componet_type):
|
if not isinstance(component, component_type):
|
||||||
raise TypeError(f"Component {name} is not of type {componet_type}")
|
raise TypeError(f"Component {name} is not of type {component_type}")
|
||||||
return component
|
return component
|
||||||
|
|
||||||
def before_start(self):
|
def before_start(self):
|
||||||
"""Invoke the before_start hooks for all registered components."""
|
"""Invoke the before_start hooks for all registered components."""
|
||||||
for _, v in self.componets.items():
|
for _, v in self.components.items():
|
||||||
v.before_start()
|
v.before_start()
|
||||||
|
|
||||||
async def async_before_start(self):
|
async def async_before_start(self):
|
||||||
"""Asynchronously invoke the before_start hooks for all registered components."""
|
"""Asynchronously invoke the before_start hooks for all registered components."""
|
||||||
tasks = [v.async_before_start() for _, v in self.componets.items()]
|
tasks = [v.async_before_start() for _, v in self.components.items()]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def after_start(self):
|
def after_start(self):
|
||||||
"""Invoke the after_start hooks for all registered components."""
|
"""Invoke the after_start hooks for all registered components."""
|
||||||
for _, v in self.componets.items():
|
for _, v in self.components.items():
|
||||||
v.after_start()
|
v.after_start()
|
||||||
|
|
||||||
async def async_after_start(self):
|
async def async_after_start(self):
|
||||||
"""Asynchronously invoke the after_start hooks for all registered components."""
|
"""Asynchronously invoke the after_start hooks for all registered components."""
|
||||||
tasks = [v.async_after_start() for _, v in self.componets.items()]
|
tasks = [v.async_after_start() for _, v in self.components.items()]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def before_stop(self):
|
def before_stop(self):
|
||||||
"""Invoke the before_stop hooks for all registered components."""
|
"""Invoke the before_stop hooks for all registered components."""
|
||||||
for _, v in self.componets.items():
|
for _, v in self.components.items():
|
||||||
try:
|
try:
|
||||||
v.before_stop()
|
v.before_stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -142,7 +145,7 @@ class SystemApp(LifeCycle):
|
|||||||
|
|
||||||
async def async_before_stop(self):
|
async def async_before_stop(self):
|
||||||
"""Asynchronously invoke the before_stop hooks for all registered components."""
|
"""Asynchronously invoke the before_stop hooks for all registered components."""
|
||||||
tasks = [v.async_before_stop() for _, v in self.componets.items()]
|
tasks = [v.async_before_stop() for _, v in self.components.items()]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
@ -189,7 +189,7 @@ class Config(metaclass=Singleton):
|
|||||||
### Log level
|
### Log level
|
||||||
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
from pilot.componet import SystemApp
|
from pilot.component import SystemApp
|
||||||
|
|
||||||
self.SYSTEM_APP: SystemApp = None
|
self.SYSTEM_APP: SystemApp = None
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import asyncio
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.componet import SystemApp
|
from pilot.component import SystemApp
|
||||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||||
from pilot.connections.base import BaseConnect
|
from pilot.connections.base import BaseConnect
|
||||||
|
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
from pyspark import SQLContext
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
|
||||||
from pyspark.sql import SparkSession, DataFrame
|
from pyspark.sql import SparkSession, DataFrame
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
|
|
||||||
|
|
||||||
class SparkConnect:
|
class SparkConnect:
|
||||||
|
@ -2,13 +2,13 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Type, TYPE_CHECKING
|
from typing import Any, Type, TYPE_CHECKING
|
||||||
|
|
||||||
from pilot.componet import BaseComponet
|
from pilot.component import BaseComponent
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFactory(BaseComponet, ABC):
|
class EmbeddingFactory(BaseComponent, ABC):
|
||||||
name = "embedding_factory"
|
name = "embedding_factory"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypedDict, Optional, Dict, List
|
from typing import TypedDict, Optional, Dict, List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, asdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pilot.utils.parameter_utils import ParameterDescription
|
from pilot.utils.parameter_utils import ParameterDescription
|
||||||
|
|
||||||
@ -84,3 +84,25 @@ class WorkerSupportedModel:
|
|||||||
]
|
]
|
||||||
worker_data["models"] = models
|
worker_data["models"] = models
|
||||||
return cls(**worker_data)
|
return cls(**worker_data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlatSupportedModel(SupportedModel):
|
||||||
|
"""For web"""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_supports(
|
||||||
|
supports: List[WorkerSupportedModel],
|
||||||
|
) -> List["FlatSupportedModel"]:
|
||||||
|
results = []
|
||||||
|
for s in supports:
|
||||||
|
host, port, models = s.host, s.port, s.models
|
||||||
|
for m in models:
|
||||||
|
kwargs = asdict(m)
|
||||||
|
kwargs["host"] = host
|
||||||
|
kwargs["port"] = port
|
||||||
|
results.append(FlatSupportedModel(**kwargs))
|
||||||
|
return results
|
||||||
|
@ -5,6 +5,7 @@ from pilot.model.cluster.base import (
|
|||||||
WorkerParameterRequest,
|
WorkerParameterRequest,
|
||||||
WorkerStartupRequest,
|
WorkerStartupRequest,
|
||||||
)
|
)
|
||||||
|
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
||||||
|
|
||||||
@ -18,6 +19,7 @@ from pilot.model.cluster.registry import ModelRegistry
|
|||||||
from pilot.model.cluster.controller.controller import (
|
from pilot.model.cluster.controller.controller import (
|
||||||
ModelRegistryClient,
|
ModelRegistryClient,
|
||||||
run_model_controller,
|
run_model_controller,
|
||||||
|
BaseModelController,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||||
@ -28,6 +30,7 @@ __all__ = [
|
|||||||
"WorkerApplyRequest",
|
"WorkerApplyRequest",
|
||||||
"WorkerParameterRequest",
|
"WorkerParameterRequest",
|
||||||
"WorkerStartupRequest",
|
"WorkerStartupRequest",
|
||||||
|
"WorkerManagerFactory",
|
||||||
"ModelWorker",
|
"ModelWorker",
|
||||||
"DefaultModelWorker",
|
"DefaultModelWorker",
|
||||||
"worker_manager",
|
"worker_manager",
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from pilot.componet import BaseComponet, ComponetType, SystemApp
|
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||||
from pilot.model.base import ModelInstance
|
from pilot.model.base import ModelInstance
|
||||||
from pilot.model.parameter import ModelControllerParameters
|
from pilot.model.parameter import ModelControllerParameters
|
||||||
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||||
@ -15,8 +15,8 @@ from pilot.utils.api_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelController(BaseComponet, ABC):
|
class BaseModelController(BaseComponent, ABC):
|
||||||
name = ComponetType.MODEL_CONTROLLER
|
name = ComponentType.MODEL_CONTROLLER
|
||||||
|
|
||||||
def init_app(self, system_app: SystemApp):
|
def init_app(self, system_app: SystemApp):
|
||||||
pass
|
pass
|
||||||
|
@ -4,6 +4,7 @@ from typing import List, Optional, Dict, Iterator, Callable
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||||
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
|
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
|
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
|
||||||
@ -104,3 +105,14 @@ class WorkerManager(ABC):
|
|||||||
self, worker_type: str, model_name: str
|
self, worker_type: str, model_name: str
|
||||||
) -> List[ParameterDescription]:
|
) -> List[ParameterDescription]:
|
||||||
"""Get parameter descriptions of model"""
|
"""Get parameter descriptions of model"""
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerManagerFactory(BaseComponent, ABC):
|
||||||
|
name = ComponentType.WORKER_MANAGER_FACTORY.value
|
||||||
|
|
||||||
|
def init_app(self, system_app: SystemApp):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self) -> WorkerManager:
|
||||||
|
"""Create worker manager"""
|
||||||
|
@ -7,10 +7,11 @@ import random
|
|||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Awaitable, Callable, Dict, Iterator, List
|
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pilot.component import SystemApp
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.model.base import (
|
from pilot.model.base import (
|
||||||
ModelInstance,
|
ModelInstance,
|
||||||
@ -23,7 +24,11 @@ from pilot.model.cluster.registry import ModelRegistry
|
|||||||
from pilot.model.llm_utils import list_supported_models
|
from pilot.model.llm_utils import list_supported_models
|
||||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData
|
from pilot.model.cluster.manager_base import (
|
||||||
|
WorkerManager,
|
||||||
|
WorkerRunData,
|
||||||
|
WorkerManagerFactory,
|
||||||
|
)
|
||||||
from pilot.model.cluster.base import *
|
from pilot.model.cluster.base import *
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.utils.parameter_utils import (
|
from pilot.utils.parameter_utils import (
|
||||||
@ -548,6 +553,17 @@ class WorkerManagerAdapter(WorkerManager):
|
|||||||
return await self.worker_manager.parameter_descriptions(worker_type, model_name)
|
return await self.worker_manager.parameter_descriptions(worker_type, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
class _DefaultWorkerManagerFactory(WorkerManagerFactory):
|
||||||
|
def __init__(
|
||||||
|
self, system_app: SystemApp | None = None, worker_manager: WorkerManager = None
|
||||||
|
):
|
||||||
|
super().__init__(system_app)
|
||||||
|
self.worker_manager = worker_manager
|
||||||
|
|
||||||
|
def create(self) -> WorkerManager:
|
||||||
|
return self.worker_manager
|
||||||
|
|
||||||
|
|
||||||
worker_manager = WorkerManagerAdapter()
|
worker_manager = WorkerManagerAdapter()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -787,6 +803,7 @@ def initialize_worker_manager_in_client(
|
|||||||
embedding_model_name: str = None,
|
embedding_model_name: str = None,
|
||||||
embedding_model_path: str = None,
|
embedding_model_path: str = None,
|
||||||
start_listener: Callable[["WorkerManager"], None] = None,
|
start_listener: Callable[["WorkerManager"], None] = None,
|
||||||
|
system_app: SystemApp = None,
|
||||||
):
|
):
|
||||||
"""Initialize WorkerManager in client.
|
"""Initialize WorkerManager in client.
|
||||||
If run_locally is True:
|
If run_locally is True:
|
||||||
@ -845,6 +862,8 @@ def initialize_worker_manager_in_client(
|
|||||||
if include_router and app:
|
if include_router and app:
|
||||||
# mount WorkerManager router
|
# mount WorkerManager router
|
||||||
app.include_router(router, prefix="/api")
|
app.include_router(router, prefix="/api")
|
||||||
|
if system_app:
|
||||||
|
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||||
|
|
||||||
|
|
||||||
def run_worker_manager(
|
def run_worker_manager(
|
||||||
|
@ -172,7 +172,9 @@ def _list_supported_models(
|
|||||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||||
param_cls = llm_adapter.model_param_class()
|
param_cls = llm_adapter.model_param_class()
|
||||||
model.enabled = True
|
model.enabled = True
|
||||||
params = _get_parameter_descriptions(param_cls)
|
params = _get_parameter_descriptions(
|
||||||
|
param_cls, model_name=model_name, model_path=model_path
|
||||||
|
)
|
||||||
model.params = params
|
model.params = params
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -3,6 +3,7 @@ import uuid
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import logging
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Request,
|
Request,
|
||||||
@ -11,6 +12,7 @@ from fastapi import (
|
|||||||
Form,
|
Form,
|
||||||
Body,
|
Body,
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
|
Depends,
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@ -18,7 +20,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from typing import List
|
from typing import List
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from pilot.componet import ComponetType
|
from pilot.component import ComponentType
|
||||||
from pilot.openapi.api_view_model import (
|
from pilot.openapi.api_view_model import (
|
||||||
Result,
|
Result,
|
||||||
ConversationVo,
|
ConversationVo,
|
||||||
@ -41,10 +43,13 @@ from pilot.scene.message import OnceConversation
|
|||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
|
|
||||||
|
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||||
|
from pilot.model.base import FlatSupportedModel
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
CHAT_FACTORY = ChatFactory()
|
CHAT_FACTORY = ChatFactory()
|
||||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
logger = logging.getLogger(__name__)
|
||||||
knowledge_service = KnowledgeService()
|
knowledge_service = KnowledgeService()
|
||||||
|
|
||||||
model_semaphore = None
|
model_semaphore = None
|
||||||
@ -90,6 +95,20 @@ def knowledge_list():
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_controller() -> BaseModelController:
|
||||||
|
controller = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||||
|
)
|
||||||
|
return controller
|
||||||
|
|
||||||
|
|
||||||
|
def get_worker_manager() -> WorkerManager:
|
||||||
|
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
|
return worker_manager
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||||
async def db_connect_list():
|
async def db_connect_list():
|
||||||
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
||||||
@ -351,15 +370,10 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/v1/model/types")
|
@router.get("/v1/model/types")
|
||||||
async def model_types(request: Request):
|
async def model_types(controller: BaseModelController = Depends(get_model_controller)):
|
||||||
print(f"/controller/model/types")
|
logger.info(f"/controller/model/types")
|
||||||
try:
|
try:
|
||||||
types = set()
|
types = set()
|
||||||
from pilot.model.cluster.controller.controller import BaseModelController
|
|
||||||
|
|
||||||
controller = CFG.SYSTEM_APP.get_componet(
|
|
||||||
ComponetType.MODEL_CONTROLLER, BaseModelController
|
|
||||||
)
|
|
||||||
models = await controller.get_all_instances(healthy_only=True)
|
models = await controller.get_all_instances(healthy_only=True)
|
||||||
for model in models:
|
for model in models:
|
||||||
worker_name, worker_type = model.model_name.split("@")
|
worker_name, worker_type = model.model_name.split("@")
|
||||||
@ -371,6 +385,16 @@ async def model_types(request: Request):
|
|||||||
return Result.faild(code="E000X", msg=f"controller model types error {e}")
|
return Result.faild(code="E000X", msg=f"controller model types error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/model/supports")
|
||||||
|
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
|
||||||
|
logger.info(f"/controller/model/supports")
|
||||||
|
try:
|
||||||
|
models = await worker_manager.supported_models()
|
||||||
|
return Result.succ(FlatSupportedModel.from_supports(models))
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"Fetch supportd models error {e}")
|
||||||
|
|
||||||
|
|
||||||
async def no_stream_generator(chat):
|
async def no_stream_generator(chat):
|
||||||
msg = await chat.nostream_call()
|
msg = await chat.nostream_call()
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, List, Dict
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.component import ComponentType
|
||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||||
@ -142,8 +143,11 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import worker_manager
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
|
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
async for output in worker_manager.generate_stream(payload):
|
async for output in worker_manager.generate_stream(payload):
|
||||||
yield output
|
yield output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -160,7 +164,11 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import worker_manager
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
|
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
|
|
||||||
model_output = await worker_manager.generate(payload)
|
model_output = await worker_manager.generate(payload)
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
self.knowledge_embedding_client = EmbeddingEngine(
|
self.knowledge_embedding_client = EmbeddingEngine(
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional, Any
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.componet import SystemApp
|
from pilot.component import SystemApp
|
||||||
from pilot.utils.parameter_utils import BaseParameters
|
from pilot.utils.parameter_utils import BaseParameters
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +71,6 @@ def server_init(args, system_app: SystemApp):
|
|||||||
|
|
||||||
def _create_model_start_listener(system_app: SystemApp):
|
def _create_model_start_listener(system_app: SystemApp):
|
||||||
from pilot.connections.manages.connection_manager import ConnectManager
|
from pilot.connections.manages.connection_manager import ConnectManager
|
||||||
from pilot.model.cluster import worker_manager
|
|
||||||
|
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Type, TYPE_CHECKING
|
|
||||||
|
|
||||||
from pilot.componet import SystemApp
|
|
||||||
import logging
|
import logging
|
||||||
from pilot.configs.model_config import get_device
|
from typing import TYPE_CHECKING, Any, Type
|
||||||
from pilot.embedding_engine.embedding_factory import (
|
|
||||||
EmbeddingFactory,
|
from pilot.component import ComponentType, SystemApp
|
||||||
DefaultEmbeddingFactory,
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
)
|
|
||||||
from pilot.server.base import WebWerverParameters
|
from pilot.server.base import WebWerverParameters
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -18,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def initialize_componets(
|
def initialize_components(
|
||||||
param: WebWerverParameters,
|
param: WebWerverParameters,
|
||||||
system_app: SystemApp,
|
system_app: SystemApp,
|
||||||
embedding_model_name: str,
|
embedding_model_name: str,
|
||||||
@ -39,13 +35,9 @@ def _initialize_embedding_model(
|
|||||||
embedding_model_name: str,
|
embedding_model_name: str,
|
||||||
embedding_model_path: str,
|
embedding_model_path: str,
|
||||||
):
|
):
|
||||||
from pilot.model.cluster import worker_manager
|
|
||||||
|
|
||||||
if param.remote_embedding:
|
if param.remote_embedding:
|
||||||
logger.info("Register remote RemoteEmbeddingFactory")
|
logger.info("Register remote RemoteEmbeddingFactory")
|
||||||
system_app.register(
|
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
|
||||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Register local LocalEmbeddingFactory")
|
logger.info(f"Register local LocalEmbeddingFactory")
|
||||||
system_app.register(
|
system_app.register(
|
||||||
@ -56,26 +48,28 @@ def _initialize_embedding_model(
|
|||||||
|
|
||||||
|
|
||||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||||
def __init__(
|
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||||
self, system_app, worker_manager, model_name: str = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
super().__init__(system_app=system_app)
|
super().__init__(system_app=system_app)
|
||||||
self._worker_manager = worker_manager
|
|
||||||
self._default_model_name = model_name
|
self._default_model_name = model_name
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
self.system_app = system_app
|
||||||
|
|
||||||
def init_app(self, system_app):
|
def init_app(self, system_app):
|
||||||
pass
|
self.system_app = system_app
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self, model_name: str = None, embedding_cls: Type = None
|
self, model_name: str = None, embedding_cls: Type = None
|
||||||
) -> "Embeddings":
|
) -> "Embeddings":
|
||||||
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
||||||
|
|
||||||
if embedding_cls:
|
if embedding_cls:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
worker_manager = self.system_app.get_component(
|
||||||
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
|
).create()
|
||||||
# Ignore model_name args
|
# Ignore model_name args
|
||||||
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
|
return RemoteEmbeddings(self._default_model_name, worker_manager)
|
||||||
|
|
||||||
|
|
||||||
class LocalEmbeddingFactory(EmbeddingFactory):
|
class LocalEmbeddingFactory(EmbeddingFactory):
|
||||||
@ -103,13 +97,13 @@ class LocalEmbeddingFactory(EmbeddingFactory):
|
|||||||
return self._model
|
return self._model
|
||||||
|
|
||||||
def _load_model(self) -> "Embeddings":
|
def _load_model(self) -> "Embeddings":
|
||||||
from pilot.model.parameter import (
|
|
||||||
EmbeddingModelParameters,
|
|
||||||
BaseEmbeddingModelParameters,
|
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
|
||||||
)
|
|
||||||
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
|
||||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
|
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||||
|
from pilot.model.parameter import (
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||||
|
BaseEmbeddingModelParameters,
|
||||||
|
EmbeddingModelParameters,
|
||||||
|
)
|
||||||
|
|
||||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||||
self._default_model_name, EmbeddingModelParameters
|
self._default_model_name, EmbeddingModelParameters
|
@ -8,14 +8,14 @@ sys.path.append(ROOT_PATH)
|
|||||||
import signal
|
import signal
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
||||||
from pilot.componet import SystemApp
|
from pilot.component import SystemApp
|
||||||
|
|
||||||
from pilot.server.base import (
|
from pilot.server.base import (
|
||||||
server_init,
|
server_init,
|
||||||
WebWerverParameters,
|
WebWerverParameters,
|
||||||
_create_model_start_listener,
|
_create_model_start_listener,
|
||||||
)
|
)
|
||||||
from pilot.server.componet_configs import initialize_componets
|
from pilot.server.component_configs import initialize_components
|
||||||
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi import FastAPI, applications
|
from fastapi import FastAPI, applications
|
||||||
@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pilot.server.knowledge.api import router as knowledge_router
|
from pilot.server.knowledge.api import router as knowledge_router
|
||||||
|
from pilot.server.prompt.api import router as prompt_router
|
||||||
from pilot.server.llm_manage.api import router as llm_manage_api
|
from pilot.server.llm_manage.api import router as llm_manage_api
|
||||||
|
|
||||||
|
|
||||||
@ -76,6 +77,7 @@ app.include_router(llm_manage_api, prefix="/api")
|
|||||||
|
|
||||||
# app.include_router(api_v1)
|
# app.include_router(api_v1)
|
||||||
app.include_router(knowledge_router)
|
app.include_router(knowledge_router)
|
||||||
|
app.include_router(prompt_router)
|
||||||
# app.include_router(api_editor_route_v1)
|
# app.include_router(api_editor_route_v1)
|
||||||
|
|
||||||
|
|
||||||
@ -118,7 +120,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
|
|
||||||
server_init(param, system_app)
|
server_init(param, system_app)
|
||||||
model_start_listener = _create_model_start_listener(system_app)
|
model_start_listener = _create_model_start_listener(system_app)
|
||||||
initialize_componets(param, system_app, embedding_model_name, embedding_model_path)
|
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||||
|
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
if not param.light:
|
if not param.light:
|
||||||
@ -133,6 +135,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
embedding_model_name=embedding_model_name,
|
embedding_model_name=embedding_model_name,
|
||||||
embedding_model_path=embedding_model_path,
|
embedding_model_path=embedding_model_path,
|
||||||
start_listener=model_start_listener,
|
start_listener=model_start_listener,
|
||||||
|
system_app=system_app,
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.NEW_SERVER_MODE = True
|
CFG.NEW_SERVER_MODE = True
|
||||||
@ -146,6 +149,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
controller_addr=CFG.MODEL_SERVER,
|
controller_addr=CFG.MODEL_SERVER,
|
||||||
local_port=param.port,
|
local_port=param.port,
|
||||||
start_listener=model_start_listener,
|
start_listener=model_start_listener,
|
||||||
|
system_app=system_app,
|
||||||
)
|
)
|
||||||
CFG.SERVER_LIGHT_MODE = True
|
CFG.SERVER_LIGHT_MODE = True
|
||||||
|
|
||||||
|
@ -181,7 +181,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
|
|||||||
@router.post("/knowledge/{vector_name}/query")
|
@router.post("/knowledge/{vector_name}/query")
|
||||||
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||||
print(f"Received params: {space_name}, {query_request}")
|
print(f"Received params: {space_name}, {query_request}")
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
client = EmbeddingEngine(
|
client = EmbeddingEngine(
|
||||||
|
@ -205,7 +205,7 @@ class KnowledgeService:
|
|||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
)
|
)
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
client = EmbeddingEngine(
|
client = EmbeddingEngine(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from pilot.componet import ComponetType
|
from pilot.component import ComponentType
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.model.base import ModelInstance, WorkerApplyType
|
from pilot.model.base import ModelInstance, WorkerApplyType
|
||||||
|
|
||||||
@ -31,8 +31,8 @@ async def model_list():
|
|||||||
try:
|
try:
|
||||||
from pilot.model.cluster.controller.controller import BaseModelController
|
from pilot.model.cluster.controller.controller import BaseModelController
|
||||||
|
|
||||||
controller = CFG.SYSTEM_APP.get_componet(
|
controller = CFG.SYSTEM_APP.get_component(
|
||||||
ComponetType.MODEL_CONTROLLER, BaseModelController
|
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||||
)
|
)
|
||||||
responses = []
|
responses = []
|
||||||
managers = await controller.get_all_instances(
|
managers = await controller.get_all_instances(
|
||||||
@ -70,8 +70,8 @@ async def model_start(request: WorkerStartupRequest):
|
|||||||
try:
|
try:
|
||||||
from pilot.model.cluster.controller.controller import BaseModelController
|
from pilot.model.cluster.controller.controller import BaseModelController
|
||||||
|
|
||||||
controller = CFG.SYSTEM_APP.get_componet(
|
controller = CFG.SYSTEM_APP.get_component(
|
||||||
ComponetType.MODEL_CONTROLLER, BaseModelController
|
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||||
)
|
)
|
||||||
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
|
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
|
||||||
worker_instance = None
|
worker_instance = None
|
||||||
@ -98,8 +98,8 @@ async def model_start(request: WorkerStartupRequest):
|
|||||||
try:
|
try:
|
||||||
from pilot.model.cluster.controller.controller import BaseModelController
|
from pilot.model.cluster.controller.controller import BaseModelController
|
||||||
|
|
||||||
controller = CFG.SYSTEM_APP.get_componet(
|
controller = CFG.SYSTEM_APP.get_component(
|
||||||
ComponetType.MODEL_CONTROLLER, BaseModelController
|
ComponentType.MODEL_CONTROLLER, BaseModelController
|
||||||
)
|
)
|
||||||
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
|
instances = await controller.get_all_instances(model_name="WorkerManager@service", healthy_only=True)
|
||||||
worker_instance = None
|
worker_instance = None
|
||||||
|
0
pilot/server/prompt/__init__.py
Normal file
0
pilot/server/prompt/__init__.py
Normal file
46
pilot/server/prompt/api.py
Normal file
46
pilot/server/prompt/api.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from fastapi import APIRouter, File, UploadFile, Form
|
||||||
|
|
||||||
|
from pilot.openapi.api_view_model import Result
|
||||||
|
from pilot.server.prompt.service import PromptManageService
|
||||||
|
from pilot.server.prompt.request.request import PromptManageRequest
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
prompt_manage_service = PromptManageService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt/add")
|
||||||
|
def prompt_add(request: PromptManageRequest):
|
||||||
|
print(f"/space/add params: {request}")
|
||||||
|
try:
|
||||||
|
prompt_manage_service.create_prompt(request)
|
||||||
|
return Result.succ([])
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E010X", msg=f"prompt add error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt/list")
|
||||||
|
def prompt_list(request: PromptManageRequest):
|
||||||
|
print(f"/prompt/list params: {request}")
|
||||||
|
try:
|
||||||
|
return Result.succ(prompt_manage_service.get_prompts(request))
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E010X", msg=f"prompt list error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt/update")
|
||||||
|
def prompt_update(request: PromptManageRequest):
|
||||||
|
print(f"/prompt/update params: {request}")
|
||||||
|
try:
|
||||||
|
return Result.succ(prompt_manage_service.update_prompt(request))
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E010X", msg=f"prompt update error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt/delete")
|
||||||
|
def prompt_delete(request: PromptManageRequest):
|
||||||
|
print(f"/prompt/delete params: {request}")
|
||||||
|
try:
|
||||||
|
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E010X", msg=f"prompt delete error {e}")
|
91
pilot/server/prompt/prompt_manage_db.py
Normal file
91
pilot/server/prompt/prompt_manage_db.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.connections.rdbms.base_dao import BaseDao
|
||||||
|
|
||||||
|
from pilot.server.prompt.request.request import PromptManageRequest
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManageEntity(Base):
|
||||||
|
__tablename__ = "prompt_manage"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
chat_scene = Column(String(100))
|
||||||
|
sub_chat_scene = Column(String(100))
|
||||||
|
prompt_type = Column(String(100))
|
||||||
|
prompt_name = Column(String(512))
|
||||||
|
content = Column(Text)
|
||||||
|
user_name = Column(String(128))
|
||||||
|
gmt_created = Column(DateTime)
|
||||||
|
gmt_modified = Column(DateTime)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManageDao(BaseDao):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
database="prompt_management", orm_base=Base, create_not_exist_table=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_prompt(self, prompt: PromptManageRequest):
|
||||||
|
session = self.Session()
|
||||||
|
prompt_manage = PromptManageEntity(
|
||||||
|
chat_scene=prompt.chat_scene,
|
||||||
|
sub_chat_scene=prompt.sub_chat_scene,
|
||||||
|
prompt_type=prompt.prompt_type,
|
||||||
|
prompt_name=prompt.prompt_name,
|
||||||
|
content=prompt.content,
|
||||||
|
user_name=prompt.user_name,
|
||||||
|
gmt_created=datetime.now(),
|
||||||
|
gmt_modified=datetime.now(),
|
||||||
|
)
|
||||||
|
session.add(prompt_manage)
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def get_prompts(self, query: PromptManageEntity):
|
||||||
|
session = self.Session()
|
||||||
|
prompts = session.query(PromptManageEntity)
|
||||||
|
if query.chat_scene is not None:
|
||||||
|
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
||||||
|
if query.sub_chat_scene is not None:
|
||||||
|
prompts = prompts.filter(
|
||||||
|
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
|
||||||
|
)
|
||||||
|
if query.prompt_type is not None:
|
||||||
|
prompts = prompts.filter(
|
||||||
|
PromptManageEntity.prompt_type == query.prompt_type
|
||||||
|
)
|
||||||
|
if query.prompt_type == "private" and query.user_name is not None:
|
||||||
|
prompts = prompts.filter(
|
||||||
|
PromptManageEntity.user_name == query.user_name
|
||||||
|
)
|
||||||
|
if query.prompt_name is not None:
|
||||||
|
prompts = prompts.filter(
|
||||||
|
PromptManageEntity.prompt_name == query.prompt_name
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
|
||||||
|
result = prompts.all()
|
||||||
|
session.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def update_prompt(self, prompt: PromptManageEntity):
|
||||||
|
session = self.Session()
|
||||||
|
session.merge(prompt)
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def delete_prompt(self, prompt: PromptManageEntity):
|
||||||
|
session = self.Session()
|
||||||
|
if prompt:
|
||||||
|
session.delete(prompt)
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
0
pilot/server/prompt/request/__init__.py
Normal file
0
pilot/server/prompt/request/__init__.py
Normal file
24
pilot/server/prompt/request/request.py
Normal file
24
pilot/server/prompt/request/request.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManageRequest(BaseModel):
|
||||||
|
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
|
||||||
|
|
||||||
|
chat_scene: str = None
|
||||||
|
|
||||||
|
"""sub_chat_scene: sub chat scene"""
|
||||||
|
sub_chat_scene: str = None
|
||||||
|
|
||||||
|
"""prompt_type: common or private"""
|
||||||
|
prompt_type: str = None
|
||||||
|
|
||||||
|
"""content: prompt content"""
|
||||||
|
content: str = None
|
||||||
|
|
||||||
|
"""user_name: user name"""
|
||||||
|
user_name: str = None
|
||||||
|
|
||||||
|
"""prompt_name: prompt name"""
|
||||||
|
prompt_name: str = None
|
26
pilot/server/prompt/request/response.py
Normal file
26
pilot/server/prompt/request/response.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from typing import List
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PromptQueryResponse(BaseModel):
|
||||||
|
id: int = None
|
||||||
|
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
|
||||||
|
|
||||||
|
chat_scene: str = None
|
||||||
|
|
||||||
|
"""sub_chat_scene: sub chat scene"""
|
||||||
|
sub_chat_scene: str = None
|
||||||
|
|
||||||
|
"""prompt_type: common or private"""
|
||||||
|
prompt_type: str = None
|
||||||
|
|
||||||
|
"""content: prompt content"""
|
||||||
|
content: str = None
|
||||||
|
|
||||||
|
"""user_name: user name"""
|
||||||
|
user_name: str = None
|
||||||
|
|
||||||
|
"""prompt_name: prompt name"""
|
||||||
|
prompt_name: str = None
|
||||||
|
gmt_created: str = None
|
||||||
|
gmt_modified: str = None
|
80
pilot/server/prompt/service.py
Normal file
80
pilot/server/prompt/service.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pilot.server.prompt.request.request import PromptManageRequest
|
||||||
|
from pilot.server.prompt.request.response import PromptQueryResponse
|
||||||
|
from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
|
||||||
|
|
||||||
|
prompt_manage_dao = PromptManageDao()
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManageService:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
"""create prompt"""
|
||||||
|
|
||||||
|
def create_prompt(self, request: PromptManageRequest):
|
||||||
|
query = PromptManageRequest(
|
||||||
|
prompt_name=request.prompt_name,
|
||||||
|
)
|
||||||
|
prompt_name = prompt_manage_dao.get_prompts(query)
|
||||||
|
if len(prompt_name) > 0:
|
||||||
|
raise Exception(f"prompt name:{request.prompt_name} have already named")
|
||||||
|
prompt_manage_dao.create_prompt(request)
|
||||||
|
return True
|
||||||
|
|
||||||
|
"""get prompts"""
|
||||||
|
|
||||||
|
def get_prompts(self, request: PromptManageRequest):
|
||||||
|
query = PromptManageRequest(
|
||||||
|
chat_scene=request.chat_scene,
|
||||||
|
sub_chat_scene=request.sub_chat_scene,
|
||||||
|
prompt_type=request.prompt_type,
|
||||||
|
prompt_name=request.prompt_name,
|
||||||
|
user_name=request.user_name,
|
||||||
|
)
|
||||||
|
responses = []
|
||||||
|
prompts = prompt_manage_dao.get_prompts(query)
|
||||||
|
for prompt in prompts:
|
||||||
|
res = PromptQueryResponse()
|
||||||
|
|
||||||
|
res.id = prompt.id
|
||||||
|
res.chat_scene = prompt.chat_scene
|
||||||
|
res.sub_chat_scene = prompt.sub_chat_scene
|
||||||
|
res.prompt_type = prompt.prompt_type
|
||||||
|
res.content = prompt.content
|
||||||
|
res.user_name = prompt.user_name
|
||||||
|
res.prompt_name = prompt.prompt_name
|
||||||
|
res.gmt_created = prompt.gmt_created
|
||||||
|
res.gmt_modified = prompt.gmt_modified
|
||||||
|
responses.append(res)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
"""update prompt"""
|
||||||
|
|
||||||
|
def update_prompt(self, request: PromptManageRequest):
|
||||||
|
query = PromptManageEntity(prompt_name=request.prompt_name)
|
||||||
|
prompts = prompt_manage_dao.get_prompts(query)
|
||||||
|
if len(prompts) != 1:
|
||||||
|
raise Exception(
|
||||||
|
f"there are no or more than one space called {request.prompt_name}"
|
||||||
|
)
|
||||||
|
prompt = prompts[0]
|
||||||
|
prompt.chat_scene = request.chat_scene
|
||||||
|
prompt.sub_chat_scene = request.sub_chat_scene
|
||||||
|
prompt.prompt_type = request.prompt_type
|
||||||
|
prompt.content = request.content
|
||||||
|
prompt.user_name = request.user_name
|
||||||
|
prompt.gmt_modified = datetime.now()
|
||||||
|
return prompt_manage_dao.update_prompt(prompt)
|
||||||
|
|
||||||
|
"""delete prompt"""
|
||||||
|
|
||||||
|
def delete_prompt(self, prompt_name: str):
|
||||||
|
query = PromptManageEntity(prompt_name=prompt_name)
|
||||||
|
prompts = prompt_manage_dao.get_prompts(query)
|
||||||
|
if len(prompts) == 0:
|
||||||
|
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
|
||||||
|
# delete prompt
|
||||||
|
prompt = prompts[0]
|
||||||
|
return prompt_manage_dao.delete_prompt(prompt)
|
@ -2,7 +2,7 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.componet import SystemApp
|
from pilot.component import SystemApp
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
@ -36,7 +36,7 @@ class DBSummaryClient:
|
|||||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
|
|
||||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||||
embedding_factory = self.system_app.get_componet(
|
embedding_factory = self.system_app.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
embeddings = embedding_factory.create(
|
embeddings = embedding_factory.create(
|
||||||
@ -94,7 +94,7 @@ class DBSummaryClient:
|
|||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
knowledge_embedding_client = EmbeddingEngine(
|
knowledge_embedding_client = EmbeddingEngine(
|
||||||
@ -117,7 +117,7 @@ class DBSummaryClient:
|
|||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_componet(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
knowledge_embedding_client = EmbeddingEngine(
|
knowledge_embedding_client = EmbeddingEngine(
|
||||||
|
@ -12,6 +12,7 @@ class ParameterDescription:
|
|||||||
param_type: str
|
param_type: str
|
||||||
default_value: Optional[Any]
|
default_value: Optional[Any]
|
||||||
description: str
|
description: str
|
||||||
|
required: Optional[bool]
|
||||||
valid_values: Optional[List[Any]]
|
valid_values: Optional[List[Any]]
|
||||||
ext_metadata: Dict
|
ext_metadata: Dict
|
||||||
|
|
||||||
@ -460,20 +461,25 @@ def _type_str_to_python_type(type_str: str) -> Type:
|
|||||||
return type_mapping.get(type_str, str)
|
return type_mapping.get(type_str, str)
|
||||||
|
|
||||||
|
|
||||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
def _get_parameter_descriptions(
|
||||||
|
dataclass_type: Type, **kwargs
|
||||||
|
) -> List[ParameterDescription]:
|
||||||
descriptions = []
|
descriptions = []
|
||||||
for field in fields(dataclass_type):
|
for field in fields(dataclass_type):
|
||||||
ext_metadata = {
|
ext_metadata = {
|
||||||
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
||||||
}
|
}
|
||||||
|
default_value = field.default if field.default != MISSING else None
|
||||||
|
if field.name in kwargs:
|
||||||
|
default_value = kwargs[field.name]
|
||||||
descriptions.append(
|
descriptions.append(
|
||||||
ParameterDescription(
|
ParameterDescription(
|
||||||
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
||||||
param_name=field.name,
|
param_name=field.name,
|
||||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||||
description=field.metadata.get("help", None),
|
description=field.metadata.get("help", None),
|
||||||
default_value=field.default if field.default != MISSING else None,
|
required=field.default is MISSING,
|
||||||
|
default_value=default_value,
|
||||||
valid_values=field.metadata.get("valid_values", None),
|
valid_values=field.metadata.get("valid_values", None),
|
||||||
ext_metadata=ext_metadata,
|
ext_metadata=ext_metadata,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user