mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from typing import Optional
|
||||
import click
|
||||
import os
|
||||
import functools
|
||||
from dbgpt.app.base import WebServerParameters
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
@@ -34,3 +36,241 @@ def stop_webserver(port: int):
|
||||
|
||||
def _stop_all_dbgpt_server():
|
||||
_stop_service("webserver", "WebServer")
|
||||
|
||||
|
||||
@click.group("migration")
|
||||
def migration():
|
||||
"""Manage database migration"""
|
||||
pass
|
||||
|
||||
|
||||
def add_migration_options(func):
|
||||
@click.option(
|
||||
"--alembic_ini_path",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Alembic ini path, if not set, use 'pilot/meta_data/alembic.ini'",
|
||||
)
|
||||
@click.option(
|
||||
"--script_location",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Alembic script location, if not set, use 'pilot/meta_data/alembic'",
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
@click.option(
|
||||
"-m",
|
||||
"--message",
|
||||
required=False,
|
||||
type=str,
|
||||
default="Init migration",
|
||||
show_default=True,
|
||||
help="The message for create migration repository",
|
||||
)
|
||||
def init(alembic_ini_path: str, script_location: str, message: str):
|
||||
"""Initialize database migration repository"""
|
||||
from dbgpt.util._db_migration_utils import create_migration_script
|
||||
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
create_migration_script(alembic_cfg, db_manager.engine, message)
|
||||
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
@click.option(
|
||||
"-m",
|
||||
"--message",
|
||||
required=False,
|
||||
type=str,
|
||||
default="New migration",
|
||||
show_default=True,
|
||||
help="The message for migration script",
|
||||
)
|
||||
def migrate(alembic_ini_path: str, script_location: str, message: str):
|
||||
"""Create migration script"""
|
||||
from dbgpt.util._db_migration_utils import create_migration_script
|
||||
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
create_migration_script(alembic_cfg, db_manager.engine, message)
|
||||
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
def upgrade(alembic_ini_path: str, script_location: str):
|
||||
"""Upgrade database to target version"""
|
||||
from dbgpt.util._db_migration_utils import upgrade_database
|
||||
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
upgrade_database(alembic_cfg, db_manager.engine)
|
||||
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
@click.option(
|
||||
"-y",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Confirm to downgrade database",
|
||||
)
|
||||
@click.option(
|
||||
"-r",
|
||||
"--revision",
|
||||
default="-1",
|
||||
show_default=True,
|
||||
help="Revision to downgrade to",
|
||||
)
|
||||
def downgrade(alembic_ini_path: str, script_location: str, y: bool, revision: str):
|
||||
"""Downgrade database to target version"""
|
||||
from dbgpt.util._db_migration_utils import downgrade_database
|
||||
|
||||
if not y:
|
||||
click.confirm("Are you sure you want to downgrade the database?", abort=True)
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
downgrade_database(alembic_cfg, db_manager.engine, revision)
|
||||
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
@click.option(
|
||||
"--drop_all_tables",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Drop all tables",
|
||||
)
|
||||
@click.option(
|
||||
"-y",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Confirm to clean migration data",
|
||||
)
|
||||
@click.option(
|
||||
"--confirm_drop_all_tables",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Confirm to drop all tables",
|
||||
)
|
||||
def clean(
|
||||
alembic_ini_path: str,
|
||||
script_location: str,
|
||||
drop_all_tables: bool,
|
||||
y: bool,
|
||||
confirm_drop_all_tables: bool,
|
||||
):
|
||||
"""Clean Alembic migration scripts and history"""
|
||||
from dbgpt.util._db_migration_utils import clean_alembic_migration
|
||||
|
||||
if not y:
|
||||
click.confirm(
|
||||
"Are you sure clean alembic migration scripts and history?", abort=True
|
||||
)
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
clean_alembic_migration(alembic_cfg, db_manager.engine)
|
||||
if drop_all_tables:
|
||||
if not confirm_drop_all_tables:
|
||||
click.confirm("\nAre you sure drop all tables?", abort=True)
|
||||
with db_manager.engine.connect() as connection:
|
||||
for tbl in reversed(db_manager.Model.metadata.sorted_tables):
|
||||
print(f"Drop table {tbl.name}")
|
||||
connection.execute(tbl.delete())
|
||||
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
def list(alembic_ini_path: str, script_location: str):
|
||||
"""List all versions in the migration history, marking the current one"""
|
||||
from alembic.script import ScriptDirectory
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
|
||||
# Set up Alembic environment and script directory
|
||||
script = ScriptDirectory.from_config(alembic_cfg)
|
||||
|
||||
# Get current revision
|
||||
def get_current_revision():
|
||||
with db_manager.engine.connect() as connection:
|
||||
context = MigrationContext.configure(connection)
|
||||
return context.get_current_revision()
|
||||
|
||||
current_rev = get_current_revision()
|
||||
|
||||
# List all revisions and mark the current one
|
||||
for revision in script.walk_revisions():
|
||||
current_marker = "(current)" if revision.revision == current_rev else ""
|
||||
print(f"{revision.revision} {current_marker}: {revision.doc}")
|
||||
|
||||
|
||||
@migration.command()
|
||||
@add_migration_options
|
||||
@click.argument("revision", required=True)
|
||||
def show(alembic_ini_path: str, script_location: str, revision: str):
|
||||
"""Show the migration script for a specific version."""
|
||||
from alembic.script import ScriptDirectory
|
||||
|
||||
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
|
||||
|
||||
script = ScriptDirectory.from_config(alembic_cfg)
|
||||
|
||||
rev = script.get_revision(revision)
|
||||
if rev is None:
|
||||
print(f"Revision {revision} not found.")
|
||||
return
|
||||
|
||||
# Find the migration script file
|
||||
script_files = os.listdir(os.path.join(script.dir, "versions"))
|
||||
script_file = next((f for f in script_files if f.startswith(revision)), None)
|
||||
|
||||
if script_file is None:
|
||||
print(f"Migration script for revision {revision} not found.")
|
||||
return
|
||||
# Print the migration script
|
||||
script_file_path = os.path.join(script.dir, "versions", script_file)
|
||||
print(f"Migration script for revision {revision}: {script_file_path}")
|
||||
try:
|
||||
with open(script_file_path, "r") as file:
|
||||
print(file.read())
|
||||
except FileNotFoundError:
|
||||
print(f"Migration script {script_file_path} not found.")
|
||||
|
||||
|
||||
def _get_migration_config(
|
||||
alembic_ini_path: Optional[str] = None, script_location: Optional[str] = None
|
||||
):
|
||||
from dbgpt.storage.metadata.db_manager import db as db_manager
|
||||
from dbgpt.util._db_migration_utils import create_alembic_config
|
||||
|
||||
# Must import dbgpt_server for initialize db metadata
|
||||
from dbgpt.app.dbgpt_server import initialize_app as _
|
||||
from dbgpt.app.base import _initialize_db
|
||||
|
||||
# initialize db
|
||||
default_meta_data_path = _initialize_db()
|
||||
alembic_cfg = create_alembic_config(
|
||||
default_meta_data_path,
|
||||
db_manager.engine,
|
||||
db_manager.Model,
|
||||
db_manager.session(),
|
||||
alembic_ini_path,
|
||||
script_location,
|
||||
)
|
||||
return alembic_cfg, db_manager
|
||||
|
@@ -8,7 +8,8 @@ from dataclasses import dataclass, field
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util.parameter_utils import BaseParameters
|
||||
from dbgpt.storage.metadata.meta_data import ddl_init_and_upgrade
|
||||
|
||||
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@@ -36,8 +37,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
|
||||
# init config
|
||||
cfg = Config()
|
||||
cfg.SYSTEM_APP = system_app
|
||||
|
||||
ddl_init_and_upgrade(param.disable_alembic_upgrade)
|
||||
# Initialize db storage first
|
||||
_initialize_db_storage(param)
|
||||
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
@@ -83,6 +84,46 @@ def _create_model_start_listener(system_app: SystemApp):
|
||||
return startup_event
|
||||
|
||||
|
||||
def _initialize_db_storage(param: "WebServerParameters"):
|
||||
"""Initialize the db storage.
|
||||
|
||||
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
|
||||
"""
|
||||
default_meta_data_path = _initialize_db(
|
||||
try_to_create_db=not param.disable_alembic_upgrade
|
||||
)
|
||||
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
|
||||
|
||||
|
||||
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
"""Initialize the database
|
||||
|
||||
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
|
||||
"""
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.storage.metadata.db_manager import initialize_db
|
||||
from urllib.parse import quote_plus as urlquote, quote
|
||||
|
||||
CFG = Config()
|
||||
db_name = CFG.LOCAL_DB_NAME
|
||||
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
|
||||
os.makedirs(default_meta_data_path, exist_ok=True)
|
||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
|
||||
else:
|
||||
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
|
||||
db_url = f"sqlite:///{sqlite_db_path}"
|
||||
engine_args = {
|
||||
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
|
||||
"max_overflow": CFG.LOCAL_DB_POOL_OVERFLOW,
|
||||
"pool_timeout": 30,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
|
||||
return default_meta_data_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebServerParameters(BaseParameters):
|
||||
host: Optional[str] = field(
|
||||
|
@@ -13,7 +13,6 @@ from dbgpt.app.base import WebServerParameters
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
@@ -3,19 +3,13 @@ from typing import List
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DocumentChunkEntity(Base):
|
||||
class DocumentChunkEntity(Model):
|
||||
__tablename__ = "document_chunk"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -35,16 +29,8 @@ class DocumentChunkEntity(Base):
|
||||
|
||||
|
||||
class DocumentChunkDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_documents_chunks(self, documents: List):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
docs = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=document.doc_name,
|
||||
@@ -64,7 +50,7 @@ class DocumentChunkDao(BaseDao):
|
||||
def get_document_chunks(
|
||||
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
|
||||
):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
document_chunks = session.query(DocumentChunkEntity)
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
@@ -102,7 +88,7 @@ class DocumentChunkDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
@@ -127,7 +113,7 @@ class DocumentChunkDao(BaseDao):
|
||||
return count
|
||||
|
||||
def delete(self, document_id: int):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if document_id is None:
|
||||
raise Exception("document_id is None")
|
||||
query = DocumentChunkEntity(document_id=document_id)
|
||||
|
@@ -2,19 +2,13 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeDocumentEntity(Base):
|
||||
class KnowledgeDocumentEntity(Model):
|
||||
__tablename__ = "knowledge_document"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -39,16 +33,8 @@ class KnowledgeDocumentEntity(Base):
|
||||
|
||||
|
||||
class KnowledgeDocumentDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_document = KnowledgeDocumentEntity(
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
@@ -69,7 +55,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return doc_id
|
||||
|
||||
def get_knowledge_documents(self, query, page=1, page_size=20):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
@@ -104,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_documents(self, query):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
@@ -136,7 +122,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_knowledge_documents_count_bulk(self, space_names):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
"""
|
||||
Perform a batch query to count the number of documents for each knowledge space.
|
||||
|
||||
@@ -161,7 +147,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return docs_count
|
||||
|
||||
def get_knowledge_documents_count(self, query):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
@@ -188,14 +174,14 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return count
|
||||
|
||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
updated_space = session.merge(document)
|
||||
session.commit()
|
||||
return updated_space.id
|
||||
|
||||
#
|
||||
def delete(self, query: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
|
@@ -2,20 +2,14 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeSpaceEntity(Base):
|
||||
class KnowledgeSpaceEntity(Model):
|
||||
__tablename__ = "knowledge_space"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -35,16 +29,8 @@ class KnowledgeSpaceEntity(Base):
|
||||
|
||||
|
||||
class KnowledgeSpaceDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
vector_type=CFG.VECTOR_STORE_TYPE,
|
||||
@@ -58,7 +44,7 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
@@ -97,14 +83,14 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
session.merge(space)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
||||
|
||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if space:
|
||||
session.delete(space)
|
||||
session.commit()
|
||||
|
@@ -2,17 +2,12 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
|
||||
|
||||
class ChatFeedBackEntity(Base):
|
||||
class ChatFeedBackEntity(Model):
|
||||
__tablename__ = "chat_feed_back"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -39,18 +34,10 @@ class ChatFeedBackEntity(Base):
|
||||
|
||||
|
||||
class ChatFeedBackDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||
# Todo: We need to have user information first.
|
||||
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
chat_feed_back = ChatFeedBackEntity(
|
||||
conv_uid=feed_back.conv_uid,
|
||||
conv_index=feed_back.conv_index,
|
||||
@@ -84,7 +71,7 @@ class ChatFeedBackDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||
|
@@ -2,13 +2,8 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
from dbgpt.app.prompt.request.request import PromptManageRequest
|
||||
@@ -16,7 +11,7 @@ from dbgpt.app.prompt.request.request import PromptManageRequest
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class PromptManageEntity(Base):
|
||||
class PromptManageEntity(Model):
|
||||
__tablename__ = "prompt_manage"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -38,16 +33,8 @@ class PromptManageEntity(Base):
|
||||
|
||||
|
||||
class PromptManageDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_prompt(self, prompt: PromptManageRequest):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
prompt_manage = PromptManageEntity(
|
||||
chat_scene=prompt.chat_scene,
|
||||
sub_chat_scene=prompt.sub_chat_scene,
|
||||
@@ -64,7 +51,7 @@ class PromptManageDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_prompts(self, query: PromptManageEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
prompts = session.query(PromptManageEntity)
|
||||
if query.chat_scene is not None:
|
||||
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
||||
@@ -93,13 +80,13 @@ class PromptManageDao(BaseDao):
|
||||
return result
|
||||
|
||||
def update_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
session.merge(prompt)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def delete_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if prompt:
|
||||
session.delete(prompt)
|
||||
session.commit()
|
||||
|
@@ -146,7 +146,9 @@ class BaseChat(ABC):
|
||||
input_values = await self.generate_input_values()
|
||||
### Chat sequence advance
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self.current_user_input)
|
||||
self.current_message.add_user_message(
|
||||
self.current_user_input, check_duplicate_type=True
|
||||
)
|
||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
@@ -221,7 +223,7 @@ class BaseChat(ABC):
|
||||
view_msg = self.stream_plugin_call(msg)
|
||||
view_msg = view_msg.replace("\n", "\\n")
|
||||
yield view_msg
|
||||
self.current_message.add_ai_message(msg)
|
||||
self.current_message.add_ai_message(msg, update_if_exist=True)
|
||||
view_msg = self.stream_call_reinforce_fn(view_msg)
|
||||
self.current_message.add_view_message(view_msg)
|
||||
span.end()
|
||||
@@ -257,7 +259,7 @@ class BaseChat(ABC):
|
||||
)
|
||||
)
|
||||
### model result deal
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
@@ -320,7 +322,7 @@ class BaseChat(ABC):
|
||||
)
|
||||
)
|
||||
### model result deal
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
|
||||
prompt_define_response = None
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
@@ -596,7 +598,7 @@ def _load_system_message(
|
||||
prompt_template: PromptTemplate,
|
||||
str_message: bool = True,
|
||||
):
|
||||
system_convs = current_message.get_system_conv()
|
||||
system_convs = current_message.get_system_messages()
|
||||
system_text = ""
|
||||
system_messages = []
|
||||
for system_conv in system_convs:
|
||||
@@ -614,7 +616,7 @@ def _load_user_message(
|
||||
prompt_template: PromptTemplate,
|
||||
str_message: bool = True,
|
||||
):
|
||||
user_conv = current_message.get_user_conv()
|
||||
user_conv = current_message.get_latest_user_message()
|
||||
user_messages = []
|
||||
if user_conv:
|
||||
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
|
||||
|
@@ -70,7 +70,9 @@ class ChatHistoryManager:
|
||||
|
||||
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self._chat_ctx.current_user_input)
|
||||
self.current_message.add_user_message(
|
||||
self._chat_ctx.current_user_input, check_duplicate_type=True
|
||||
)
|
||||
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
Reference in New Issue
Block a user