refactor: Refactor storage and new serve template (#947)

This commit is contained in:
Fangyin Cheng
2023-12-18 19:30:40 +08:00
committed by GitHub
parent 22d95b444b
commit 511a43b849
63 changed files with 1891 additions and 229 deletions

View File

View File

@@ -0,0 +1,29 @@
"""Import all models to make sure they are registered with SQLAlchemy.
"""
from dbgpt.agent.db.my_plugin_db import MyPluginEntity
from dbgpt.agent.db.plugin_hub_db import PluginHubEntity
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
ChatHistoryMessageEntity,
)
_MODELS = [
PluginHubEntity,
MyPluginEntity,
PromptManageEntity,
KnowledgeSpaceEntity,
KnowledgeDocumentEntity,
DocumentChunkEntity,
ChatFeedBackEntity,
ConnectConfigEntity,
ChatHistoryEntity,
ChatHistoryMessageEntity,
]

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
import logging
from typing import Any, Type, TYPE_CHECKING
from dbgpt.component import ComponentType, SystemApp
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
from dbgpt.app.base import WebServerParameters
logger = logging.getLogger(__name__)
def _initialize_embedding_model(
param: "WebServerParameters",
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
else:
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
)
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app):
self.system_app = system_app
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "Embeddings":
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)

View File

@@ -0,0 +1,9 @@
from dbgpt.component import SystemApp
def register_serve_apps(system_app: SystemApp):
"""Register serve apps"""
from dbgpt.serve.prompt.serve import Serve as PromptServe
# Replace old prompt serve
system_app.register(PromptServe, api_prefix="/prompt")