mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 05:59:59 +00:00
refactor: Refactor storage and new serve template (#947)
This commit is contained in:
0
dbgpt/app/initialization/__init__.py
Normal file
0
dbgpt/app/initialization/__init__.py
Normal file
29
dbgpt/app/initialization/db_model_initialization.py
Normal file
29
dbgpt/app/initialization/db_model_initialization.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Import all models to make sure they are registered with SQLAlchemy.
|
||||
"""
|
||||
from dbgpt.agent.db.my_plugin_db import MyPluginEntity
|
||||
from dbgpt.agent.db.plugin_hub_db import PluginHubEntity
|
||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
|
||||
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
|
||||
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
|
||||
|
||||
# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
|
||||
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
|
||||
from dbgpt.storage.chat_history.chat_history_db import (
|
||||
ChatHistoryEntity,
|
||||
ChatHistoryMessageEntity,
|
||||
)
|
||||
|
||||
_MODELS = [
|
||||
PluginHubEntity,
|
||||
MyPluginEntity,
|
||||
PromptManageEntity,
|
||||
KnowledgeSpaceEntity,
|
||||
KnowledgeDocumentEntity,
|
||||
DocumentChunkEntity,
|
||||
ChatFeedBackEntity,
|
||||
ConnectConfigEntity,
|
||||
ChatHistoryEntity,
|
||||
ChatHistoryMessageEntity,
|
||||
]
|
103
dbgpt/app/initialization/embedding_component.py
Normal file
103
dbgpt/app/initialization/embedding_component.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from dbgpt.app.base import WebServerParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _initialize_embedding_model(
|
||||
param: "WebServerParameters",
|
||||
system_app: SystemApp,
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
):
|
||||
if param.remote_embedding:
|
||||
logger.info("Register remote RemoteEmbeddingFactory")
|
||||
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
|
||||
else:
|
||||
logger.info(f"Register local LocalEmbeddingFactory")
|
||||
system_app.register(
|
||||
LocalEmbeddingFactory,
|
||||
default_model_name=embedding_model_name,
|
||||
default_model_path=embedding_model_path,
|
||||
)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = model_name
|
||||
self.kwargs = kwargs
|
||||
self.system_app = system_app
|
||||
|
||||
def init_app(self, system_app):
|
||||
self.system_app = system_app
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
||||
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
# Ignore model_name args
|
||||
return RemoteEmbeddings(self._default_model_name, worker_manager)
|
||||
|
||||
|
||||
class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(
|
||||
self,
|
||||
system_app,
|
||||
default_model_name: str = None,
|
||||
default_model_path: str = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> "Embeddings":
|
||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load(self._default_model_name, model_params)
|
9
dbgpt/app/initialization/serve_initialization.py
Normal file
9
dbgpt/app/initialization/serve_initialization.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
|
||||
def register_serve_apps(system_app: SystemApp):
|
||||
"""Register serve apps"""
|
||||
from dbgpt.serve.prompt.serve import Serve as PromptServe
|
||||
|
||||
# Replace old prompt serve
|
||||
system_app.register(PromptServe, api_prefix="/prompt")
|
Reference in New Issue
Block a user