mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 12:37:14 +00:00
refactor: Refactor storage and new serve template (#947)
This commit is contained in:
@@ -1,5 +1,17 @@
|
||||
from .utils import (
|
||||
get_gpu_memory,
|
||||
server_error_msg,
|
||||
get_or_create_event_loop,
|
||||
)
|
||||
from .pagination_utils import PaginationResult
|
||||
from .parameter_utils import BaseParameters, ParameterDescription, EnvArgumentParser
|
||||
from .config_utils import AppConfig
|
||||
|
||||
__ALL__ = [
|
||||
"get_gpu_memory",
|
||||
"get_or_create_event_loop",
|
||||
"PaginationResult",
|
||||
"BaseParameters",
|
||||
"ParameterDescription",
|
||||
"EnvArgumentParser",
|
||||
"AppConfig",
|
||||
]
|
||||
|
@@ -51,19 +51,50 @@ def create_alembic_config(
|
||||
|
||||
|
||||
def create_migration_script(
|
||||
alembic_cfg: AlembicConfig, engine: Engine, message: str = "New migration"
|
||||
) -> None:
|
||||
alembic_cfg: AlembicConfig,
|
||||
engine: Engine,
|
||||
message: str = "New migration",
|
||||
create_new_revision_if_noting_to_update: Optional[bool] = True,
|
||||
) -> str:
|
||||
"""Create migration script.
|
||||
|
||||
Args:
|
||||
alembic_cfg: alembic config
|
||||
engine: sqlalchemy engine
|
||||
message: migration message
|
||||
|
||||
create_new_revision_if_noting_to_update: Whether to create a new revision if there is nothing to update,
|
||||
pass False to avoid creating a new revision if there is nothing to update, default is True
|
||||
Returns:
|
||||
The path of the generated migration script.
|
||||
"""
|
||||
from alembic.script import ScriptDirectory
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
|
||||
# Check if the database is up-to-date
|
||||
script_dir = ScriptDirectory.from_config(alembic_cfg)
|
||||
with engine.connect() as connection:
|
||||
alembic_cfg.attributes["connection"] = connection
|
||||
command.revision(alembic_cfg, message, autogenerate=True)
|
||||
context = MigrationContext.configure(connection=connection)
|
||||
current_rev = context.get_current_revision()
|
||||
head_rev = script_dir.get_current_head()
|
||||
|
||||
logger.info(
|
||||
f"alembic migration current revision: {current_rev}, latest revision: {head_rev}"
|
||||
)
|
||||
should_create_revision = (
|
||||
(current_rev is None and head_rev is None)
|
||||
or current_rev != head_rev
|
||||
or create_new_revision_if_noting_to_update
|
||||
)
|
||||
if should_create_revision:
|
||||
with engine.connect() as connection:
|
||||
alembic_cfg.attributes["connection"] = connection
|
||||
revision = command.revision(alembic_cfg, message=message, autogenerate=True)
|
||||
# Return the path of the generated migration script
|
||||
return revision.path
|
||||
elif current_rev == head_rev:
|
||||
logger.info("No migration script to generate, database is up-to-date")
|
||||
# If no new revision is created, return None or an appropriate message
|
||||
return None
|
||||
|
||||
|
||||
def upgrade_database(
|
||||
@@ -82,6 +113,37 @@ def upgrade_database(
|
||||
command.upgrade(alembic_cfg, target_version)
|
||||
|
||||
|
||||
def generate_sql_for_upgrade(
|
||||
alembic_cfg: AlembicConfig,
|
||||
engine: Engine,
|
||||
target_version: Optional[str] = "head",
|
||||
output_file: Optional[str] = "migration.sql",
|
||||
) -> None:
|
||||
"""Generate SQL for upgrading database to target version.
|
||||
|
||||
Args:
|
||||
alembic_cfg: alembic config
|
||||
engine: sqlalchemy engine
|
||||
target_version: target version, default is head (latest version)
|
||||
output_file: file to write the SQL script
|
||||
|
||||
TODO: Can't generate SQL for most of the operations.
|
||||
"""
|
||||
import contextlib
|
||||
import io
|
||||
|
||||
with engine.connect() as connection, contextlib.redirect_stdout(
|
||||
io.StringIO()
|
||||
) as stdout:
|
||||
alembic_cfg.attributes["connection"] = connection
|
||||
# Generating SQL instead of applying changes
|
||||
command.upgrade(alembic_cfg, target_version, sql=True)
|
||||
|
||||
# Write the generated SQL to a file
|
||||
with open(output_file, "w", encoding="utf-8") as file:
|
||||
file.write(stdout.getvalue())
|
||||
|
||||
|
||||
def downgrade_database(
|
||||
alembic_cfg: AlembicConfig, engine: Engine, revision: str = "-1"
|
||||
):
|
||||
@@ -160,9 +222,94 @@ or
|
||||
rm -rf pilot/meta_data/alembic/versions/*
|
||||
rm -rf pilot/meta_data/alembic/dbgpt.db
|
||||
```
|
||||
|
||||
If your database is a shared database, and you run DB-GPT in multiple instances,
|
||||
you should make sure that all migration scripts are same in all instances, in this case,
|
||||
wo strongly recommend you close migration feature by setting `--disable_alembic_upgrade`.
|
||||
and use `dbgpt db migration` command to manage migration scripts.
|
||||
"""
|
||||
|
||||
|
||||
def _check_database_migration_status(alembic_cfg: AlembicConfig, engine: Engine):
|
||||
"""Check if the database is at the latest migration revision.
|
||||
|
||||
If your database is a shared database, and you run DB-GPT in multiple instances,
|
||||
you should make sure that all migration scripts are same in all instances, in this case,
|
||||
wo strongly recommend you close migration feature by setting `disable_alembic_upgrade` to True.
|
||||
and use `dbgpt db migration` command to manage migration scripts.
|
||||
|
||||
Args:
|
||||
alembic_cfg: Alembic configuration object.
|
||||
engine: SQLAlchemy engine instance.
|
||||
Raises:
|
||||
Exception: If the database is not at the latest revision.
|
||||
"""
|
||||
from alembic.script import ScriptDirectory
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
|
||||
script = ScriptDirectory.from_config(alembic_cfg)
|
||||
|
||||
def get_current_revision(engine):
|
||||
with engine.connect() as connection:
|
||||
context = MigrationContext.configure(connection=connection)
|
||||
return context.get_current_revision()
|
||||
|
||||
current_rev = get_current_revision(engine)
|
||||
head_rev = script.get_current_head()
|
||||
|
||||
script_info_msg = "Migration versions and their file paths:"
|
||||
script_info_msg += f"\n{'='*40}Migration versions{'='*40}\n"
|
||||
for revision in script.walk_revisions(base="base"):
|
||||
current_marker = "(current)" if revision.revision == current_rev else ""
|
||||
script_path = script.get_revision(revision.revision).path
|
||||
script_info_msg += f"\n{revision.revision} {current_marker}: {revision.doc} (Path: {script_path})"
|
||||
script_info_msg += f"\n{'='*90}"
|
||||
|
||||
logger.info(script_info_msg)
|
||||
|
||||
if current_rev != head_rev:
|
||||
logger.error(
|
||||
"Database is not at the latest revision. "
|
||||
f"Current revision: {current_rev}, latest revision: {head_rev}\n"
|
||||
"Please apply existing migration scripts before generating new ones. "
|
||||
"Check the listed file paths for migration scripts.\n"
|
||||
f"Also you can try the following solutions:\n{_MIGRATION_SOLUTION}\n"
|
||||
)
|
||||
raise Exception(
|
||||
"Check database migration status failed, you can see the error and solutions above"
|
||||
)
|
||||
|
||||
|
||||
def _get_latest_revision(alembic_cfg: AlembicConfig, engine: Engine) -> str:
|
||||
"""Get the latest revision of the database.
|
||||
|
||||
Args:
|
||||
alembic_cfg: Alembic configuration object.
|
||||
engine: SQLAlchemy engine instance.
|
||||
|
||||
Returns:
|
||||
The latest revision as a string.
|
||||
"""
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
|
||||
with engine.connect() as connection:
|
||||
context = MigrationContext.configure(connection=connection)
|
||||
return context.get_current_revision()
|
||||
|
||||
|
||||
def _delete_migration_script(script_path: str):
|
||||
"""Delete a migration script.
|
||||
|
||||
Args:
|
||||
script_path: The path of the migration script to delete.
|
||||
"""
|
||||
if os.path.exists(script_path):
|
||||
os.remove(script_path)
|
||||
logger.info(f"Deleted migration script at: {script_path}")
|
||||
else:
|
||||
logger.warning(f"Migration script not found at: {script_path}")
|
||||
|
||||
|
||||
def _ddl_init_and_upgrade(
|
||||
default_meta_data_path: str,
|
||||
disable_alembic_upgrade: bool,
|
||||
@@ -203,7 +350,19 @@ def _ddl_init_and_upgrade(
|
||||
script_location,
|
||||
)
|
||||
try:
|
||||
create_migration_script(alembic_cfg, db.engine)
|
||||
_check_database_migration_status(alembic_cfg, db.engine)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check database migration status: {e}")
|
||||
raise
|
||||
latest_revision_before = "__latest_revision_before__"
|
||||
new_script_path = None
|
||||
try:
|
||||
latest_revision_before = _get_latest_revision(alembic_cfg, db.engine)
|
||||
# create_new_revision_if_noting_to_update=False avoid creating a lot of empty migration scripts
|
||||
# TODO Set create_new_revision_if_noting_to_update=False, not working now.
|
||||
new_script_path = create_migration_script(
|
||||
alembic_cfg, db.engine, create_new_revision_if_noting_to_update=True
|
||||
)
|
||||
upgrade_database(alembic_cfg, db.engine)
|
||||
except CommandError as e:
|
||||
if "Target database is not up to date" in str(e):
|
||||
@@ -216,4 +375,10 @@ def _ddl_init_and_upgrade(
|
||||
"you can see the error and solutions above"
|
||||
) from e
|
||||
else:
|
||||
latest_revision_after = _get_latest_revision(alembic_cfg, db.engine)
|
||||
if latest_revision_before != latest_revision_after:
|
||||
logger.error(
|
||||
f"Upgrade database failed. Please review the migration script manually. "
|
||||
f"Failed script path: {new_script_path}\nError: {e}"
|
||||
)
|
||||
raise e
|
||||
|
32
dbgpt/util/config_utils.py
Normal file
32
dbgpt/util/config_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from functools import cache
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class AppConfig:
|
||||
def __init__(self):
|
||||
self.configs = {}
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""Set config value by key
|
||||
Args:
|
||||
key (str): The key of config
|
||||
value (Any): The value of config
|
||||
"""
|
||||
self.configs[key] = value
|
||||
|
||||
def get(self, key, default: Optional[Any] = None) -> Any:
|
||||
"""Get config value by key
|
||||
|
||||
Args:
|
||||
key (str): The key of config
|
||||
default (Optional[Any], optional): The default value if key not found. Defaults to None.
|
||||
"""
|
||||
return self.configs.get(key, default)
|
||||
|
||||
@cache
|
||||
def get_all_by_prefix(self, prefix) -> Dict[str, Any]:
|
||||
"""Get all config values by prefix
|
||||
Args:
|
||||
prefix (str): The prefix of config
|
||||
"""
|
||||
return {k: v for k, v in self.configs.items() if k.startswith(prefix)}
|
@@ -6,7 +6,6 @@ import logging.handlers
|
||||
from typing import Any, List
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
@@ -81,17 +80,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
||||
setup_logging_level(logging_level=logging_level)
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
# stdout_logger = logging.getLogger("stdout")
|
||||
# stdout_logger.setLevel(logging.INFO)
|
||||
# sl_1 = StreamToLogger(stdout_logger, logging.INFO)
|
||||
# sys.stdout = sl_1
|
||||
#
|
||||
# stderr_logger = logging.getLogger("stderr")
|
||||
# stderr_logger.setLevel(logging.ERROR)
|
||||
# sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
# sys.stderr = sl
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None and logger_filename:
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
|
Reference in New Issue
Block a user