mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: Refactor storage and new serve template (#947)
This commit is contained in:
@@ -108,12 +108,24 @@ def migrate(alembic_ini_path: str, script_location: str, message: str):
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
def upgrade(alembic_ini_path: str, script_location: str):
|
||||
@click.option(
|
||||
"--sql-output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Generate SQL script for migration instead of applying it. ex: --sql-output=upgrade.sql",
|
||||
)
|
||||
def upgrade(alembic_ini_path: str, script_location: str, sql_output: str):
|
||||
"""Upgrade database to target version"""
|
||||
from dbgpt.util._db_migration_utils import upgrade_database
|
||||
from dbgpt.util._db_migration_utils import (
|
||||
upgrade_database,
|
||||
generate_sql_for_upgrade,
|
||||
)
|
||||
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
upgrade_database(alembic_cfg, db_manager.engine)
|
||||
if sql_output:
|
||||
generate_sql_for_upgrade(alembic_cfg, db_manager.engine, output_file=sql_output)
|
||||
else:
|
||||
upgrade_database(alembic_cfg, db_manager.engine)
|
||||
|
||||
|
||||
@migration.command()
|
||||
@@ -199,6 +211,7 @@ def clean(
|
||||
def list(alembic_ini_path: str, script_location: str):
|
||||
"""List all versions in the migration history, marking the current one"""
|
||||
from alembic.script import ScriptDirectory
|
||||
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
@@ -259,8 +272,8 @@ def _get_migration_config(
|
||||
from dbgpt.storage.metadata.db_manager import db as db_manager
|
||||
from dbgpt.util._db_migration_utils import create_alembic_config
|
||||
|
||||
# Must import dbgpt_server for initialize db metadata
|
||||
from dbgpt.app.dbgpt_server import initialize_app as _
|
||||
# Import all models to make sure they are registered with SQLAlchemy.
|
||||
from dbgpt.app.initialization.db_model_initialization import _MODELS
|
||||
from dbgpt.app.base import _initialize_db
|
||||
|
||||
# initialize db
|
||||
|
@@ -10,7 +10,6 @@ from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util.parameter_utils import BaseParameters
|
||||
|
||||
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@@ -92,10 +91,27 @@ def _initialize_db_storage(param: "WebServerParameters"):
|
||||
|
||||
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
|
||||
"""
|
||||
default_meta_data_path = _initialize_db(
|
||||
try_to_create_db=not param.disable_alembic_upgrade
|
||||
)
|
||||
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
|
||||
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)
|
||||
|
||||
|
||||
def _migration_db_storage(param: "WebServerParameters"):
|
||||
"""Migration the db storage."""
|
||||
# Import all models to make sure they are registered with SQLAlchemy.
|
||||
from dbgpt.app.initialization.db_model_initialization import _MODELS
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
|
||||
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
|
||||
if not param.disable_alembic_upgrade:
|
||||
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
|
||||
from dbgpt.storage.metadata.db_manager import db
|
||||
|
||||
# try to create all tables
|
||||
try:
|
||||
db.create_all()
|
||||
except Exception as e:
|
||||
logger.warning(f"Create all tables stored in this metadata error: {str(e)}")
|
||||
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
|
||||
|
||||
|
||||
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
@@ -112,7 +128,13 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
|
||||
os.makedirs(default_meta_data_path, exist_ok=True)
|
||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
||||
db_url = (
|
||||
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:"
|
||||
f"{urlquote(CFG.LOCAL_DB_PASSWORD)}@"
|
||||
f"{CFG.LOCAL_DB_HOST}:"
|
||||
f"{str(CFG.LOCAL_DB_PORT)}/"
|
||||
f"{db_name}?charset=utf8mb4&collation=utf8mb4_unicode_ci"
|
||||
)
|
||||
# Try to create database, if failed, will raise exception
|
||||
_create_mysql_database(db_name, db_url, try_to_create_db)
|
||||
else:
|
||||
@@ -125,7 +147,7 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
|
||||
initialize_db(db_url, db_name, engine_args)
|
||||
return default_meta_data_path
|
||||
|
||||
|
||||
@@ -161,7 +183,11 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F
|
||||
no_db_name_url = db_url.rsplit("/", 1)[0]
|
||||
engine_no_db = create_engine(no_db_name_url)
|
||||
with engine_no_db.connect() as conn:
|
||||
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
|
||||
conn.execute(
|
||||
DDL(
|
||||
f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
|
||||
)
|
||||
)
|
||||
logger.info(f"Database {db_name} successfully created")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Failed to create database {db_name}: {e}")
|
||||
|
@@ -1,17 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR
|
||||
from dbgpt.util.executor_utils import DefaultExecutorFactory
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.app.base import WebServerParameters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +20,10 @@ def initialize_components(
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
):
|
||||
# Lazy import to avoid high time cost
|
||||
from dbgpt.model.cluster.controller.controller import controller
|
||||
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
|
||||
from dbgpt.app.initialization.serve_initialization import register_serve_apps
|
||||
|
||||
# Register global default executor factory first
|
||||
system_app.register(DefaultExecutorFactory)
|
||||
@@ -44,97 +43,8 @@ def initialize_components(
|
||||
)
|
||||
_initialize_model_cache(system_app)
|
||||
_initialize_awel(system_app)
|
||||
|
||||
|
||||
def _initialize_embedding_model(
|
||||
param: WebServerParameters,
|
||||
system_app: SystemApp,
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
):
|
||||
if param.remote_embedding:
|
||||
logger.info("Register remote RemoteEmbeddingFactory")
|
||||
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
|
||||
else:
|
||||
logger.info(f"Register local LocalEmbeddingFactory")
|
||||
system_app.register(
|
||||
LocalEmbeddingFactory,
|
||||
default_model_name=embedding_model_name,
|
||||
default_model_path=embedding_model_path,
|
||||
)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = model_name
|
||||
self.kwargs = kwargs
|
||||
self.system_app = system_app
|
||||
|
||||
def init_app(self, system_app):
|
||||
self.system_app = system_app
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
||||
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
# Ignore model_name args
|
||||
return RemoteEmbeddings(self._default_model_name, worker_manager)
|
||||
|
||||
|
||||
class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(
|
||||
self,
|
||||
system_app,
|
||||
default_model_name: str = None,
|
||||
default_model_path: str = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> "Embeddings":
|
||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load(self._default_model_name, model_params)
|
||||
# Register serve apps
|
||||
register_serve_apps(system_app)
|
||||
|
||||
|
||||
def _initialize_model_cache(system_app: SystemApp):
|
||||
|
@@ -16,28 +16,22 @@ from dbgpt.component import SystemApp
|
||||
|
||||
from dbgpt.app.base import (
|
||||
server_init,
|
||||
_migration_db_storage,
|
||||
WebServerParameters,
|
||||
_create_model_start_listener,
|
||||
)
|
||||
|
||||
# initialize_components import time cost about 0.1s
|
||||
from dbgpt.app.component_configs import initialize_components
|
||||
|
||||
# fastapi import time cost about 0.05s
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from dbgpt.app.knowledge.api import router as knowledge_router
|
||||
from dbgpt.app.prompt.api import router as prompt_router
|
||||
from dbgpt.app.llm_manage.api import router as llm_manage_api
|
||||
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
|
||||
from dbgpt.app.openapi.base import validation_exception_handler
|
||||
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||
from dbgpt.agent.commands.disply_type.show_chart_gen import (
|
||||
static_message_img_path,
|
||||
)
|
||||
from dbgpt.model.cluster import initialize_worker_manager_in_client
|
||||
from dbgpt.util.utils import (
|
||||
setup_logging,
|
||||
_get_logging_level,
|
||||
@@ -78,16 +72,35 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(api_v1, prefix="/api", tags=["Chat"])
|
||||
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
|
||||
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
|
||||
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
|
||||
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
app.include_router(prompt_router, tags=["Prompt"])
|
||||
def mount_routers(app: FastAPI):
|
||||
"""Lazy import to avoid high time cost"""
|
||||
from dbgpt.app.knowledge.api import router as knowledge_router
|
||||
|
||||
# from dbgpt.app.prompt.api import router as prompt_router
|
||||
# prompt has been removed to dbgpt.serve.prompt
|
||||
from dbgpt.app.llm_manage.api import router as llm_manage_api
|
||||
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
|
||||
from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import (
|
||||
router as api_editor_route_v1,
|
||||
)
|
||||
from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||
|
||||
app.include_router(api_v1, prefix="/api", tags=["Chat"])
|
||||
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
|
||||
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
|
||||
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
|
||||
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
# app.include_router(prompt_router, tags=["Prompt"])
|
||||
|
||||
|
||||
def mount_static_files(app):
|
||||
def mount_static_files(app: FastAPI):
|
||||
from dbgpt.agent.commands.disply_type.show_chart_gen import (
|
||||
static_message_img_path,
|
||||
)
|
||||
|
||||
os.makedirs(static_message_img_path, exist_ok=True)
|
||||
app.mount(
|
||||
"/images",
|
||||
@@ -122,14 +135,15 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
if not param:
|
||||
param = _get_webserver_params(args)
|
||||
|
||||
# import after param is initialized, accelerate --help speed
|
||||
from dbgpt.model.cluster import initialize_worker_manager_in_client
|
||||
|
||||
if not param.log_level:
|
||||
param.log_level = _get_logging_level()
|
||||
setup_logging(
|
||||
"dbgpt", logging_level=param.log_level, logger_filename=param.log_file
|
||||
)
|
||||
|
||||
# Before start
|
||||
system_app.before_start()
|
||||
model_name = param.model_name or CFG.LLM_MODEL
|
||||
param.model_name = model_name
|
||||
print(param)
|
||||
@@ -138,9 +152,16 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
|
||||
server_init(param, system_app)
|
||||
mount_routers(app)
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||
|
||||
# Before start, after initialize_components
|
||||
# TODO: initialize_worker_manager_in_client as a component register in system_app
|
||||
system_app.before_start()
|
||||
# Migration db storage, so you db models must be imported before this
|
||||
_migration_db_storage(param)
|
||||
|
||||
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
|
0
dbgpt/app/initialization/__init__.py
Normal file
0
dbgpt/app/initialization/__init__.py
Normal file
29
dbgpt/app/initialization/db_model_initialization.py
Normal file
29
dbgpt/app/initialization/db_model_initialization.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Import all models to make sure they are registered with SQLAlchemy.
|
||||
"""
|
||||
from dbgpt.agent.db.my_plugin_db import MyPluginEntity
|
||||
from dbgpt.agent.db.plugin_hub_db import PluginHubEntity
|
||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
|
||||
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
|
||||
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
|
||||
|
||||
# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
|
||||
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
|
||||
from dbgpt.storage.chat_history.chat_history_db import (
|
||||
ChatHistoryEntity,
|
||||
ChatHistoryMessageEntity,
|
||||
)
|
||||
|
||||
_MODELS = [
|
||||
PluginHubEntity,
|
||||
MyPluginEntity,
|
||||
PromptManageEntity,
|
||||
KnowledgeSpaceEntity,
|
||||
KnowledgeDocumentEntity,
|
||||
DocumentChunkEntity,
|
||||
ChatFeedBackEntity,
|
||||
ConnectConfigEntity,
|
||||
ChatHistoryEntity,
|
||||
ChatHistoryMessageEntity,
|
||||
]
|
103
dbgpt/app/initialization/embedding_component.py
Normal file
103
dbgpt/app/initialization/embedding_component.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from dbgpt.app.base import WebServerParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _initialize_embedding_model(
|
||||
param: "WebServerParameters",
|
||||
system_app: SystemApp,
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
):
|
||||
if param.remote_embedding:
|
||||
logger.info("Register remote RemoteEmbeddingFactory")
|
||||
system_app.register(RemoteEmbeddingFactory, model_name=embedding_model_name)
|
||||
else:
|
||||
logger.info(f"Register local LocalEmbeddingFactory")
|
||||
system_app.register(
|
||||
LocalEmbeddingFactory,
|
||||
default_model_name=embedding_model_name,
|
||||
default_model_path=embedding_model_path,
|
||||
)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = model_name
|
||||
self.kwargs = kwargs
|
||||
self.system_app = system_app
|
||||
|
||||
def init_app(self, system_app):
|
||||
self.system_app = system_app
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
|
||||
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
# Ignore model_name args
|
||||
return RemoteEmbeddings(self._default_model_name, worker_manager)
|
||||
|
||||
|
||||
class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(
|
||||
self,
|
||||
system_app,
|
||||
default_model_name: str = None,
|
||||
default_model_path: str = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> "Embeddings":
|
||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load(self._default_model_name, model_params)
|
9
dbgpt/app/initialization/serve_initialization.py
Normal file
9
dbgpt/app/initialization/serve_initialization.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from dbgpt.component import SystemApp
|
||||
|
||||
|
||||
def register_serve_apps(system_app: SystemApp):
|
||||
"""Register serve apps"""
|
||||
from dbgpt.serve.prompt.serve import Serve as PromptServe
|
||||
|
||||
# Replace old prompt serve
|
||||
system_app.register(PromptServe, api_prefix="/prompt")
|
@@ -11,10 +11,6 @@ CFG = Config()
|
||||
|
||||
class DocumentChunkEntity(Model):
|
||||
__tablename__ = "document_chunk"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
document_id = Column(Integer)
|
||||
doc_name = Column(String(100))
|
||||
@@ -112,7 +108,7 @@ class DocumentChunkDao(BaseDao):
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def delete(self, document_id: int):
|
||||
def raw_delete(self, document_id: int):
|
||||
session = self.get_raw_session()
|
||||
if document_id is None:
|
||||
raise Exception("document_id is None")
|
||||
|
@@ -10,10 +10,6 @@ CFG = Config()
|
||||
|
||||
class KnowledgeDocumentEntity(Model):
|
||||
__tablename__ = "knowledge_document"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
doc_name = Column(String(100))
|
||||
doc_type = Column(String(100))
|
||||
@@ -180,7 +176,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return updated_space.id
|
||||
|
||||
#
|
||||
def delete(self, query: KnowledgeDocumentEntity):
|
||||
def raw_delete(self, query: KnowledgeDocumentEntity):
|
||||
session = self.get_raw_session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
|
@@ -367,9 +367,9 @@ class KnowledgeService:
|
||||
# delete chunks
|
||||
documents = knowledge_document_dao.get_documents(document_query)
|
||||
for document in documents:
|
||||
document_chunk_dao.delete(document.id)
|
||||
document_chunk_dao.raw_delete(document.id)
|
||||
# delete documents
|
||||
knowledge_document_dao.delete(document_query)
|
||||
knowledge_document_dao.raw_delete(document_query)
|
||||
# delete space
|
||||
return knowledge_space_dao.delete_knowledge_space(space)
|
||||
|
||||
@@ -395,9 +395,9 @@ class KnowledgeService:
|
||||
# delete vector by ids
|
||||
vector_client.delete_by_ids(vector_ids)
|
||||
# delete chunks
|
||||
document_chunk_dao.delete(documents[0].id)
|
||||
document_chunk_dao.raw_delete(documents[0].id)
|
||||
# delete document
|
||||
return knowledge_document_dao.delete(document_query)
|
||||
return knowledge_document_dao.raw_delete(document_query)
|
||||
|
||||
def get_document_chunks(self, request: ChunkQueryRequest):
|
||||
"""get document chunks
|
||||
|
@@ -11,10 +11,6 @@ CFG = Config()
|
||||
|
||||
class KnowledgeSpaceEntity(Model):
|
||||
__tablename__ = "knowledge_space"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100))
|
||||
vector_type = Column(String(100))
|
||||
|
@@ -9,10 +9,6 @@ from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
|
||||
class ChatFeedBackEntity(Model):
|
||||
__tablename__ = "chat_feed_back"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
conv_uid = Column(String(128))
|
||||
conv_index = Column(Integer)
|
||||
|
@@ -13,10 +13,6 @@ CFG = Config()
|
||||
|
||||
class PromptManageEntity(Model):
|
||||
__tablename__ = "prompt_manage"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
chat_scene = Column(String(100))
|
||||
sub_chat_scene = Column(String(100))
|
||||
|
Reference in New Issue
Block a user