refactor: Refactor storage and new serve template (#947)

This commit is contained in:
Fangyin Cheng
2023-12-18 19:30:40 +08:00
committed by GitHub
parent 22d95b444b
commit 511a43b849
63 changed files with 1891 additions and 229 deletions

View File

@@ -108,12 +108,24 @@ def migrate(alembic_ini_path: str, script_location: str, message: str):
@migration.command()
@add_migration_options
def upgrade(alembic_ini_path: str, script_location: str):
@click.option(
"--sql-output",
type=str,
default=None,
help="Generate SQL script for migration instead of applying it. ex: --sql-output=upgrade.sql",
)
def upgrade(alembic_ini_path: str, script_location: str, sql_output: str):
"""Upgrade database to target version"""
from dbgpt.util._db_migration_utils import upgrade_database
from dbgpt.util._db_migration_utils import (
upgrade_database,
generate_sql_for_upgrade,
)
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
upgrade_database(alembic_cfg, db_manager.engine)
if sql_output:
generate_sql_for_upgrade(alembic_cfg, db_manager.engine, output_file=sql_output)
else:
upgrade_database(alembic_cfg, db_manager.engine)
@migration.command()
@@ -199,6 +211,7 @@ def clean(
def list(alembic_ini_path: str, script_location: str):
"""List all versions in the migration history, marking the current one"""
from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
@@ -259,8 +272,8 @@ def _get_migration_config(
from dbgpt.storage.metadata.db_manager import db as db_manager
from dbgpt.util._db_migration_utils import create_alembic_config
# Must import dbgpt_server for initialize db metadata
from dbgpt.app.dbgpt_server import initialize_app as _
# Import all models to make sure they are registered with SQLAlchemy.
from dbgpt.app.initialization.db_model_initialization import _MODELS
from dbgpt.app.base import _initialize_db
# initialize db

View File

@@ -10,7 +10,6 @@ from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.util.parameter_utils import BaseParameters
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -92,10 +91,27 @@ def _initialize_db_storage(param: "WebServerParameters"):
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
default_meta_data_path = _initialize_db(
try_to_create_db=not param.disable_alembic_upgrade
)
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)
def _migration_db_storage(param: "WebServerParameters"):
"""Migration the db storage."""
# Import all models to make sure they are registered with SQLAlchemy.
from dbgpt.app.initialization.db_model_initialization import _MODELS
from dbgpt.configs.model_config import PILOT_PATH
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
if not param.disable_alembic_upgrade:
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
from dbgpt.storage.metadata.db_manager import db
# try to create all tables
try:
db.create_all()
except Exception as e:
logger.warning(f"Create all tables stored in this metadata error: {str(e)}")
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
@@ -112,7 +128,13 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
os.makedirs(default_meta_data_path, exist_ok=True)
if CFG.LOCAL_DB_TYPE == "mysql":
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
db_url = (
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:"
f"{urlquote(CFG.LOCAL_DB_PASSWORD)}@"
f"{CFG.LOCAL_DB_HOST}:"
f"{str(CFG.LOCAL_DB_PORT)}/"
f"{db_name}?charset=utf8mb4&collation=utf8mb4_unicode_ci"
)
# Try to create database, if failed, will raise exception
_create_mysql_database(db_name, db_url, try_to_create_db)
else:
@@ -125,7 +147,7 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
"pool_recycle": 3600,
"pool_pre_ping": True,
}
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
initialize_db(db_url, db_name, engine_args)
return default_meta_data_path
@@ -161,7 +183,11 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F
no_db_name_url = db_url.rsplit("/", 1)[0]
engine_no_db = create_engine(no_db_name_url)
with engine_no_db.connect() as conn:
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
conn.execute(
DDL(
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
)
)
logger.info(f"Database {db_name} successfully created")
except SQLAlchemyError as e:
logger.error(f"Failed to create database {db_name}: {e}")

View File

@@ -1,17 +1,13 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Type
from dbgpt.component import ComponentType, SystemApp
from dbgpt.component import SystemApp
from dbgpt._private.config import Config
from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR
from dbgpt.util.executor_utils import DefaultExecutorFactory
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.app.base import WebServerParameters
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
logger = logging.getLogger(__name__)
@@ -24,7 +20,10 @@ def initialize_components(
embedding_model_name: str,
embedding_model_path: str,
):
# Lazy import to avoid high time cost
from dbgpt.model.cluster.controller.controller import controller
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.serve_initialization import register_serve_apps
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
@@ -44,97 +43,8 @@ def initialize_components(
)
_initialize_model_cache(system_app)
_initialize_awel(system_app)
def _initialize_embedding_model(
param: WebServerParameters,
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
else:
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
)
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app):
self.system_app = system_app
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "Embeddings":
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)
# Register serve apps
register_serve_apps(system_app)
def _initialize_model_cache(system_app: SystemApp):

View File

@@ -16,28 +16,22 @@ from dbgpt.component import SystemApp
from dbgpt.app.base import (
server_init,
_migration_db_storage,
WebServerParameters,
_create_model_start_listener,
)
# initialize_components import time cost about 0.1s
from dbgpt.app.component_configs import initialize_components
# fastapi import time cost about 0.05s
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from dbgpt.app.knowledge.api import router as knowledge_router
from dbgpt.app.prompt.api import router as prompt_router
from dbgpt.app.llm_manage.api import router as llm_manage_api
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
from dbgpt.app.openapi.base import validation_exception_handler
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from dbgpt.agent.commands.disply_type.show_chart_gen import (
static_message_img_path,
)
from dbgpt.model.cluster import initialize_worker_manager_in_client
from dbgpt.util.utils import (
setup_logging,
_get_logging_level,
@@ -78,16 +72,35 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
app.include_router(knowledge_router, tags=["Knowledge"])
app.include_router(prompt_router, tags=["Prompt"])
def mount_routers(app: FastAPI):
"""Lazy import to avoid high time cost"""
from dbgpt.app.knowledge.api import router as knowledge_router
# from dbgpt.app.prompt.api import router as prompt_router
# prompt has been removed to dbgpt.serve.prompt
from dbgpt.app.llm_manage.api import router as llm_manage_api
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import (
router as api_editor_route_v1,
)
from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
app.include_router(knowledge_router, tags=["Knowledge"])
# app.include_router(prompt_router, tags=["Prompt"])
def mount_static_files(app):
def mount_static_files(app: FastAPI):
from dbgpt.agent.commands.disply_type.show_chart_gen import (
static_message_img_path,
)
os.makedirs(static_message_img_path, exist_ok=True)
app.mount(
"/images",
@@ -122,14 +135,15 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
if not param:
param = _get_webserver_params(args)
# import after param is initialized, accelerate --help speed
from dbgpt.model.cluster import initialize_worker_manager_in_client
if not param.log_level:
param.log_level = _get_logging_level()
setup_logging(
"dbgpt", logging_level=param.log_level, logger_filename=param.log_file
)
# Before start
system_app.before_start()
model_name = param.model_name or CFG.LLM_MODEL
param.model_name = model_name
print(param)
@@ -138,9 +152,16 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
server_init(param, system_app)
mount_routers(app)
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
# Before start, after initialize_components
# TODO: initialize_worker_manager_in_client as a component register in system_app
system_app.before_start()
# Migration db storage, so you db models must be imported before this
_migration_db_storage(param)
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
if not param.light:
print("Model Unified Deployment Mode!")

View File

View File

@@ -0,0 +1,29 @@
"""Import all models to make sure they are registered with SQLAlchemy.
"""
from dbgpt.agent.db.my_plugin_db import MyPluginEntity
from dbgpt.agent.db.plugin_hub_db import PluginHubEntity
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
ChatHistoryMessageEntity,
)
_MODELS = [
PluginHubEntity,
MyPluginEntity,
PromptManageEntity,
KnowledgeSpaceEntity,
KnowledgeDocumentEntity,
DocumentChunkEntity,
ChatFeedBackEntity,
ConnectConfigEntity,
ChatHistoryEntity,
ChatHistoryMessageEntity,
]

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
import logging
from typing import Any, Type, TYPE_CHECKING
from dbgpt.component import ComponentType, SystemApp
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
from dbgpt.app.base import WebServerParameters
logger = logging.getLogger(__name__)
def _initialize_embedding_model(
param: "WebServerParameters",
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
):
if param.remote_embedding:
logger.info("Register remote RemoteEmbeddingFactory")
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
else:
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
)
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app):
self.system_app = system_app
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "Embeddings":
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)

View File

@@ -0,0 +1,9 @@
from dbgpt.component import SystemApp
def register_serve_apps(system_app: SystemApp):
"""Register serve apps"""
from dbgpt.serve.prompt.serve import Serve as PromptServe
# Replace old prompt serve
system_app.register(PromptServe, api_prefix="/prompt")

View File

@@ -11,10 +11,6 @@ CFG = Config()
class DocumentChunkEntity(Model):
__tablename__ = "document_chunk"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True)
document_id = Column(Integer)
doc_name = Column(String(100))
@@ -112,7 +108,7 @@ class DocumentChunkDao(BaseDao):
session.close()
return count
def delete(self, document_id: int):
def raw_delete(self, document_id: int):
session = self.get_raw_session()
if document_id is None:
raise Exception("document_id is None")

View File

@@ -10,10 +10,6 @@ CFG = Config()
class KnowledgeDocumentEntity(Model):
__tablename__ = "knowledge_document"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True)
doc_name = Column(String(100))
doc_type = Column(String(100))
@@ -180,7 +176,7 @@ class KnowledgeDocumentDao(BaseDao):
return updated_space.id
#
def delete(self, query: KnowledgeDocumentEntity):
def raw_delete(self, query: KnowledgeDocumentEntity):
session = self.get_raw_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:

View File

@@ -367,9 +367,9 @@ class KnowledgeService:
# delete chunks
documents = knowledge_document_dao.get_documents(document_query)
for document in documents:
document_chunk_dao.delete(document.id)
document_chunk_dao.raw_delete(document.id)
# delete documents
knowledge_document_dao.delete(document_query)
knowledge_document_dao.raw_delete(document_query)
# delete space
return knowledge_space_dao.delete_knowledge_space(space)
@@ -395,9 +395,9 @@ class KnowledgeService:
# delete vector by ids
vector_client.delete_by_ids(vector_ids)
# delete chunks
document_chunk_dao.delete(documents[0].id)
document_chunk_dao.raw_delete(documents[0].id)
# delete document
return knowledge_document_dao.delete(document_query)
return knowledge_document_dao.raw_delete(document_query)
def get_document_chunks(self, request: ChunkQueryRequest):
"""get document chunks

View File

@@ -11,10 +11,6 @@ CFG = Config()
class KnowledgeSpaceEntity(Model):
__tablename__ = "knowledge_space"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True)
name = Column(String(100))
vector_type = Column(String(100))

View File

@@ -9,10 +9,6 @@ from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
class ChatFeedBackEntity(Model):
__tablename__ = "chat_feed_back"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True)
conv_uid = Column(String(128))
conv_index = Column(Integer)

View File

@@ -13,10 +13,6 @@ CFG = Config()
class PromptManageEntity(Model):
__tablename__ = "prompt_manage"
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
id = Column(Integer, primary_key=True)
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))