refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -1,5 +1,7 @@
from typing import Optional
import click
import os
import functools
from dbgpt.app.base import WebServerParameters
from dbgpt.configs.model_config import LOGDIR
from dbgpt.util.parameter_utils import EnvArgumentParser
@@ -34,3 +36,241 @@ def stop_webserver(port: int):
def _stop_all_dbgpt_server():
_stop_service("webserver", "WebServer")
@click.group("migration")
def migration():
"""Manage database migration"""
pass
def add_migration_options(func):
@click.option(
"--alembic_ini_path",
required=False,
type=str,
default=None,
show_default=True,
help="Alembic ini path, if not set, use 'pilot/meta_data/alembic.ini'",
)
@click.option(
"--script_location",
required=False,
type=str,
default=None,
show_default=True,
help="Alembic script location, if not set, use 'pilot/meta_data/alembic'",
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@migration.command()
@add_migration_options
@click.option(
"-m",
"--message",
required=False,
type=str,
default="Init migration",
show_default=True,
help="The message for create migration repository",
)
def init(alembic_ini_path: str, script_location: str, message: str):
"""Initialize database migration repository"""
from dbgpt.util._db_migration_utils import create_migration_script
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
create_migration_script(alembic_cfg, db_manager.engine, message)
@migration.command()
@add_migration_options
@click.option(
"-m",
"--message",
required=False,
type=str,
default="New migration",
show_default=True,
help="The message for migration script",
)
def migrate(alembic_ini_path: str, script_location: str, message: str):
"""Create migration script"""
from dbgpt.util._db_migration_utils import create_migration_script
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
create_migration_script(alembic_cfg, db_manager.engine, message)
@migration.command()
@add_migration_options
def upgrade(alembic_ini_path: str, script_location: str):
"""Upgrade database to target version"""
from dbgpt.util._db_migration_utils import upgrade_database
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
upgrade_database(alembic_cfg, db_manager.engine)
@migration.command()
@add_migration_options
@click.option(
"-y",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm to downgrade database",
)
@click.option(
"-r",
"--revision",
default="-1",
show_default=True,
help="Revision to downgrade to",
)
def downgrade(alembic_ini_path: str, script_location: str, y: bool, revision: str):
"""Downgrade database to target version"""
from dbgpt.util._db_migration_utils import downgrade_database
if not y:
click.confirm("Are you sure you want to downgrade the database?", abort=True)
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
downgrade_database(alembic_cfg, db_manager.engine, revision)
@migration.command()
@add_migration_options
@click.option(
"--drop_all_tables",
required=False,
type=bool,
default=False,
is_flag=True,
help="Drop all tables",
)
@click.option(
"-y",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm to clean migration data",
)
@click.option(
"--confirm_drop_all_tables",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm to drop all tables",
)
def clean(
alembic_ini_path: str,
script_location: str,
drop_all_tables: bool,
y: bool,
confirm_drop_all_tables: bool,
):
"""Clean Alembic migration scripts and history"""
from dbgpt.util._db_migration_utils import clean_alembic_migration
if not y:
click.confirm(
"Are you sure clean alembic migration scripts and history?", abort=True
)
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
clean_alembic_migration(alembic_cfg, db_manager.engine)
if drop_all_tables:
if not confirm_drop_all_tables:
click.confirm("\nAre you sure drop all tables?", abort=True)
with db_manager.engine.connect() as connection:
for tbl in reversed(db_manager.Model.metadata.sorted_tables):
print(f"Drop table {tbl.name}")
connection.execute(tbl.delete())
@migration.command()
@add_migration_options
def list(alembic_ini_path: str, script_location: str):
"""List all versions in the migration history, marking the current one"""
from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
# Set up Alembic environment and script directory
script = ScriptDirectory.from_config(alembic_cfg)
# Get current revision
def get_current_revision():
with db_manager.engine.connect() as connection:
context = MigrationContext.configure(connection)
return context.get_current_revision()
current_rev = get_current_revision()
# List all revisions and mark the current one
for revision in script.walk_revisions():
current_marker = "(current)" if revision.revision == current_rev else ""
print(f"{revision.revision} {current_marker}: {revision.doc}")
@migration.command()
@add_migration_options
@click.argument("revision", required=True)
def show(alembic_ini_path: str, script_location: str, revision: str):
"""Show the migration script for a specific version."""
from alembic.script import ScriptDirectory
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
script = ScriptDirectory.from_config(alembic_cfg)
rev = script.get_revision(revision)
if rev is None:
print(f"Revision {revision} not found.")
return
# Find the migration script file
script_files = os.listdir(os.path.join(script.dir, "versions"))
script_file = next((f for f in script_files if f.startswith(revision)), None)
if script_file is None:
print(f"Migration script for revision {revision} not found.")
return
# Print the migration script
script_file_path = os.path.join(script.dir, "versions", script_file)
print(f"Migration script for revision {revision}: {script_file_path}")
try:
with open(script_file_path, "r") as file:
print(file.read())
except FileNotFoundError:
print(f"Migration script {script_file_path} not found.")
def _get_migration_config(
alembic_ini_path: Optional[str] = None, script_location: Optional[str] = None
):
from dbgpt.storage.metadata.db_manager import db as db_manager
from dbgpt.util._db_migration_utils import create_alembic_config
# Must import dbgpt_server for initialize db metadata
from dbgpt.app.dbgpt_server import initialize_app as _
from dbgpt.app.base import _initialize_db
# initialize db
default_meta_data_path = _initialize_db()
alembic_cfg = create_alembic_config(
default_meta_data_path,
db_manager.engine,
db_manager.Model,
db_manager.session(),
alembic_ini_path,
script_location,
)
return alembic_cfg, db_manager

View File

@@ -8,7 +8,8 @@ from dataclasses import dataclass, field
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.util.parameter_utils import BaseParameters
from dbgpt.storage.metadata.meta_data import ddl_init_and_upgrade
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -36,8 +37,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
# init config
cfg = Config()
cfg.SYSTEM_APP = system_app
ddl_init_and_upgrade(param.disable_alembic_upgrade)
# Initialize db storage first
_initialize_db_storage(param)
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
@@ -83,6 +84,46 @@ def _create_model_start_listener(system_app: SystemApp):
return startup_event
def _initialize_db_storage(param: "WebServerParameters"):
"""Initialize the db storage.
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
default_meta_data_path = _initialize_db(
try_to_create_db=not param.disable_alembic_upgrade
)
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
"""Initialize the database
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.storage.metadata.db_manager import initialize_db
from urllib.parse import quote_plus as urlquote, quote
CFG = Config()
db_name = CFG.LOCAL_DB_NAME
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
os.makedirs(default_meta_data_path, exist_ok=True)
if CFG.LOCAL_DB_TYPE == "mysql":
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
else:
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
db_url = f"sqlite:///{sqlite_db_path}"
engine_args = {
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
"max_overflow": CFG.LOCAL_DB_POOL_OVERFLOW,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
}
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
return default_meta_data_path
@dataclass
class WebServerParameters(BaseParameters):
host: Optional[str] = field(

View File

@@ -13,7 +13,6 @@ from dbgpt.app.base import WebServerParameters
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
logger = logging.getLogger(__name__)
CFG = Config()

View File

@@ -3,19 +3,13 @@ from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
CFG = Config()
class DocumentChunkEntity(Base):
class DocumentChunkEntity(Model):
__tablename__ = "document_chunk"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -35,16 +29,8 @@ class DocumentChunkEntity(Base):
class DocumentChunkDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_documents_chunks(self, documents: List):
session = self.get_session()
session = self.get_raw_session()
docs = [
DocumentChunkEntity(
doc_name=document.doc_name,
@@ -64,7 +50,7 @@ class DocumentChunkDao(BaseDao):
def get_document_chunks(
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
):
session = self.get_session()
session = self.get_raw_session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@@ -102,7 +88,7 @@ class DocumentChunkDao(BaseDao):
return result
def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.get_session()
session = self.get_raw_session()
document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@@ -127,7 +113,7 @@ class DocumentChunkDao(BaseDao):
return count
def delete(self, document_id: int):
session = self.get_session()
session = self.get_raw_session()
if document_id is None:
raise Exception("document_id is None")
query = DocumentChunkEntity(document_id=document_id)

View File

@@ -2,19 +2,13 @@ from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
CFG = Config()
class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentEntity(Model):
__tablename__ = "knowledge_document"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -39,16 +33,8 @@ class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.get_session()
session = self.get_raw_session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
@@ -69,7 +55,7 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.get_session()
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
@@ -104,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
return result
def get_documents(self, query):
session = self.get_session()
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
@@ -136,7 +122,7 @@ class KnowledgeDocumentDao(BaseDao):
return result
def get_knowledge_documents_count_bulk(self, space_names):
session = self.get_session()
session = self.get_raw_session()
"""
Perform a batch query to count the number of documents for each knowledge space.
@@ -161,7 +147,7 @@ class KnowledgeDocumentDao(BaseDao):
return docs_count
def get_knowledge_documents_count(self, query):
session = self.get_session()
session = self.get_raw_session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
@@ -188,14 +174,14 @@ class KnowledgeDocumentDao(BaseDao):
return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.get_session()
session = self.get_raw_session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
#
def delete(self, query: KnowledgeDocumentEntity):
session = self.get_session()
session = self.get_raw_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(

View File

@@ -2,20 +2,14 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
CFG = Config()
class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceEntity(Model):
__tablename__ = "knowledge_space"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -35,16 +29,8 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.get_session()
session = self.get_raw_session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
vector_type=CFG.VECTOR_STORE_TYPE,
@@ -58,7 +44,7 @@ class KnowledgeSpaceDao(BaseDao):
session.close()
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.get_session()
session = self.get_raw_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(
@@ -97,14 +83,14 @@ class KnowledgeSpaceDao(BaseDao):
return result
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.get_session()
session = self.get_raw_session()
session.merge(space)
session.commit()
session.close()
return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.get_session()
session = self.get_raw_session()
if space:
session.delete(space)
session.commit()

View File

@@ -2,17 +2,12 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
class ChatFeedBackEntity(Base):
class ChatFeedBackEntity(Model):
__tablename__ = "chat_feed_back"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -39,18 +34,10 @@ class ChatFeedBackEntity(Base):
class ChatFeedBackDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first.
session = self.get_session()
session = self.get_raw_session()
chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index,
@@ -84,7 +71,7 @@ class ChatFeedBackDao(BaseDao):
session.close()
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.get_session()
session = self.get_raw_session()
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid)

View File

@@ -2,13 +2,8 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
from dbgpt.app.prompt.request.request import PromptManageRequest
@@ -16,7 +11,7 @@ from dbgpt.app.prompt.request.request import PromptManageRequest
CFG = Config()
class PromptManageEntity(Base):
class PromptManageEntity(Model):
__tablename__ = "prompt_manage"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -38,16 +33,8 @@ class PromptManageEntity(Base):
class PromptManageDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_prompt(self, prompt: PromptManageRequest):
session = self.get_session()
session = self.get_raw_session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
@@ -64,7 +51,7 @@ class PromptManageDao(BaseDao):
session.close()
def get_prompts(self, query: PromptManageEntity):
session = self.get_session()
session = self.get_raw_session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
@@ -93,13 +80,13 @@ class PromptManageDao(BaseDao):
return result
def update_prompt(self, prompt: PromptManageEntity):
session = self.get_session()
session = self.get_raw_session()
session.merge(prompt)
session.commit()
session.close()
def delete_prompt(self, prompt: PromptManageEntity):
session = self.get_session()
session = self.get_raw_session()
if prompt:
session.delete(prompt)
session.commit()

View File

@@ -146,7 +146,9 @@ class BaseChat(ABC):
input_values = await self.generate_input_values()
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
self.current_message.add_user_message(
self.current_user_input, check_duplicate_type=True
)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
@@ -221,7 +223,7 @@ class BaseChat(ABC):
view_msg = self.stream_plugin_call(msg)
view_msg = view_msg.replace("\n", "\\n")
yield view_msg
self.current_message.add_ai_message(msg)
self.current_message.add_ai_message(msg, update_if_exist=True)
view_msg = self.stream_call_reinforce_fn(view_msg)
self.current_message.add_view_message(view_msg)
span.end()
@@ -257,7 +259,7 @@ class BaseChat(ABC):
)
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
@@ -320,7 +322,7 @@ class BaseChat(ABC):
)
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
prompt_define_response = None
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
@@ -596,7 +598,7 @@ def _load_system_message(
prompt_template: PromptTemplate,
str_message: bool = True,
):
system_convs = current_message.get_system_conv()
system_convs = current_message.get_system_messages()
system_text = ""
system_messages = []
for system_conv in system_convs:
@@ -614,7 +616,7 @@ def _load_user_message(
prompt_template: PromptTemplate,
str_message: bool = True,
):
user_conv = current_message.get_user_conv()
user_conv = current_message.get_latest_user_message()
user_messages = []
if user_conv:
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep

View File

@@ -70,7 +70,9 @@ class ChatHistoryManager:
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self._chat_ctx.current_user_input)
self.current_message.add_user_message(
self._chat_ctx.current_user_input, check_duplicate_type=True
)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)