refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -1,188 +1,238 @@
-- You can change `dbgpt` to your actual metadata database name in your `.env` file -- You can change `dbgpt` to your actual metadata database name in your `.env` file
-- eg. `LOCAL_DB_NAME=dbgpt` -- eg. `LOCAL_DB_NAME=dbgpt`
CREATE DATABASE IF NOT EXISTS dbgpt; CREATE
DATABASE IF NOT EXISTS dbgpt;
use dbgpt; use dbgpt;
-- For alembic migration tool -- For alembic migration tool
CREATE TABLE `alembic_version` ( CREATE TABLE IF NOT EXISTS `alembic_version`
version_num VARCHAR(32) NOT NULL, (
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) version_num VARCHAR(32) NOT NULL,
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
); );
CREATE TABLE `knowledge_space` ( CREATE TABLE IF NOT EXISTS `knowledge_space`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', (
`name` varchar(100) NOT NULL COMMENT 'knowledge space name', `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`vector_type` varchar(50) NOT NULL COMMENT 'vector type', `name` varchar(100) NOT NULL COMMENT 'knowledge space name',
`desc` varchar(500) NOT NULL COMMENT 'description', `vector_type` varchar(50) NOT NULL COMMENT 'vector type',
`owner` varchar(100) DEFAULT NULL COMMENT 'owner', `desc` varchar(500) NOT NULL COMMENT 'description',
`context` TEXT DEFAULT NULL COMMENT 'context argument', `owner` varchar(100) DEFAULT NULL COMMENT 'owner',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `context` TEXT DEFAULT NULL COMMENT 'context argument',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
PRIMARY KEY (`id`), `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
KEY `idx_name` (`name`) COMMENT 'index:idx_name' PRIMARY KEY (`id`),
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; KEY `idx_name` (`name`) COMMENT 'index:idx_name'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table';
CREATE TABLE `knowledge_document` ( CREATE TABLE IF NOT EXISTS `knowledge_document`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', (
`doc_name` varchar(100) NOT NULL COMMENT 'document path name', `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type', `doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`space` varchar(50) NOT NULL COMMENT 'knowledge space', `doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`chunk_size` int NOT NULL COMMENT 'chunk size', `space` varchar(50) NOT NULL COMMENT 'knowledge space',
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', `chunk_size` int NOT NULL COMMENT 'chunk size',
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
`result` TEXT NULL COMMENT 'knowledge content', `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids', `result` TEXT NULL COMMENT 'knowledge content',
`summary` LONGTEXT NULL COMMENT 'knowledge summary', `vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `summary` LONGTEXT NULL COMMENT 'knowledge summary',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
PRIMARY KEY (`id`), `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' PRIMARY KEY (`id`),
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table';
CREATE TABLE `document_chunk` ( CREATE TABLE IF NOT EXISTS `document_chunk`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', (
`doc_name` varchar(100) NOT NULL COMMENT 'document path name', `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
`doc_type` varchar(50) NOT NULL COMMENT 'doc type', `doc_name` varchar(100) NOT NULL COMMENT 'document path name',
`document_id` int NOT NULL COMMENT 'document parent id', `doc_type` varchar(50) NOT NULL COMMENT 'doc type',
`content` longtext NOT NULL COMMENT 'chunk content', `document_id` int NOT NULL COMMENT 'document parent id',
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info', `content` longtext NOT NULL COMMENT 'chunk content',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
PRIMARY KEY (`id`), `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' PRIMARY KEY (`id`),
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail';
CREATE TABLE `connect_config` ( CREATE TABLE IF NOT EXISTS `connect_config`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', (
`db_type` varchar(255) NOT NULL COMMENT 'db type', `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`db_name` varchar(255) NOT NULL COMMENT 'db name', `db_type` varchar(255) NOT NULL COMMENT 'db type',
`db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', `db_name` varchar(255) NOT NULL COMMENT 'db name',
`db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path',
`db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)',
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user', `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)',
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', `db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
`comment` text COMMENT 'db comment', `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `comment` text COMMENT 'db comment',
PRIMARY KEY (`id`), `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
UNIQUE KEY `uk_db` (`db_name`), PRIMARY KEY (`id`),
KEY `idx_q_db_type` (`db_type`) UNIQUE KEY `uk_db` (`db_name`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; KEY `idx_q_db_type` (`db_type`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi';
CREATE TABLE `chat_history` ( CREATE TABLE IF NOT EXISTS `chat_history`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', (
`conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', `summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
PRIMARY KEY (`id`) `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma',
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
UNIQUE KEY `conv_uid` (`conv_uid`),
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
CREATE TABLE `chat_feed_back` ( CREATE TABLE IF NOT EXISTS `chat_history_message`
`id` bigint(20) NOT NULL AUTO_INCREMENT, (
`conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id',
`score` int(1) DEFAULT NULL COMMENT 'Score of user', `index` int NOT NULL COMMENT 'Message index',
`ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', `round_index` int NOT NULL COMMENT 'Round of conversation',
`question` longtext DEFAULT NULL COMMENT 'User question', `message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name', UNIQUE KEY `message_uid_index` (`conv_uid`, `index`),
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', PRIMARY KEY (`id`)
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message';
PRIMARY KEY (`id`),
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), CREATE TABLE IF NOT EXISTS `chat_feed_back`
KEY `idx_conv` (`conv_uid`,`conv_index`) (
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; `id` bigint(20) NOT NULL AUTO_INCREMENT,
`conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID',
`conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation',
`score` int(1) DEFAULT NULL COMMENT 'Score of user',
`ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category',
`question` longtext DEFAULT NULL COMMENT 'User question',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name',
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`),
KEY `idx_conv` (`conv_uid`,`conv_index`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table';
CREATE TABLE `my_plugin` ( CREATE TABLE IF NOT EXISTS `my_plugin`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', (
`tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code',
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name',
`file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name',
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
`use_count` int DEFAULT NULL COMMENT 'plugin total use count', `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count', `use_count` int DEFAULT NULL COMMENT 'plugin total use count',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`), `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
UNIQUE KEY `name` (`name`) PRIMARY KEY (`id`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; UNIQUE KEY `name` (`name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table';
CREATE TABLE `plugin_hub` ( CREATE TABLE IF NOT EXISTS `plugin_hub`
`id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', (
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name',
`author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description',
`email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author',
`type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email',
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type',
`storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
`storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel',
`download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param',
`installed` int DEFAULT NULL COMMENT 'plugin already installed count', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time',
PRIMARY KEY (`id`), `installed` int DEFAULT NULL COMMENT 'plugin already installed count',
UNIQUE KEY `name` (`name`) PRIMARY KEY (`id`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; UNIQUE KEY `name` (`name`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table';
CREATE TABLE `prompt_manage` ( CREATE TABLE IF NOT EXISTS `prompt_manage`
`id` int(11) NOT NULL AUTO_INCREMENT, (
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', `id` int(11) NOT NULL AUTO_INCREMENT,
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
PRIMARY KEY (`id`), `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
UNIQUE KEY `prompt_name_uiq` (`prompt_name`), PRIMARY KEY (`id`),
KEY `gmt_created_idx` (`gmt_created`) UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; KEY `gmt_created_idx` (`gmt_created`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';
CREATE DATABASE EXAMPLE_1; CREATE
DATABASE IF NOT EXISTS EXAMPLE_1;
use EXAMPLE_1; use EXAMPLE_1;
CREATE TABLE `users` ( CREATE TABLE IF NOT EXISTS `users`
`id` int NOT NULL AUTO_INCREMENT, (
`username` varchar(50) NOT NULL COMMENT '用户名', `id` int NOT NULL AUTO_INCREMENT,
`password` varchar(50) NOT NULL COMMENT '密码', `username` varchar(50) NOT NULL COMMENT '用户名',
`email` varchar(50) NOT NULL COMMENT '邮箱', `password` varchar(50) NOT NULL COMMENT '密码',
`phone` varchar(20) DEFAULT NULL COMMENT '电话', `email` varchar(50) NOT NULL COMMENT '邮箱',
PRIMARY KEY (`id`), `phone` varchar(20) DEFAULT NULL COMMENT '电话',
KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' PRIMARY KEY (`id`),
) ENGINE=InnoDB AUTO_INCREMENT=101 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; KEY `idx_username` (`username`) COMMENT '索引:按用户名查询'
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表';
INSERT INTO users (username, password, email, phone) VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901');
INSERT INTO users (username, password, email, phone) VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902');
INSERT INTO users (username, password, email, phone) VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903');
INSERT INTO users (username, password, email, phone) VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904');
INSERT INTO users (username, password, email, phone) VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905');
INSERT INTO users (username, password, email, phone) VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906');
INSERT INTO users (username, password, email, phone) VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907');
INSERT INTO users (username, password, email, phone) VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908');
INSERT INTO users (username, password, email, phone) VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909');
INSERT INTO users (username, password, email, phone) VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); INSERT INTO users (username, password, email, phone)
INSERT INTO users (username, password, email, phone) VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900');
INSERT INTO users (username, password, email, phone)
VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901');
INSERT INTO users (username, password, email, phone)
VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902');
INSERT INTO users (username, password, email, phone)
VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903');
INSERT INTO users (username, password, email, phone)
VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904');
INSERT INTO users (username, password, email, phone)
VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905');
INSERT INTO users (username, password, email, phone)
VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906');
INSERT INTO users (username, password, email, phone)
VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907');
INSERT INTO users (username, password, email, phone)
VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908');
INSERT INTO users (username, password, email, phone)
VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909');
INSERT INTO users (username, password, email, phone)
VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900');

View File

@@ -182,6 +182,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10)) self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10))
self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20))
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db") self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db")

View File

@@ -2,16 +2,10 @@ from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, func from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
class MyPluginEntity(Base): class MyPluginEntity(Model):
__tablename__ = "my_plugin" __tablename__ = "my_plugin"
__table_args__ = { __table_args__ = {
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
@@ -39,16 +33,8 @@ class MyPluginEntity(Base):
class MyPluginDao(BaseDao[MyPluginEntity]): class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def add(self, engity: MyPluginEntity): def add(self, engity: MyPluginEntity):
session = self.get_session() session = self.get_raw_session()
my_plugin = MyPluginEntity( my_plugin = MyPluginEntity(
tenant=engity.tenant, tenant=engity.tenant,
user_code=engity.user_code, user_code=engity.user_code,
@@ -68,13 +54,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return id return id
def update(self, entity: MyPluginEntity): def update(self, entity: MyPluginEntity):
session = self.get_session() session = self.get_raw_session()
updated = session.merge(entity) updated = session.merge(entity)
session.commit() session.commit()
return updated.id return updated.id
def get_by_user(self, user: str) -> list[MyPluginEntity]: def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_session() session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if user: if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user) my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
@@ -83,7 +69,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return result return result
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity: def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_session() session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if user: if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user) my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
@@ -93,7 +79,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return result return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]: def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_session() session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count() all_count = my_plugins.count()
if query.id is not None: if query.id is not None:
@@ -122,7 +108,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return result, total_pages, all_count return result, total_pages, all_count
def count(self, query: MyPluginEntity): def count(self, query: MyPluginEntity):
session = self.get_session() session = self.get_raw_session()
my_plugins = session.query(func.count(MyPluginEntity.id)) my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
@@ -143,7 +129,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return count return count
def delete(self, plugin_id: int): def delete(self, plugin_id: int):
session = self.get_session() session = self.get_raw_session()
if plugin_id is None: if plugin_id is None:
raise Exception("plugin_id is None") raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id) query = MyPluginEntity(id=plugin_id)

View File

@@ -3,19 +3,13 @@ import pytz
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
# TODO We should consider that the production environment does not have permission to execute the DDL # TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4") char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
class PluginHubEntity(Base): class PluginHubEntity(Model):
__tablename__ = "plugin_hub" __tablename__ = "plugin_hub"
__table_args__ = { __table_args__ = {
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
@@ -43,16 +37,8 @@ class PluginHubEntity(Base):
class PluginHubDao(BaseDao[PluginHubEntity]): class PluginHubDao(BaseDao[PluginHubEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def add(self, engity: PluginHubEntity): def add(self, engity: PluginHubEntity):
session = self.get_session() session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai") timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity( plugin_hub = PluginHubEntity(
name=engity.name, name=engity.name,
@@ -71,7 +57,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return id return id
def update(self, entity: PluginHubEntity): def update(self, entity: PluginHubEntity):
session = self.get_session() session = self.get_raw_session()
try: try:
updated = session.merge(entity) updated = session.merge(entity)
session.commit() session.commit()
@@ -82,7 +68,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def list( def list(
self, query: PluginHubEntity, page=1, page_size=20 self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]: ) -> list[PluginHubEntity]:
session = self.get_session() session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity) plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count() all_count = plugin_hubs.count()
@@ -111,7 +97,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return result, total_pages, all_count return result, total_pages, all_count
def get_by_storage_url(self, storage_url): def get_by_storage_url(self, storage_url):
session = self.get_session() session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity) plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url) plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all() result = plugin_hubs.all()
@@ -119,7 +105,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return result return result
def get_by_name(self, name: str) -> PluginHubEntity: def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_session() session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity) plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name) plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first() result = plugin_hubs.first()
@@ -127,7 +113,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return result return result
def count(self, query: PluginHubEntity): def count(self, query: PluginHubEntity):
session = self.get_session() session = self.get_raw_session()
plugin_hubs = session.query(func.count(PluginHubEntity.id)) plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None: if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
@@ -146,7 +132,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return count return count
def delete(self, plugin_id: int): def delete(self, plugin_id: int):
session = self.get_session() session = self.get_raw_session()
if plugin_id is None: if plugin_id is None:
raise Exception("plugin_id is None") raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity) plugin_hubs = session.query(PluginHubEntity)

View File

@@ -59,18 +59,12 @@ class AgentHub:
else: else:
my_plugin_entity.user_code = Default_User my_plugin_entity.user_code = Default_User
with self.hub_dao.get_session() as session: with self.hub_dao.session() as session:
try: if my_plugin_entity.id is None:
if my_plugin_entity.id is None: session.add(my_plugin_entity)
session.add(my_plugin_entity) else:
else: session.merge(my_plugin_entity)
session.merge(my_plugin_entity) session.merge(plugin_entity)
session.merge(plugin_entity)
session.commit()
session.close()
except Exception as e:
logger.error("install merge roll back!" + str(e))
session.rollback()
except Exception as e: except Exception as e:
logger.error("install pluguin exception!", e) logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}") raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
@@ -87,19 +81,15 @@ class AgentHub:
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name) my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None: if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1 plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session: with self.hub_dao.session() as session:
try: my_plugin_q = session.query(MyPluginEntity).filter(
my_plugin_q = session.query(MyPluginEntity).filter( MyPluginEntity.name == plugin_name
MyPluginEntity.name == plugin_name )
) if user:
if user: my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.filter(MyPluginEntity.user_code == user) my_plugin_q.delete()
my_plugin_q.delete() if plugin_entity is not None:
if plugin_entity is not None: session.merge(plugin_entity)
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
if plugin_entity is not None: if plugin_entity is not None:
# delete package file if not use # delete package file if not use

View File

@@ -1,5 +1,7 @@
from typing import Optional
import click import click
import os import os
import functools
from dbgpt.app.base import WebServerParameters from dbgpt.app.base import WebServerParameters
from dbgpt.configs.model_config import LOGDIR from dbgpt.configs.model_config import LOGDIR
from dbgpt.util.parameter_utils import EnvArgumentParser from dbgpt.util.parameter_utils import EnvArgumentParser
@@ -34,3 +36,241 @@ def stop_webserver(port: int):
def _stop_all_dbgpt_server(): def _stop_all_dbgpt_server():
_stop_service("webserver", "WebServer") _stop_service("webserver", "WebServer")
@click.group("migration")
def migration():
"""Manage database migration"""
pass
def add_migration_options(func):
@click.option(
"--alembic_ini_path",
required=False,
type=str,
default=None,
show_default=True,
help="Alembic ini path, if not set, use 'pilot/meta_data/alembic.ini'",
)
@click.option(
"--script_location",
required=False,
type=str,
default=None,
show_default=True,
help="Alembic script location, if not set, use 'pilot/meta_data/alembic'",
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@migration.command()
@add_migration_options
@click.option(
"-m",
"--message",
required=False,
type=str,
default="Init migration",
show_default=True,
help="The message for create migration repository",
)
def init(alembic_ini_path: str, script_location: str, message: str):
"""Initialize database migration repository"""
from dbgpt.util._db_migration_utils import create_migration_script
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
create_migration_script(alembic_cfg, db_manager.engine, message)
@migration.command()
@add_migration_options
@click.option(
"-m",
"--message",
required=False,
type=str,
default="New migration",
show_default=True,
help="The message for migration script",
)
def migrate(alembic_ini_path: str, script_location: str, message: str):
"""Create migration script"""
from dbgpt.util._db_migration_utils import create_migration_script
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
create_migration_script(alembic_cfg, db_manager.engine, message)
@migration.command()
@add_migration_options
def upgrade(alembic_ini_path: str, script_location: str):
"""Upgrade database to target version"""
from dbgpt.util._db_migration_utils import upgrade_database
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
upgrade_database(alembic_cfg, db_manager.engine)
@migration.command()
@add_migration_options
@click.option(
"-y",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm to downgrade database",
)
@click.option(
"-r",
"--revision",
default="-1",
show_default=True,
help="Revision to downgrade to",
)
def downgrade(alembic_ini_path: str, script_location: str, y: bool, revision: str):
"""Downgrade database to target version"""
from dbgpt.util._db_migration_utils import downgrade_database
if not y:
click.confirm("Are you sure you want to downgrade the database?", abort=True)
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
downgrade_database(alembic_cfg, db_manager.engine, revision)
@migration.command()
@add_migration_options
@click.option(
"--drop_all_tables",
required=False,
type=bool,
default=False,
is_flag=True,
help="Drop all tables",
)
@click.option(
"-y",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm to clean migration data",
)
@click.option(
"--confirm_drop_all_tables",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm to drop all tables",
)
def clean(
alembic_ini_path: str,
script_location: str,
drop_all_tables: bool,
y: bool,
confirm_drop_all_tables: bool,
):
"""Clean Alembic migration scripts and history"""
from dbgpt.util._db_migration_utils import clean_alembic_migration
if not y:
click.confirm(
"Are you sure clean alembic migration scripts and history?", abort=True
)
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
clean_alembic_migration(alembic_cfg, db_manager.engine)
if drop_all_tables:
if not confirm_drop_all_tables:
click.confirm("\nAre you sure drop all tables?", abort=True)
with db_manager.engine.connect() as connection:
for tbl in reversed(db_manager.Model.metadata.sorted_tables):
print(f"Drop table {tbl.name}")
connection.execute(tbl.delete())
@migration.command()
@add_migration_options
def list(alembic_ini_path: str, script_location: str):
"""List all versions in the migration history, marking the current one"""
from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
# Set up Alembic environment and script directory
script = ScriptDirectory.from_config(alembic_cfg)
# Get current revision
def get_current_revision():
with db_manager.engine.connect() as connection:
context = MigrationContext.configure(connection)
return context.get_current_revision()
current_rev = get_current_revision()
# List all revisions and mark the current one
for revision in script.walk_revisions():
current_marker = "(current)" if revision.revision == current_rev else ""
print(f"{revision.revision} {current_marker}: {revision.doc}")
@migration.command()
@add_migration_options
@click.argument("revision", required=True)
def show(alembic_ini_path: str, script_location: str, revision: str):
"""Show the migration script for a specific version."""
from alembic.script import ScriptDirectory
alembic_cfg, db_manager = _get_migration_config(alembic_ini_path, script_location)
script = ScriptDirectory.from_config(alembic_cfg)
rev = script.get_revision(revision)
if rev is None:
print(f"Revision {revision} not found.")
return
# Find the migration script file
script_files = os.listdir(os.path.join(script.dir, "versions"))
script_file = next((f for f in script_files if f.startswith(revision)), None)
if script_file is None:
print(f"Migration script for revision {revision} not found.")
return
# Print the migration script
script_file_path = os.path.join(script.dir, "versions", script_file)
print(f"Migration script for revision {revision}: {script_file_path}")
try:
with open(script_file_path, "r") as file:
print(file.read())
except FileNotFoundError:
print(f"Migration script {script_file_path} not found.")
def _get_migration_config(
alembic_ini_path: Optional[str] = None, script_location: Optional[str] = None
):
from dbgpt.storage.metadata.db_manager import db as db_manager
from dbgpt.util._db_migration_utils import create_alembic_config
# Must import dbgpt_server for initialize db metadata
from dbgpt.app.dbgpt_server import initialize_app as _
from dbgpt.app.base import _initialize_db
# initialize db
default_meta_data_path = _initialize_db()
alembic_cfg = create_alembic_config(
default_meta_data_path,
db_manager.engine,
db_manager.Model,
db_manager.session(),
alembic_ini_path,
script_location,
)
return alembic_cfg, db_manager

View File

@@ -8,7 +8,8 @@ from dataclasses import dataclass, field
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
from dbgpt.util.parameter_utils import BaseParameters from dbgpt.util.parameter_utils import BaseParameters
from dbgpt.storage.metadata.meta_data import ddl_init_and_upgrade
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
@@ -36,8 +37,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
# init config # init config
cfg = Config() cfg = Config()
cfg.SYSTEM_APP = system_app cfg.SYSTEM_APP = system_app
# Initialize db storage first
ddl_init_and_upgrade(param.disable_alembic_upgrade) _initialize_db_storage(param)
# load_native_plugins(cfg) # load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
@@ -83,6 +84,46 @@ def _create_model_start_listener(system_app: SystemApp):
return startup_event return startup_event
def _initialize_db_storage(param: "WebServerParameters"):
"""Initialize the db storage.
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
default_meta_data_path = _initialize_db(
try_to_create_db=not param.disable_alembic_upgrade
)
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
"""Initialize the database
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.storage.metadata.db_manager import initialize_db
from urllib.parse import quote_plus as urlquote, quote
CFG = Config()
db_name = CFG.LOCAL_DB_NAME
default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
os.makedirs(default_meta_data_path, exist_ok=True)
if CFG.LOCAL_DB_TYPE == "mysql":
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
else:
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
db_url = f"sqlite:///{sqlite_db_path}"
engine_args = {
"pool_size": CFG.LOCAL_DB_POOL_SIZE,
"max_overflow": CFG.LOCAL_DB_POOL_OVERFLOW,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
}
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
return default_meta_data_path
@dataclass @dataclass
class WebServerParameters(BaseParameters): class WebServerParameters(BaseParameters):
host: Optional[str] = field( host: Optional[str] = field(

View File

@@ -13,7 +13,6 @@ from dbgpt.app.base import WebServerParameters
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@@ -3,19 +3,13 @@ from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt._private.config import Config from dbgpt._private.config import Config
CFG = Config() CFG = Config()
class DocumentChunkEntity(Base): class DocumentChunkEntity(Model):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
__table_args__ = { __table_args__ = {
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
@@ -35,16 +29,8 @@ class DocumentChunkEntity(Base):
class DocumentChunkDao(BaseDao): class DocumentChunkDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_documents_chunks(self, documents: List): def create_documents_chunks(self, documents: List):
session = self.get_session() session = self.get_raw_session()
docs = [ docs = [
DocumentChunkEntity( DocumentChunkEntity(
doc_name=document.doc_name, doc_name=document.doc_name,
@@ -64,7 +50,7 @@ class DocumentChunkDao(BaseDao):
def get_document_chunks( def get_document_chunks(
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
): ):
session = self.get_session() session = self.get_raw_session()
document_chunks = session.query(DocumentChunkEntity) document_chunks = session.query(DocumentChunkEntity)
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@@ -102,7 +88,7 @@ class DocumentChunkDao(BaseDao):
return result return result
def get_document_chunks_count(self, query: DocumentChunkEntity): def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.get_session() session = self.get_raw_session()
document_chunks = session.query(func.count(DocumentChunkEntity.id)) document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@@ -127,7 +113,7 @@ class DocumentChunkDao(BaseDao):
return count return count
def delete(self, document_id: int): def delete(self, document_id: int):
session = self.get_session() session = self.get_raw_session()
if document_id is None: if document_id is None:
raise Exception("document_id is None") raise Exception("document_id is None")
query = DocumentChunkEntity(document_id=document_id) query = DocumentChunkEntity(document_id=document_id)

View File

@@ -2,19 +2,13 @@ from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt._private.config import Config from dbgpt._private.config import Config
CFG = Config() CFG = Config()
class KnowledgeDocumentEntity(Base): class KnowledgeDocumentEntity(Model):
__tablename__ = "knowledge_document" __tablename__ = "knowledge_document"
__table_args__ = { __table_args__ = {
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
@@ -39,16 +33,8 @@ class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentDao(BaseDao): class KnowledgeDocumentDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_knowledge_document(self, document: KnowledgeDocumentEntity): def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.get_session() session = self.get_raw_session()
knowledge_document = KnowledgeDocumentEntity( knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name, doc_name=document.doc_name,
doc_type=document.doc_type, doc_type=document.doc_type,
@@ -69,7 +55,7 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20): def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.get_session() session = self.get_raw_session()
print(f"current session:{session}") print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
@@ -104,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
return result return result
def get_documents(self, query): def get_documents(self, query):
session = self.get_session() session = self.get_raw_session()
print(f"current session:{session}") print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
@@ -136,7 +122,7 @@ class KnowledgeDocumentDao(BaseDao):
return result return result
def get_knowledge_documents_count_bulk(self, space_names): def get_knowledge_documents_count_bulk(self, space_names):
session = self.get_session() session = self.get_raw_session()
""" """
Perform a batch query to count the number of documents for each knowledge space. Perform a batch query to count the number of documents for each knowledge space.
@@ -161,7 +147,7 @@ class KnowledgeDocumentDao(BaseDao):
return docs_count return docs_count
def get_knowledge_documents_count(self, query): def get_knowledge_documents_count(self, query):
session = self.get_session() session = self.get_raw_session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(
@@ -188,14 +174,14 @@ class KnowledgeDocumentDao(BaseDao):
return count return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity): def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.get_session() session = self.get_raw_session()
updated_space = session.merge(document) updated_space = session.merge(document)
session.commit() session.commit()
return updated_space.id return updated_space.id
# #
def delete(self, query: KnowledgeDocumentEntity): def delete(self, query: KnowledgeDocumentEntity):
session = self.get_session() session = self.get_raw_session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(

View File

@@ -2,20 +2,14 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
CFG = Config() CFG = Config()
class KnowledgeSpaceEntity(Base): class KnowledgeSpaceEntity(Model):
__tablename__ = "knowledge_space" __tablename__ = "knowledge_space"
__table_args__ = { __table_args__ = {
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
@@ -35,16 +29,8 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao(BaseDao): class KnowledgeSpaceDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_knowledge_space(self, space: KnowledgeSpaceRequest): def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.get_session() session = self.get_raw_session()
knowledge_space = KnowledgeSpaceEntity( knowledge_space = KnowledgeSpaceEntity(
name=space.name, name=space.name,
vector_type=CFG.VECTOR_STORE_TYPE, vector_type=CFG.VECTOR_STORE_TYPE,
@@ -58,7 +44,7 @@ class KnowledgeSpaceDao(BaseDao):
session.close() session.close()
def get_knowledge_space(self, query: KnowledgeSpaceEntity): def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.get_session() session = self.get_raw_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity) knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None: if query.id is not None:
knowledge_spaces = knowledge_spaces.filter( knowledge_spaces = knowledge_spaces.filter(
@@ -97,14 +83,14 @@ class KnowledgeSpaceDao(BaseDao):
return result return result
def update_knowledge_space(self, space: KnowledgeSpaceEntity): def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.get_session() session = self.get_raw_session()
session.merge(space) session.merge(space)
session.commit() session.commit()
session.close() session.close()
return True return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity): def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.get_session() session = self.get_raw_session()
if space: if space:
session.delete(space) session.delete(space)
session.commit() session.commit()

View File

@@ -2,17 +2,12 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
class ChatFeedBackEntity(Base): class ChatFeedBackEntity(Model):
__tablename__ = "chat_feed_back" __tablename__ = "chat_feed_back"
__table_args__ = { __table_args__ = {
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
@@ -39,18 +34,10 @@ class ChatFeedBackEntity(Base):
class ChatFeedBackDao(BaseDao): class ChatFeedBackDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first. # Todo: We need to have user information first.
session = self.get_session() session = self.get_raw_session()
chat_feed_back = ChatFeedBackEntity( chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid, conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index, conv_index=feed_back.conv_index,
@@ -84,7 +71,7 @@ class ChatFeedBackDao(BaseDao):
session.close() session.close()
def get_chat_feed_back(self, conv_uid: str, conv_index: int): def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.get_session() session = self.get_raw_session()
result = ( result = (
session.query(ChatFeedBackEntity) session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid) .filter(ChatFeedBackEntity.conv_uid == conv_uid)

View File

@@ -2,13 +2,8 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.app.prompt.request.request import PromptManageRequest from dbgpt.app.prompt.request.request import PromptManageRequest
@@ -16,7 +11,7 @@ from dbgpt.app.prompt.request.request import PromptManageRequest
CFG = Config() CFG = Config()
class PromptManageEntity(Base): class PromptManageEntity(Model):
__tablename__ = "prompt_manage" __tablename__ = "prompt_manage"
__table_args__ = { __table_args__ = {
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
@@ -38,16 +33,8 @@ class PromptManageEntity(Base):
class PromptManageDao(BaseDao): class PromptManageDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_prompt(self, prompt: PromptManageRequest): def create_prompt(self, prompt: PromptManageRequest):
session = self.get_session() session = self.get_raw_session()
prompt_manage = PromptManageEntity( prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene, chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene, sub_chat_scene=prompt.sub_chat_scene,
@@ -64,7 +51,7 @@ class PromptManageDao(BaseDao):
session.close() session.close()
def get_prompts(self, query: PromptManageEntity): def get_prompts(self, query: PromptManageEntity):
session = self.get_session() session = self.get_raw_session()
prompts = session.query(PromptManageEntity) prompts = session.query(PromptManageEntity)
if query.chat_scene is not None: if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
@@ -93,13 +80,13 @@ class PromptManageDao(BaseDao):
return result return result
def update_prompt(self, prompt: PromptManageEntity): def update_prompt(self, prompt: PromptManageEntity):
session = self.get_session() session = self.get_raw_session()
session.merge(prompt) session.merge(prompt)
session.commit() session.commit()
session.close() session.close()
def delete_prompt(self, prompt: PromptManageEntity): def delete_prompt(self, prompt: PromptManageEntity):
session = self.get_session() session = self.get_raw_session()
if prompt: if prompt:
session.delete(prompt) session.delete(prompt)
session.commit() session.commit()

View File

@@ -146,7 +146,9 @@ class BaseChat(ABC):
input_values = await self.generate_input_values() input_values = await self.generate_input_values()
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(
self.current_user_input, check_duplicate_type=True
)
self.current_message.start_date = datetime.datetime.now().strftime( self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
) )
@@ -221,7 +223,7 @@ class BaseChat(ABC):
view_msg = self.stream_plugin_call(msg) view_msg = self.stream_plugin_call(msg)
view_msg = view_msg.replace("\n", "\\n") view_msg = view_msg.replace("\n", "\\n")
yield view_msg yield view_msg
self.current_message.add_ai_message(msg) self.current_message.add_ai_message(msg, update_if_exist=True)
view_msg = self.stream_call_reinforce_fn(view_msg) view_msg = self.stream_call_reinforce_fn(view_msg)
self.current_message.add_view_message(view_msg) self.current_message.add_view_message(view_msg)
span.end() span.end()
@@ -257,7 +259,7 @@ class BaseChat(ABC):
) )
) )
### model result deal ### model result deal
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
prompt_define_response = ( prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response( self.prompt_template.output_parser.parse_prompt_response(
ai_response_text ai_response_text
@@ -320,7 +322,7 @@ class BaseChat(ABC):
) )
) )
### model result deal ### model result deal
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text, update_if_exist=True)
prompt_define_response = None prompt_define_response = None
prompt_define_response = ( prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response( self.prompt_template.output_parser.parse_prompt_response(
@@ -596,7 +598,7 @@ def _load_system_message(
prompt_template: PromptTemplate, prompt_template: PromptTemplate,
str_message: bool = True, str_message: bool = True,
): ):
system_convs = current_message.get_system_conv() system_convs = current_message.get_system_messages()
system_text = "" system_text = ""
system_messages = [] system_messages = []
for system_conv in system_convs: for system_conv in system_convs:
@@ -614,7 +616,7 @@ def _load_user_message(
prompt_template: PromptTemplate, prompt_template: PromptTemplate,
str_message: bool = True, str_message: bool = True,
): ):
user_conv = current_message.get_user_conv() user_conv = current_message.get_latest_user_message()
user_messages = [] user_messages = []
if user_conv: if user_conv:
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep

View File

@@ -70,7 +70,9 @@ class ChatHistoryManager:
def _new_chat(self, input_values: Dict) -> List[ModelMessage]: def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self._chat_ctx.current_user_input) self.current_message.add_user_message(
self._chat_ctx.current_user_input, check_duplicate_type=True
)
self.current_message.start_date = datetime.datetime.now().strftime( self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
) )

View File

@@ -51,6 +51,12 @@ def install():
pass pass
@click.group()
def db():
"""Manage your metadata database and your datasources."""
pass
stop_all_func_list = [] stop_all_func_list = []
@@ -64,6 +70,7 @@ def stop_all():
cli.add_command(start) cli.add_command(start)
cli.add_command(stop) cli.add_command(stop)
cli.add_command(install) cli.add_command(install)
cli.add_command(db)
add_command_alias(stop_all, name="all", parent_group=stop) add_command_alias(stop_all, name="all", parent_group=stop)
try: try:
@@ -96,10 +103,13 @@ try:
start_webserver, start_webserver,
stop_webserver, stop_webserver,
_stop_all_dbgpt_server, _stop_all_dbgpt_server,
migration,
) )
add_command_alias(start_webserver, name="webserver", parent_group=start) add_command_alias(start_webserver, name="webserver", parent_group=start)
add_command_alias(stop_webserver, name="webserver", parent_group=stop) add_command_alias(stop_webserver, name="webserver", parent_group=stop)
# Add migration command
add_command_alias(migration, name="migration", parent_group=db)
stop_all_func_list.append(_stop_all_dbgpt_server) stop_all_func_list.append(_stop_all_dbgpt_server)
except ImportError as e: except ImportError as e:

View File

@@ -9,6 +9,10 @@ from dbgpt.core.interface.message import (
ModelMessage, ModelMessage,
ModelMessageRoleType, ModelMessageRoleType,
OnceConversation, OnceConversation,
StorageConversation,
MessageStorageItem,
ConversationIdentifier,
MessageIdentifier,
) )
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
@@ -20,6 +24,16 @@ from dbgpt.core.interface.cache import (
CachePolicy, CachePolicy,
CacheConfig, CacheConfig,
) )
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageItem,
StorageItemAdapter,
StorageInterface,
InMemoryStorage,
DefaultStorageItemAdapter,
QuerySpec,
StorageError,
)
__ALL__ = [ __ALL__ = [
"ModelInferenceMetrics", "ModelInferenceMetrics",
@@ -30,6 +44,10 @@ __ALL__ = [
"ModelMessage", "ModelMessage",
"ModelMessageRoleType", "ModelMessageRoleType",
"OnceConversation", "OnceConversation",
"StorageConversation",
"MessageStorageItem",
"ConversationIdentifier",
"MessageIdentifier",
"PromptTemplate", "PromptTemplate",
"PromptTemplateOperator", "PromptTemplateOperator",
"BaseOutputParser", "BaseOutputParser",
@@ -41,4 +59,12 @@ __ALL__ = [
"CacheClient", "CacheClient",
"CachePolicy", "CachePolicy",
"CacheConfig", "CacheConfig",
"ResourceIdentifier",
"StorageItem",
"StorageItemAdapter",
"StorageInterface",
"InMemoryStorage",
"DefaultStorageItemAdapter",
"QuerySpec",
"StorageError",
] ]

View File

@@ -1,7 +1,7 @@
import pytest import pytest
import threading import threading
import asyncio import asyncio
from ..dag import DAG, DAGContext from ..base import DAG, DAGVar
def test_dag_context_sync(): def test_dag_context_sync():
@@ -9,18 +9,18 @@ def test_dag_context_sync():
dag2 = DAG("dag2") dag2 = DAG("dag2")
with dag1: with dag1:
assert DAGContext.get_current_dag() == dag1 assert DAGVar.get_current_dag() == dag1
with dag2: with dag2:
assert DAGContext.get_current_dag() == dag2 assert DAGVar.get_current_dag() == dag2
assert DAGContext.get_current_dag() == dag1 assert DAGVar.get_current_dag() == dag1
assert DAGContext.get_current_dag() is None assert DAGVar.get_current_dag() is None
def test_dag_context_threading(): def test_dag_context_threading():
def thread_function(dag): def thread_function(dag):
DAGContext.enter_dag(dag) DAGVar.enter_dag(dag)
assert DAGContext.get_current_dag() == dag assert DAGVar.get_current_dag() == dag
DAGContext.exit_dag() DAGVar.exit_dag()
dag1 = DAG("dag1") dag1 = DAG("dag1")
dag2 = DAG("dag2") dag2 = DAG("dag2")
@@ -33,19 +33,19 @@ def test_dag_context_threading():
thread1.join() thread1.join()
thread2.join() thread2.join()
assert DAGContext.get_current_dag() is None assert DAGVar.get_current_dag() is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dag_context_async(): async def test_dag_context_async():
async def async_function(dag): async def async_function(dag):
DAGContext.enter_dag(dag) DAGVar.enter_dag(dag)
assert DAGContext.get_current_dag() == dag assert DAGVar.get_current_dag() == dag
DAGContext.exit_dag() DAGVar.exit_dag()
dag1 = DAG("dag1") dag1 = DAG("dag1")
dag2 = DAG("dag2") dag2 = DAG("dag2")
await asyncio.gather(async_function(dag1), async_function(dag2)) await asyncio.gather(async_function(dag1), async_function(dag2))
assert DAGContext.get_current_dag() is None assert DAGVar.get_current_dag() is None

View File

@@ -1,16 +1,26 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union, Optional
from datetime import datetime from datetime import datetime
from dbgpt._private.pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageItem,
StorageInterface,
InMemoryStorage,
)
class BaseMessage(BaseModel, ABC): class BaseMessage(BaseModel, ABC):
"""Message object.""" """Message object."""
content: str content: str
index: int = 0
round_index: int = 0
"""The round index of the message in the conversation"""
additional_kwargs: dict = Field(default_factory=dict) additional_kwargs: dict = Field(default_factory=dict)
@property @property
@@ -18,6 +28,24 @@ class BaseMessage(BaseModel, ABC):
def type(self) -> str: def type(self) -> str:
"""Type of the message, used for serialization.""" """Type of the message, used for serialization."""
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model"""
return True
def to_dict(self) -> Dict:
"""Convert to dict
Returns:
Dict: The dict object
"""
return {
"type": self.type,
"data": self.dict(),
"index": self.index,
"round_index": self.round_index,
}
class HumanMessage(BaseMessage): class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human.""" """Type of message that is spoken by the human."""
@@ -51,6 +79,14 @@ class ViewMessage(BaseMessage):
"""Type of the message, used for serialization.""" """Type of the message, used for serialization."""
return "view" return "view"
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
The view message will not be passed to the model
"""
return False
class SystemMessage(BaseMessage): class SystemMessage(BaseMessage):
"""Type of message that is a system message.""" """Type of message that is a system message."""
@@ -141,15 +177,15 @@ class ModelMessage(BaseModel):
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
def _message_to_dict(message: BaseMessage) -> dict: def _message_to_dict(message: BaseMessage) -> Dict:
return {"type": message.type, "data": message.dict()} return message.to_dict()
def _messages_to_dict(messages: List[BaseMessage]) -> List[dict]: def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
return [_message_to_dict(m) for m in messages] return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage: def _message_from_dict(message: Dict) -> BaseMessage:
_type = message["type"] _type = message["type"]
if _type == "human": if _type == "human":
return HumanMessage(**message["data"]) return HumanMessage(**message["data"])
@@ -163,7 +199,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
raise ValueError(f"Got unexpected type: {_type}") raise ValueError(f"Got unexpected type: {_type}")
def _messages_from_dict(messages: List[dict]) -> List[BaseMessage]: def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages] return [_message_from_dict(m) for m in messages]
@@ -193,50 +229,119 @@ def _parse_model_messages(
history_messages.append([]) history_messages.append([])
if messages[-1].role != "human": if messages[-1].role != "human":
raise ValueError("Hi! What do you want to talk about") raise ValueError("Hi! What do you want to talk about")
# Keep message pair of [user message, assistant message] # Keep message a pair of [user message, assistant message]
history_messages = list(filter(lambda x: len(x) == 2, history_messages)) history_messages = list(filter(lambda x: len(x) == 2, history_messages))
user_prompt = messages[-1].content user_prompt = messages[-1].content
return user_prompt, system_messages, history_messages return user_prompt, system_messages, history_messages
class OnceConversation: class OnceConversation:
""" """All the information of a conversation, the current single service in memory,
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services can expand cache and database support distributed services.
""" """
def __init__(self, chat_mode, user_name: str = None, sys_code: str = None): def __init__(
self,
chat_mode: str,
user_name: str = None,
sys_code: str = None,
summary: str = None,
**kwargs,
):
self.chat_mode: str = chat_mode self.chat_mode: str = chat_mode
self.messages: List[BaseMessage] = []
self.start_date: str = ""
self.chat_order: int = 0
self.model_name: str = ""
self.param_type: str = ""
self.param_value: str = ""
self.cost: int = 0
self.tokens: int = 0
self.user_name: str = user_name self.user_name: str = user_name
self.sys_code: str = sys_code self.sys_code: str = sys_code
self.summary: str = summary
def add_user_message(self, message: str) -> None: self.messages: List[BaseMessage] = kwargs.get("messages", [])
"""Add a user message to the store""" self.start_date: str = kwargs.get("start_date", "")
has_message = any( # After each complete round of dialogue, the current value will be increased by 1
isinstance(instance, HumanMessage) for instance in self.messages self.chat_order: int = int(kwargs.get("chat_order", 0))
) self.model_name: str = kwargs.get("model_name", "")
if has_message: self.param_type: str = kwargs.get("param_type", "")
raise ValueError("Already Have Human message") self.param_value: str = kwargs.get("param_value", "")
self.messages.append(HumanMessage(content=message)) self.cost: int = int(kwargs.get("cost", 0))
self.tokens: int = int(kwargs.get("tokens", 0))
self._message_index: int = int(kwargs.get("message_index", 0))
def add_ai_message(self, message: str) -> None: def _append_message(self, message: BaseMessage) -> None:
"""Add an AI message to the store""" index = self._message_index
self._message_index += 1
message.index = index
message.round_index = self.chat_order
self.messages.append(message)
def start_new_round(self) -> None:
"""Start a new round of conversation
Example:
>>> conversation = OnceConversation()
>>> # The chat order will be 0, then we start a new round of conversation
>>> assert conversation.chat_order == 0
>>> conversation.start_new_round()
>>> # Now the chat order will be 1
>>> assert conversation.chat_order == 1
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> # Now the chat order will be 1, then we start a new round of conversation
>>> conversation.start_new_round()
>>> # Now the chat order will be 2
>>> assert conversation.chat_order == 2
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> assert conversation.chat_order == 2
"""
self.chat_order += 1
def end_current_round(self) -> None:
"""End the current round of conversation
We do noting here, just for the interface
"""
pass
def add_user_message(
self, message: str, check_duplicate_type: Optional[bool] = False
) -> None:
"""Add a user message to the conversation
Args:
message (str): The message content
check_duplicate_type (bool): Whether to check the duplicate message type
Raises:
ValueError: If the message is duplicate and check_duplicate_type is True
"""
if check_duplicate_type:
has_message = any(
isinstance(instance, HumanMessage) for instance in self.messages
)
if has_message:
raise ValueError("Already Have Human message")
self._append_message(HumanMessage(content=message))
def add_ai_message(
self, message: str, update_if_exist: Optional[bool] = False
) -> None:
"""Add an AI message to the conversation
Args:
message (str): The message content
update_if_exist (bool): Whether to update the message if the message type is duplicate
"""
if not update_if_exist:
self._append_message(AIMessage(content=message))
return
has_message = any(isinstance(instance, AIMessage) for instance in self.messages) has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message: if has_message:
self.__update_ai_message(message) self._update_ai_message(message)
else: else:
self.messages.append(AIMessage(content=message)) self._append_message(AIMessage(content=message))
""" """
def __update_ai_message(self, new_message: str) -> None: def _update_ai_message(self, new_message: str) -> None:
""" """
stream out message update stream out message update
Args: Args:
@@ -252,13 +357,11 @@ class OnceConversation:
def add_view_message(self, message: str) -> None: def add_view_message(self, message: str) -> None:
"""Add an AI message to the store""" """Add an AI message to the store"""
self._append_message(ViewMessage(content=message))
self.messages.append(ViewMessage(content=message))
""" """
def add_system_message(self, message: str) -> None: def add_system_message(self, message: str) -> None:
"""Add an AI message to the store""" """Add a system message to the store"""
self.messages.append(SystemMessage(content=message)) self._append_message(SystemMessage(content=message))
def set_start_time(self, datatime: datetime): def set_start_time(self, datatime: datetime):
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S") dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
@@ -267,23 +370,369 @@ class OnceConversation:
def clear(self) -> None: def clear(self) -> None:
"""Remove all messages from the store""" """Remove all messages from the store"""
self.messages.clear() self.messages.clear()
self.session_id = None
def get_user_conv(self): def get_latest_user_message(self) -> Optional[HumanMessage]:
for message in self.messages: """Get the latest user message"""
for message in self.messages[::-1]:
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
return message return message
return None return None
def get_system_conv(self): def get_system_messages(self) -> List[SystemMessage]:
system_convs = [] """Get the latest user message"""
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
def _to_dict(self) -> Dict:
return _conversation_to_dict(self)
def from_conversation(self, conversation: OnceConversation) -> None:
"""Load the conversation from the storage"""
self.chat_mode = conversation.chat_mode
self.messages = conversation.messages
self.start_date = conversation.start_date
self.chat_order = conversation.chat_order
self.model_name = conversation.model_name
self.param_type = conversation.param_type
self.param_value = conversation.param_value
self.cost = conversation.cost
self.tokens = conversation.tokens
self.user_name = conversation.user_name
self.sys_code = conversation.sys_code
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
"""Get the messages by round index
Args:
round_index (int): The round index
Returns:
List[BaseMessage]: The messages
"""
return list(filter(lambda x: x.round_index == round_index, self.messages))
def get_latest_round(self) -> List[BaseMessage]:
"""Get the latest round messages
Returns:
List[BaseMessage]: The messages
"""
return self.get_messages_by_round(self.chat_order)
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
"""Get the messages with round count
If the round count is 1, the history messages will not be included.
Example:
.. code-block:: python
conversation = OnceConversation()
conversation.start_new_round()
conversation.add_user_message("hello, this is the first round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is the second round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is the third round")
conversation.add_ai_message("hi")
conversation.end_current_round()
assert len(conversation.get_messages_with_round(1)) == 2
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
assert conversation.get_messages_with_round(1)[1].content == "hi"
assert len(conversation.get_messages_with_round(2)) == 4
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
assert conversation.get_messages_with_round(2)[1].content == "hi"
Args:
round_count (int): The round count
Returns:
List[BaseMessage]: The messages
"""
latest_round_index = self.chat_order
start_round_index = max(1, latest_round_index - round_count + 1)
messages = []
for round_index in range(start_round_index, latest_round_index + 1):
messages.extend(self.get_messages_by_round(round_index))
return messages
def get_model_messages(self) -> List[ModelMessage]:
"""Get the model messages
Model messages just include human, ai and system messages.
Model messages maybe include the history messages, The order of the messages is the same as the order of
the messages in the conversation, the last message is the latest message.
If you want to hand the message with your own logic, you can override this method.
Examples:
If you not need the history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
for message in self.get_latest_round():
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
)
return messages
If you want to add the one round history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
latest_round_index = self.chat_order
round_count = 1
start_round_index = max(1, latest_round_index - round_count + 1)
for round_index in range(start_round_index, latest_round_index + 1):
for message in self.get_messages_by_round(round_index):
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
)
return messages
Returns:
List[ModelMessage]: The model messages
"""
messages = []
for message in self.messages: for message in self.messages:
if isinstance(message, SystemMessage): if message.pass_to_model:
system_convs.append(message) messages.append(
return system_convs ModelMessage(role=message.type, content=message.content)
)
return messages
def _conversation_to_dict(once: OnceConversation) -> dict: class ConversationIdentifier(ResourceIdentifier):
"""Conversation identifier"""
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
self.conv_uid = conv_uid
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}:{self.conv_uid}"
def to_dict(self) -> Dict:
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
class MessageIdentifier(ResourceIdentifier):
"""Message identifier"""
identifier_split = "___"
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
self.conv_uid = conv_uid
self.index = index
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
@staticmethod
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
"""Convert from str identifier
Args:
str_identifier (str): The str identifier
Returns:
MessageIdentifier: The message identifier
"""
parts = str_identifier.split(MessageIdentifier.identifier_split)
if len(parts) != 3:
raise ValueError(f"Invalid str identifier: {str_identifier}")
return MessageIdentifier(parts[1], int(parts[2]))
def to_dict(self) -> Dict:
return {
"conv_uid": self.conv_uid,
"index": self.index,
"identifier_type": self.identifier_type,
}
class MessageStorageItem(StorageItem):
@property
def identifier(self) -> MessageIdentifier:
return self._id
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
self.conv_uid = conv_uid
self.index = index
self.message_detail = message_detail
self._id = MessageIdentifier(conv_uid, index)
def to_dict(self) -> Dict:
return {
"conv_uid": self.conv_uid,
"index": self.index,
"message_detail": self.message_detail,
}
def to_message(self) -> BaseMessage:
"""Convert to message object
Returns:
BaseMessage: The message object
Raises:
ValueError: If the message type is not supported
"""
return _message_from_dict(self.message_detail)
def merge(self, other: "StorageItem") -> None:
"""Merge the other message to self
Args:
other (StorageItem): The other message
"""
if not isinstance(other, MessageStorageItem):
raise ValueError(f"Can not merge {other} to {self}")
self.message_detail = other.message_detail
class StorageConversation(OnceConversation, StorageItem):
"""All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
@property
def identifier(self) -> ConversationIdentifier:
return self._id
def to_dict(self) -> Dict:
dict_data = self._to_dict()
messages: Dict = dict_data.pop("messages")
message_ids = []
index = 0
for message in messages:
if "index" in message:
message_idx = message["index"]
else:
message_idx = index
index += 1
message_ids.append(
MessageIdentifier(self.conv_uid, message_idx).str_identifier
)
# Replace message with message ids
dict_data["conv_uid"] = self.conv_uid
dict_data["message_ids"] = message_ids
dict_data["save_message_independent"] = self.save_message_independent
return dict_data
def merge(self, other: "StorageItem") -> None:
"""Merge the other conversation to self
Args:
other (StorageItem): The other conversation
"""
if not isinstance(other, StorageConversation):
raise ValueError(f"Can not merge {other} to {self}")
self.from_conversation(other)
def __init__(
self,
conv_uid: str,
chat_mode: str = None,
user_name: str = None,
sys_code: str = None,
message_ids: List[str] = None,
summary: str = None,
save_message_independent: Optional[bool] = True,
conv_storage: StorageInterface = None,
message_storage: StorageInterface = None,
**kwargs,
):
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
self.conv_uid = conv_uid
self._message_ids = message_ids
self.save_message_independent = save_message_independent
self._id = ConversationIdentifier(conv_uid)
if conv_storage is None:
conv_storage = InMemoryStorage()
if message_storage is None:
message_storage = InMemoryStorage()
self.conv_storage = conv_storage
self.message_storage = message_storage
# Load from storage
self.load_from_storage(self.conv_storage, self.message_storage)
@property
def message_ids(self) -> List[str]:
"""Get the message ids
Returns:
List[str]: The message ids
"""
return self._message_ids if self._message_ids else []
def end_current_round(self) -> None:
"""End the current round of conversation
Save the conversation to the storage after a round of conversation
"""
self.save_to_storage()
def _get_message_items(self) -> List[MessageStorageItem]:
return [
MessageStorageItem(self.conv_uid, message.index, message.to_dict())
for message in self.messages
]
def save_to_storage(self) -> None:
"""Save the conversation to the storage"""
# Save messages first
message_list = self._get_message_items()
self._message_ids = [
message.identifier.str_identifier for message in message_list
]
self.message_storage.save_list(message_list)
# Save conversation
self.conv_storage.save_or_update(self)
def load_from_storage(
self, conv_storage: StorageInterface, message_storage: StorageInterface
) -> None:
"""Load the conversation from the storage
Warning: This will overwrite the current conversation.
Args:
conv_storage (StorageInterface): The storage interface
message_storage (StorageInterface): The storage interface
"""
# Load conversation first
conversation: StorageConversation = conv_storage.load(
self._id, StorageConversation
)
if conversation is None:
return
message_ids = conversation._message_ids or []
# Load messages
message_list = message_storage.load_list(
[
MessageIdentifier.from_str_identifier(message_id)
for message_id in message_ids
],
MessageStorageItem,
)
messages = [message.to_message() for message in message_list]
conversation.messages = messages
self._message_ids = message_ids
self.from_conversation(conversation)
def _conversation_to_dict(once: OnceConversation) -> Dict:
start_str: str = "" start_str: str = ""
if hasattr(once, "start_date") and once.start_date: if hasattr(once, "start_date") and once.start_date:
if isinstance(once.start_date, datetime): if isinstance(once.start_date, datetime):
@@ -303,6 +752,7 @@ def _conversation_to_dict(once: OnceConversation) -> dict:
"param_value": once.param_value, "param_value": once.param_value,
"user_name": once.user_name, "user_name": once.user_name,
"sys_code": once.sys_code, "sys_code": once.sys_code,
"summary": once.summary if once.summary else "",
} }

View File

@@ -92,7 +92,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}""" f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
) )
def __illegal_json_ends(self, s): def _illegal_json_ends(self, s):
temp_json = s temp_json = s
illegal_json_ends_1 = [", }", ",}"] illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]" illegal_json_ends_2 = ", ]", ",]"
@@ -102,25 +102,25 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
temp_json = temp_json.replace(illegal_json_end, " ]") temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json return temp_json
def __extract_json(self, s): def _extract_json(self, s):
try: try:
# Get the dual-mode analysis first and get the maximum result # Get the dual-mode analysis first and get the maximum result
temp_json_simple = self.__json_interception(s) temp_json_simple = self._json_interception(s)
temp_json_array = self.__json_interception(s, True) temp_json_array = self._json_interception(s, True)
if len(temp_json_simple) > len(temp_json_array): if len(temp_json_simple) > len(temp_json_array):
temp_json = temp_json_simple temp_json = temp_json_simple
else: else:
temp_json = temp_json_array temp_json = temp_json_array
if not temp_json: if not temp_json:
temp_json = self.__json_interception(s) temp_json = self._json_interception(s)
temp_json = self.__illegal_json_ends(temp_json) temp_json = self._illegal_json_ends(temp_json)
return temp_json return temp_json
except Exception as e: except Exception as e:
raise ValueError("Failed to find a valid json in LLM response" + temp_json) raise ValueError("Failed to find a valid json in LLM response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False): def _json_interception(self, s, is_json_array: bool = False):
try: try:
if is_json_array: if is_json_array:
i = s.find("[") i = s.find("[")
@@ -176,7 +176,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing:\n" + cleaned_output) logger.info("illegal json processing:\n" + cleaned_output)
cleaned_output = self.__extract_json(cleaned_output) cleaned_output = self._extract_json(cleaned_output)
if not cleaned_output or len(cleaned_output) <= 0: if not cleaned_output or len(cleaned_output) <= 0:
return model_out_text return model_out_text
@@ -188,7 +188,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
.replace("\\", " ") .replace("\\", " ")
.replace("\_", "_") .replace("\_", "_")
) )
cleaned_output = self.__illegal_json_ends(cleaned_output) cleaned_output = self._illegal_json_ends(cleaned_output)
return cleaned_output return cleaned_output
def parse_view_response( def parse_view_response(
@@ -208,20 +208,6 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
"""Instructions on how the LLM output should be formatted.""" """Instructions on how the LLM output should be formatted."""
raise NotImplementedError raise NotImplementedError
# @property
# def _type(self) -> str:
# """Return the type key."""
# raise NotImplementedError(
# f"_type property is not implemented in class {self.__class__.__name__}."
# " This is required for serialization."
# )
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
async def map(self, input_value: ModelOutput) -> Any: async def map(self, input_value: ModelOutput) -> Any:
"""Parse the output of an LLM call. """Parse the output of an LLM call.

View File

@@ -1,19 +1,34 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Type, Dict from typing import Type, Dict
class Serializable(ABC): class Serializable(ABC):
serializer: "Serializer" = None
@abstractmethod @abstractmethod
def to_dict(self) -> Dict:
"""Convert the object's state to a dictionary."""
def serialize(self) -> bytes: def serialize(self) -> bytes:
"""Convert the object into bytes for storage or transmission. """Convert the object into bytes for storage or transmission.
Returns: Returns:
bytes: The byte array after serialization bytes: The byte array after serialization
""" """
if self.serializer is None:
raise ValueError(
"Serializer is not set. Please set the serializer before serialization."
)
return self.serializer.serialize(self)
@abstractmethod def set_serializer(self, serializer: "Serializer") -> None:
def to_dict(self) -> Dict: """Set the serializer for current serializable object.
"""Convert the object's state to a dictionary."""
Args:
serializer (Serializer): The serializer to set
"""
self.serializer = serializer
class Serializer(ABC): class Serializer(ABC):

View File

@@ -0,0 +1,409 @@
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
from abc import ABC, abstractmethod
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.util.serialization.json_serialization import JsonSerializer
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.pagination_utils import PaginationResult
@PublicAPI(stability="beta")
class ResourceIdentifier(Serializable, ABC):
"""The resource identifier interface for resource identifiers."""
@property
@abstractmethod
def str_identifier(self) -> str:
"""Get the string identifier of the resource.
The string identifier is used to uniquely identify the resource.
Returns:
str: The string identifier of the resource
"""
def __hash__(self) -> int:
"""Return the hash value of the key."""
return hash(self.str_identifier)
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
if not isinstance(other, ResourceIdentifier):
return False
return self.str_identifier == other.str_identifier
@PublicAPI(stability="beta")
class StorageItem(Serializable, ABC):
"""The storage item interface for storage items."""
@property
@abstractmethod
def identifier(self) -> ResourceIdentifier:
"""Get the resource identifier of the storage item.
Returns:
ResourceIdentifier: The resource identifier of the storage item
"""
@abstractmethod
def merge(self, other: "StorageItem") -> None:
"""Merge the other storage item into the current storage item.
Args:
other (StorageItem): The other storage item
"""
T = TypeVar("T", bound=StorageItem)
TDataRepresentation = TypeVar("TDataRepresentation")
class StorageItemAdapter(Generic[T, TDataRepresentation]):
"""The storage item adapter for converting storage items to and from the storage format.
Sometimes, the storage item is not the same as the storage format,
so we need to convert the storage item to the storage format and vice versa.
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
"""
@abstractmethod
def to_storage_format(self, item: T) -> TDataRepresentation:
"""Convert the storage item to the storage format.
Args:
item (T): The storage item
Returns:
TDataRepresentation: The data in the storage format
"""
@abstractmethod
def from_storage_format(self, data: TDataRepresentation) -> T:
"""Convert the storage format to the storage item.
Args:
data (TDataRepresentation): The data in the storage format
Returns:
T: The storage item
"""
@abstractmethod
def get_query_for_identifier(
self,
storage_format: Type[TDataRepresentation],
resource_id: ResourceIdentifier,
**kwargs,
) -> Any:
"""Get the query for the resource identifier.
Args:
storage_format (Type[TDataRepresentation]): The storage format
resource_id (ResourceIdentifier): The resource identifier
kwargs: The additional arguments
Returns:
Any: The query for the resource identifier
"""
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
"""The default storage item adapter for converting storage items to and from the storage format.
The storage item is the same as the storage format, so no conversion is required.
"""
def to_storage_format(self, item: T) -> T:
return item
def from_storage_format(self, data: T) -> T:
return data
def get_query_for_identifier(
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
) -> bool:
return True
@PublicAPI(stability="beta")
class StorageError(Exception):
"""The base exception class for storage errors."""
def __init__(self, message: str):
super().__init__(message)
@PublicAPI(stability="beta")
class QuerySpec:
"""The query specification for querying data from the storage.
Attributes:
conditions (Dict[str, Any]): The conditions for querying data
limit (int): The maximum number of data to return
offset (int): The offset of the data to return
"""
def __init__(
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
) -> None:
self.conditions = conditions
self.limit = limit
self.offset = offset
@PublicAPI(stability="beta")
class StorageInterface(Generic[T, TDataRepresentation], ABC):
"""The storage interface for storing and loading data."""
def __init__(
self,
serializer: Optional[Serializer] = None,
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
):
self._serializer = serializer or JsonSerializer()
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
@property
def serializer(self) -> Serializer:
"""Get the serializer of the storage.
Returns:
Serializer: The serializer of the storage
"""
return self._serializer
@property
def adapter(self) -> StorageItemAdapter[T, TDataRepresentation]:
"""Get the adapter of the storage.
Returns:
StorageItemAdapter[T, TDataRepresentation]: The adapter of the storage
"""
return self._storage_item_adapter
@abstractmethod
def save(self, data: T) -> None:
"""Save the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If the data already exists in the storage or data is None
"""
@abstractmethod
def update(self, data: T) -> None:
"""Update the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If data is None
"""
@abstractmethod
def save_or_update(self, data: T) -> None:
"""Save or update the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If data is None
"""
def save_list(self, data: List[T]) -> None:
"""Save the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If the data already exists in the storage or data is None
"""
for d in data:
self.save(d)
def save_or_update_list(self, data: List[T]) -> None:
"""Save or update the data to the storage.
Args:
data (T): The data to save
"""
for d in data:
self.save_or_update(d)
@abstractmethod
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
Load data with resource_id will be faster than query data with conditions,
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
Optional[T]: The loaded data
"""
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
Load data with resource_id will be faster than query data with conditions,
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
Optional[T]: The loaded data
"""
result = []
for r in resource_id:
item = self.load(r, cls)
if item is not None:
result.append(item)
return result
@abstractmethod
def delete(self, resource_id: ResourceIdentifier) -> None:
"""Delete the data from the storage.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
"""
@abstractmethod
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
List[T]: The queried data
"""
@abstractmethod
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
"""Count the number of data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
int: The number of data
"""
def paginate_query(
self, page: int, page_size: int, cls: Type[T], spec: Optional[QuerySpec] = None
) -> PaginationResult[T]:
"""Paginate the query result.
Args:
page (int): The page number
page_size (int): The number of items per page
cls (Type[T]): The type of the data
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
Returns:
PaginationResult[T]: The pagination result
"""
if spec is None:
spec = QuerySpec(conditions={})
spec.limit = page_size
spec.offset = (page - 1) * page_size
items = self.query(spec, cls)
total = self.count(spec, cls)
return PaginationResult(
items=items,
total_count=total,
total_pages=(total + page_size - 1) // page_size,
page=page,
page_size=page_size,
)
@PublicAPI(stability="alpha")
class InMemoryStorage(StorageInterface[T, T]):
"""The in-memory storage for storing and loading data."""
def __init__(
self,
serializer: Optional[Serializer] = None,
):
super().__init__(serializer)
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
def save(self, data: T) -> None:
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
data.set_serializer(self.serializer)
if data.identifier.str_identifier in self._data:
raise StorageError(
f"Data with identifier {data.identifier.str_identifier} already exists"
)
self._data[data.identifier.str_identifier] = data.serialize()
def update(self, data: T) -> None:
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
data.set_serializer(self.serializer)
self._data[data.identifier.str_identifier] = data.serialize()
def save_or_update(self, data: T) -> None:
self.update(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
serialized_data = self._data.get(resource_id.str_identifier)
if serialized_data is None:
return None
return self.serializer.deserialize(serialized_data, cls)
def delete(self, resource_id: ResourceIdentifier) -> None:
if resource_id.str_identifier in self._data:
del self._data[resource_id.str_identifier]
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
result = []
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)
if all(
getattr(data, key) == value for key, value in spec.conditions.items()
):
result.append(data)
# Apply limit and offset
if spec.limit is not None:
result = result[spec.offset : spec.offset + spec.limit]
else:
result = result[spec.offset :]
return result
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
count = 0
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)
if all(
getattr(data, key) == value for key, value in spec.conditions.items()
):
count += 1
return count

View File

View File

@@ -0,0 +1,14 @@
import pytest
from dbgpt.core.interface.storage import InMemoryStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def in_memory_storage(serializer):
return InMemoryStorage(serializer)

View File

@@ -0,0 +1,307 @@
import pytest
from dbgpt.core.interface.tests.conftest import in_memory_storage
from dbgpt.core.interface.message import *
@pytest.fixture
def basic_conversation():
return OnceConversation(chat_mode="chat_normal", user_name="user1", sys_code="sys1")
@pytest.fixture
def human_message():
return HumanMessage(content="Hello")
@pytest.fixture
def ai_message():
return AIMessage(content="Hi there")
@pytest.fixture
def system_message():
return SystemMessage(content="System update")
@pytest.fixture
def view_message():
return ViewMessage(content="View this")
@pytest.fixture
def conversation_identifier():
return ConversationIdentifier("conv1")
@pytest.fixture
def message_identifier():
return MessageIdentifier("conv1", 1)
@pytest.fixture
def message_storage_item():
message = HumanMessage(content="Hello", index=1)
message_detail = message.to_dict()
return MessageStorageItem("conv1", 1, message_detail)
@pytest.fixture
def storage_conversation():
return StorageConversation("conv1", chat_mode="chat_normal", user_name="user1")
@pytest.fixture
def conversation_with_messages():
conv = OnceConversation(chat_mode="chat_normal", user_name="user1")
conv.start_new_round()
conv.add_user_message("Hello")
conv.add_ai_message("Hi")
conv.end_current_round()
conv.start_new_round()
conv.add_user_message("How are you?")
conv.add_ai_message("I'm good, thanks")
conv.end_current_round()
return conv
def test_init(basic_conversation):
assert basic_conversation.chat_mode == "chat_normal"
assert basic_conversation.user_name == "user1"
assert basic_conversation.sys_code == "sys1"
assert basic_conversation.messages == []
assert basic_conversation.start_date == ""
assert basic_conversation.chat_order == 0
assert basic_conversation.model_name == ""
assert basic_conversation.param_type == ""
assert basic_conversation.param_value == ""
assert basic_conversation.cost == 0
assert basic_conversation.tokens == 0
assert basic_conversation._message_index == 0
def test_add_user_message(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], HumanMessage)
def test_add_ai_message(basic_conversation, ai_message):
basic_conversation.add_ai_message(ai_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], AIMessage)
def test_add_system_message(basic_conversation, system_message):
basic_conversation.add_system_message(system_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], SystemMessage)
def test_add_view_message(basic_conversation, view_message):
basic_conversation.add_view_message(view_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], ViewMessage)
def test_set_start_time(basic_conversation):
now = datetime.now()
basic_conversation.set_start_time(now)
assert basic_conversation.start_date == now.strftime("%Y-%m-%d %H:%M:%S")
def test_clear_messages(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
basic_conversation.clear()
assert len(basic_conversation.messages) == 0
def test_get_latest_user_message(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
latest_message = basic_conversation.get_latest_user_message()
assert latest_message == human_message
def test_get_system_messages(basic_conversation, system_message):
basic_conversation.add_system_message(system_message.content)
system_messages = basic_conversation.get_system_messages()
assert len(system_messages) == 1
assert system_messages[0] == system_message
def test_from_conversation(basic_conversation):
new_conversation = OnceConversation(chat_mode="chat_advanced", user_name="user2")
basic_conversation.from_conversation(new_conversation)
assert basic_conversation.chat_mode == "chat_advanced"
assert basic_conversation.user_name == "user2"
def test_get_messages_by_round(conversation_with_messages):
# Test first round
round1_messages = conversation_with_messages.get_messages_by_round(1)
assert len(round1_messages) == 2
assert round1_messages[0].content == "Hello"
assert round1_messages[1].content == "Hi"
# Test not existing round
no_messages = conversation_with_messages.get_messages_by_round(3)
assert len(no_messages) == 0
def test_get_latest_round(conversation_with_messages):
latest_round_messages = conversation_with_messages.get_latest_round()
assert len(latest_round_messages) == 2
assert latest_round_messages[0].content == "How are you?"
assert latest_round_messages[1].content == "I'm good, thanks"
def test_get_messages_with_round(conversation_with_messages):
# Test last round
last_round_messages = conversation_with_messages.get_messages_with_round(1)
assert len(last_round_messages) == 2
assert last_round_messages[0].content == "How are you?"
assert last_round_messages[1].content == "I'm good, thanks"
# Test last two rounds
last_two_rounds_messages = conversation_with_messages.get_messages_with_round(2)
assert len(last_two_rounds_messages) == 4
assert last_two_rounds_messages[0].content == "Hello"
assert last_two_rounds_messages[1].content == "Hi"
def test_get_model_messages(conversation_with_messages):
model_messages = conversation_with_messages.get_model_messages()
assert len(model_messages) == 4
assert all(isinstance(msg, ModelMessage) for msg in model_messages)
assert model_messages[0].content == "Hello"
assert model_messages[1].content == "Hi"
assert model_messages[2].content == "How are you?"
assert model_messages[3].content == "I'm good, thanks"
def test_conversation_identifier(conversation_identifier):
assert conversation_identifier.conv_uid == "conv1"
assert conversation_identifier.identifier_type == "conversation"
assert conversation_identifier.str_identifier == "conversation:conv1"
assert conversation_identifier.to_dict() == {
"conv_uid": "conv1",
"identifier_type": "conversation",
}
def test_message_identifier(message_identifier):
assert message_identifier.conv_uid == "conv1"
assert message_identifier.index == 1
assert message_identifier.identifier_type == "message"
assert message_identifier.str_identifier == "message___conv1___1"
assert message_identifier.to_dict() == {
"conv_uid": "conv1",
"index": 1,
"identifier_type": "message",
}
def test_message_storage_item(message_storage_item):
assert message_storage_item.conv_uid == "conv1"
assert message_storage_item.index == 1
assert message_storage_item.message_detail == {
"type": "human",
"data": {
"content": "Hello",
"index": 1,
"round_index": 0,
"additional_kwargs": {},
"example": False,
},
"index": 1,
"round_index": 0,
}
assert isinstance(message_storage_item.identifier, MessageIdentifier)
assert message_storage_item.to_dict() == {
"conv_uid": "conv1",
"index": 1,
"message_detail": {
"type": "human",
"index": 1,
"data": {
"content": "Hello",
"index": 1,
"round_index": 0,
"additional_kwargs": {},
"example": False,
},
"round_index": 0,
},
}
assert isinstance(message_storage_item.to_message(), BaseMessage)
def test_storage_conversation_init(storage_conversation):
assert storage_conversation.conv_uid == "conv1"
assert storage_conversation.chat_mode == "chat_normal"
assert storage_conversation.user_name == "user1"
def test_storage_conversation_add_user_message(storage_conversation):
storage_conversation.add_user_message("Hi")
assert len(storage_conversation.messages) == 1
assert isinstance(storage_conversation.messages[0], HumanMessage)
def test_storage_conversation_add_ai_message(storage_conversation):
storage_conversation.add_ai_message("Hello")
assert len(storage_conversation.messages) == 1
assert isinstance(storage_conversation.messages[0], AIMessage)
def test_save_to_storage(storage_conversation, in_memory_storage):
# Set storage
storage_conversation.conv_storage = in_memory_storage
storage_conversation.message_storage = in_memory_storage
# Add messages
storage_conversation.add_user_message("User message")
storage_conversation.add_ai_message("AI response")
# Save to storage
storage_conversation.save_to_storage()
# Create a new StorageConversation instance to load the data
saved_conversation = StorageConversation(
storage_conversation.conv_uid,
conv_storage=in_memory_storage,
message_storage=in_memory_storage,
)
assert saved_conversation.conv_uid == storage_conversation.conv_uid
assert len(saved_conversation.messages) == 2
assert isinstance(saved_conversation.messages[0], HumanMessage)
assert isinstance(saved_conversation.messages[1], AIMessage)
def test_load_from_storage(storage_conversation, in_memory_storage):
# Set storage
storage_conversation.conv_storage = in_memory_storage
storage_conversation.message_storage = in_memory_storage
# Add messages and save to storage
storage_conversation.add_user_message("User message")
storage_conversation.add_ai_message("AI response")
storage_conversation.save_to_storage()
# Create a new StorageConversation instance to load the data
new_conversation = StorageConversation(
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
)
# Check if the data is loaded correctly
assert new_conversation.conv_uid == storage_conversation.conv_uid
assert len(new_conversation.messages) == 2
assert new_conversation.messages[0].content == "User message"
assert new_conversation.messages[1].content == "AI response"
assert isinstance(new_conversation.messages[0], HumanMessage)
assert isinstance(new_conversation.messages[1], AIMessage)

View File

@@ -0,0 +1,129 @@
import pytest
from typing import Dict, Type, Union
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageError,
QuerySpec,
InMemoryStorage,
StorageItem,
)
from dbgpt.util.serialization.json_serialization import JsonSerializer
class MockResourceIdentifier(ResourceIdentifier):
def __init__(self, identifier: str):
self._identifier = identifier
@property
def str_identifier(self) -> str:
return self._identifier
def to_dict(self) -> Dict:
return {"identifier": self._identifier}
class MockStorageItem(StorageItem):
def merge(self, other: "StorageItem") -> None:
if not isinstance(other, MockStorageItem):
raise ValueError("other must be a MockStorageItem")
self.data = other.data
def __init__(self, identifier: Union[str, MockResourceIdentifier], data):
self._identifier_str = (
identifier if isinstance(identifier, str) else identifier.str_identifier
)
self.data = data
def to_dict(self) -> Dict:
return {"identifier": self._identifier_str, "data": self.data}
@property
def identifier(self) -> ResourceIdentifier:
return MockResourceIdentifier(self._identifier_str)
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def in_memory_storage(serializer):
return InMemoryStorage(serializer)
def test_save_and_load(in_memory_storage):
resource_id = MockResourceIdentifier("1")
item = MockStorageItem(resource_id, "test_data")
in_memory_storage.save(item)
loaded_item = in_memory_storage.load(resource_id, MockStorageItem)
assert loaded_item.data == "test_data"
def test_duplicate_save(in_memory_storage):
item = MockStorageItem("1", "test_data")
in_memory_storage.save(item)
# Should raise StorageError when saving the same data
with pytest.raises(StorageError):
in_memory_storage.save(item)
def test_delete(in_memory_storage):
resource_id = MockResourceIdentifier("1")
item = MockStorageItem(resource_id, "test_data")
in_memory_storage.save(item)
in_memory_storage.delete(resource_id)
# Storage should not contain the data after deletion
assert in_memory_storage.load(resource_id, MockStorageItem) is None
def test_query(in_memory_storage):
resource_id1 = MockResourceIdentifier("1")
item1 = MockStorageItem(resource_id1, "test_data1")
resource_id2 = MockResourceIdentifier("2")
item2 = MockStorageItem(resource_id2, "test_data2")
in_memory_storage.save(item1)
in_memory_storage.save(item2)
query_spec = QuerySpec(conditions={"data": "test_data1"})
results = in_memory_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
assert results[0].data == "test_data1"
def test_count(in_memory_storage):
item1 = MockStorageItem("1", "test_data1")
item2 = MockStorageItem("2", "test_data2")
in_memory_storage.save(item1)
in_memory_storage.save(item2)
query_spec = QuerySpec(conditions={})
count = in_memory_storage.count(query_spec, MockStorageItem)
assert count == 2
def test_paginate_query(in_memory_storage):
for i in range(10):
resource_id = MockResourceIdentifier(str(i))
item = MockStorageItem(resource_id, f"test_data{i}")
in_memory_storage.save(item)
page_size = 3
query_spec = QuerySpec(conditions={})
page_result = in_memory_storage.paginate_query(
2, page_size, MockStorageItem, query_spec
)
assert len(page_result.items) == page_size
assert page_result.total_count == 10
assert page_result.total_pages == 4
assert page_result.page == 2

View File

@@ -91,6 +91,10 @@ class BaseConnect(ABC):
"""Get column fields about specified table.""" """Get column fields about specified table."""
pass pass
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self.get_fields(table_name)
def get_show_create_table(self, table_name): def get_show_create_table(self, table_name):
"""Get the creation table sql about specified table.""" """Get the creation table sql about specified table."""
pass pass

View File

@@ -1,16 +1,10 @@
from sqlalchemy import Column, Integer, String, Index, Text, text from sqlalchemy import Column, Integer, String, Index, Text, text
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
class ConnectConfigEntity(Base): class ConnectConfigEntity(Model):
"""db connect config entity""" """db connect config entity"""
__tablename__ = "connect_config" __tablename__ = "connect_config"
@@ -38,17 +32,9 @@ class ConnectConfigEntity(Base):
class ConnectConfigDao(BaseDao[ConnectConfigEntity]): class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
"""db connect config dao""" """db connect config dao"""
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def update(self, entity: ConnectConfigEntity): def update(self, entity: ConnectConfigEntity):
"""update db connect info""" """update db connect info"""
session = self.get_session() session = self.get_raw_session()
try: try:
updated = session.merge(entity) updated = session.merge(entity)
session.commit() session.commit()
@@ -58,7 +44,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def delete(self, db_name: int): def delete(self, db_name: int):
""" "delete db connect info""" """ "delete db connect info"""
session = self.get_session() session = self.get_raw_session()
if db_name is None: if db_name is None:
raise Exception("db_name is None") raise Exception("db_name is None")
@@ -70,7 +56,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_by_names(self, db_name: str) -> ConnectConfigEntity: def get_by_names(self, db_name: str) -> ConnectConfigEntity:
"""get db connect info by name""" """get db connect info by name"""
session = self.get_session() session = self.get_raw_session()
db_connect = session.query(ConnectConfigEntity) db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name) db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
result = db_connect.first() result = db_connect.first()
@@ -99,7 +85,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
comment: comment comment: comment
""" """
try: try:
session = self.get_session() session = self.get_raw_session()
from sqlalchemy import text from sqlalchemy import text
@@ -144,7 +130,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
old_db_conf = self.get_db_config(db_name) old_db_conf = self.get_db_config(db_name)
if old_db_conf: if old_db_conf:
try: try:
session = self.get_session() session = self.get_raw_session()
if not db_path: if not db_path:
update_statement = text( update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'" f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
@@ -164,7 +150,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
"""add file db connect info""" """add file db connect info"""
try: try:
session = self.get_session() session = self.get_raw_session()
insert_statement = text( insert_statement = text(
""" """
INSERT INTO connect_config( INSERT INTO connect_config(
@@ -194,7 +180,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_db_config(self, db_name): def get_db_config(self, db_name):
"""get db config by name""" """get db config by name"""
session = self.get_session() session = self.get_raw_session()
if db_name: if db_name:
select_statement = text( select_statement = text(
""" """
@@ -221,7 +207,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_db_list(self): def get_db_list(self):
"""get db list""" """get db list"""
session = self.get_session() session = self.get_raw_session()
result = session.execute(text("SELECT * FROM connect_config")) result = session.execute(text("SELECT * FROM connect_config"))
fields = [field[0] for field in result.cursor.description] fields = [field[0] for field in result.cursor.description]
@@ -235,7 +221,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def delete_db(self, db_name): def delete_db(self, db_name):
"""delete db connect info""" """delete db connect info"""
session = self.get_session() session = self.get_raw_session()
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""") delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
params = {"db_name": db_name} params = {"db_name": db_name}
session.execute(delete_statement, params) session.execute(delete_statement, params)

View File

@@ -270,7 +270,12 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message""" """Format the error message"""
return f"Error: {e}" return f"Error: {e}"
def __write(self, write_sql): def _write(self, write_sql: str):
"""Run a SQL write command and return the results as a list of tuples.
Args:
write_sql (str): SQL write command to run
"""
print(f"Write[{write_sql}]") print(f"Write[{write_sql}]")
db_cache = self._engine.url.database db_cache = self._engine.url.database
result = self.session.execute(text(write_sql)) result = self.session.execute(text(write_sql))
@@ -280,16 +285,12 @@ class RDBMSDatabase(BaseConnect):
print(f"SQL[{write_sql}], result:{result.rowcount}") print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount return result.rowcount
def __query(self, query, fetch: str = "all"): def _query(self, query: str, fetch: str = "all"):
""" """Run a SQL query and return the results as a list of tuples.
only for query
Args: Args:
session: query (str): SQL query to run
query: fetch (str): fetch type
fetch:
Returns:
""" """
print(f"Query[{query}]") print(f"Query[{query}]")
if not query: if not query:
@@ -308,6 +309,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names) result.insert(0, field_names)
return result return result
def query_table_schema(self, table_name):
sql = f"select * from {table_name} limit 1"
return self._query(sql)
def query_ex(self, query, fetch: str = "all"): def query_ex(self, query, fetch: str = "all"):
""" """
only for query only for query
@@ -325,7 +330,7 @@ class RDBMSDatabase(BaseConnect):
if fetch == "all": if fetch == "all":
result = cursor.fetchall() result = cursor.fetchall()
elif fetch == "one": elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore result = cursor.fetchone() # type: ignore
else: else:
raise ValueError("Fetch parameter must be either 'one' or 'all'") raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys()) field_names = list(i[0:] for i in cursor.keys())
@@ -342,12 +347,12 @@ class RDBMSDatabase(BaseConnect):
parsed, ttype, sql_type, table_name = self.__sql_parse(command) parsed, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML: if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT": if sql_type == "SELECT":
return self.__query(command, fetch) return self._query(command, fetch)
else: else:
self.__write(command) self._write(command)
select_sql = self.convert_sql_write_to_select(command) select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}") print(f"write result query:{select_sql}")
return self.__query(select_sql) return self._query(select_sql)
else: else:
print(f"DDL execution determines whether to enable through configuration ") print(f"DDL execution determines whether to enable through configuration ")
@@ -360,10 +365,11 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names) result.insert(0, field_names)
print("DDL Result:" + str(result)) print("DDL Result:" + str(result))
if not result: if not result:
return self.__query(f"SHOW COLUMNS FROM {table_name}") # return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
return result return result
else: else:
return self.__query(f"SHOW COLUMNS FROM {table_name}") return self.get_simple_fields(table_name)
def run_to_df(self, command: str, fetch: str = "all"): def run_to_df(self, command: str, fetch: str = "all"):
result_lst = self.run(command, fetch) result_lst = self.run(command, fetch)
@@ -451,13 +457,23 @@ class RDBMSDatabase(BaseConnect):
sql = sql.strip() sql = sql.strip()
parsed = sqlparse.parse(sql)[0] parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type() sql_type = parsed.get_type()
table_name = parsed.get_name() if sql_type == "CREATE":
table_name = self._extract_table_name_from_ddl(parsed)
else:
table_name = parsed.get_name()
first_token = parsed.token_first(skip_ws=True, skip_cm=False) first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}") print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
return parsed, ttype, sql_type, table_name return parsed, ttype, sql_type, table_name
def _extract_table_name_from_ddl(self, parsed):
"""Extract table name from CREATE TABLE statement.""" ""
for token in parsed.tokens:
if token.ttype is None and isinstance(token, sqlparse.sql.Identifier):
return token.get_real_name()
return None
def get_indexes(self, table_name): def get_indexes(self, table_name):
"""Get table indexes about specified table.""" """Get table indexes about specified table."""
session = self._db_sessions() session = self._db_sessions()
@@ -485,6 +501,10 @@ class RDBMSDatabase(BaseConnect):
fields = cursor.fetchall() fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_charset(self): def get_charset(self):
"""Get character_set.""" """Get character_set."""
session = self._db_sessions() session = self._db_sessions()

View File

@@ -56,6 +56,10 @@ class SQLiteConnect(RDBMSDatabase):
print(fields) print(fields)
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields] return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self.get_fields(table_name)
def get_users(self): def get_users(self):
return [] return []
@@ -88,8 +92,9 @@ class SQLiteConnect(RDBMSDatabase):
self._metadata.reflect(bind=self._engine) self._metadata.reflect(bind=self._engine)
return self._all_tables return self._all_tables
def _write(self, session, write_sql): def _write(self, write_sql):
print(f"Write[{write_sql}]") print(f"Write[{write_sql}]")
session = self.session
result = session.execute(text(write_sql)) result = session.execute(text(write_sql))
session.commit() session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem # TODO Subsequent optimization of dynamically specified database submission loss target problem

View File

@@ -25,41 +25,41 @@ def test_get_table_info(db):
def test_get_table_info_with_table(db): def test_get_table_info_with_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER);") db.run("CREATE TABLE test (id INTEGER);")
print(db._sync_tables_from_db()) print(db._sync_tables_from_db())
table_info = db.get_table_info() table_info = db.get_table_info()
assert "CREATE TABLE test" in table_info assert "CREATE TABLE test" in table_info
def test_run_sql(db): def test_run_sql(db):
result = db.run(db.session, "CREATE TABLE test (id INTEGER);") result = db.run("CREATE TABLE test(id INTEGER);")
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk") assert result[0] == ("id", "INTEGER", 0, None, 0)
def test_run_no_throw(db): def test_run_no_throw(db):
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:") assert db.run_no_throw("this is a error sql").startswith("Error:")
def test_get_indexes(db): def test_get_indexes(db):
db.run(db.session, "CREATE TABLE test (name TEXT);") db.run("CREATE TABLE test (name TEXT);")
db.run(db.session, "CREATE INDEX idx_name ON test(name);") db.run("CREATE INDEX idx_name ON test(name);")
assert db.get_indexes("test") == [("idx_name", "c")] assert db.get_indexes("test") == [("idx_name", "c")]
def test_get_indexes_empty(db): def test_get_indexes_empty(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_indexes("test") == [] assert db.get_indexes("test") == []
def test_get_show_create_table(db): def test_get_show_create_table(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert ( assert (
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)" db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
) )
def test_get_fields(db): def test_get_fields(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)] assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
@@ -72,26 +72,26 @@ def test_get_collation(db):
def test_table_simple_info(db): def test_table_simple_info(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.table_simple_info() == ["test(id);"] assert db.table_simple_info() == ["test(id);"]
def test_get_table_info_no_throw(db): def test_get_table_info_no_throw(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:") assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
def test_query_ex(db): def test_query_ex(db):
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.run(db.session, "insert into test(id) values (1)") db.run("insert into test(id) values (1)")
db.run(db.session, "insert into test(id) values (2)") db.run("insert into test(id) values (2)")
field_names, result = db.query_ex(db.session, "select * from test") field_names, result = db.query_ex("select * from test")
assert field_names == ["id"] assert field_names == ["id"]
assert result == [(1,), (2,)] assert result == [(1,), (2,)]
field_names, result = db.query_ex(db.session, "select * from test", fetch="one") field_names, result = db.query_ex("select * from test", fetch="one")
assert field_names == ["id"] assert field_names == ["id"]
assert result == [(1,)] assert result == [1]
def test_convert_sql_write_to_select(db): def test_convert_sql_write_to_select(db):
@@ -109,7 +109,7 @@ def test_get_users(db):
def test_get_table_comments(db): def test_get_table_comments(db):
assert db.get_table_comments() == [] assert db.get_table_comments() == []
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") db.run("CREATE TABLE test (id INTEGER PRIMARY KEY);")
assert db.get_table_comments() == [ assert db.get_table_comments() == [
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)") ("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
] ]

View File

@@ -4,6 +4,7 @@ from aioresponses import aioresponses
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError from httpx import AsyncClient, HTTPError
import importlib.metadata as metadata
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
from dbgpt.util.openai_utils import chat_completion_stream, chat_completion from dbgpt.util.openai_utils import chat_completion_stream, chat_completion
@@ -190,12 +191,26 @@ async def test_chat_completions_with_openai_lib_async_stream(
) )
stream_stream_resp = "" stream_stream_resp = ""
async for stream_resp in await openai.ChatCompletion.acreate( if metadata.version("openai") >= "1.0.0":
model=model_name, from openai import OpenAI
messages=[{"role": "user", "content": "Hello! What is your name?"}],
stream=True, client = OpenAI(
): **{"base_url": "http://test/api/v1", "api_key": client_api_key}
)
res = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
stream=True,
)
else:
res = openai.ChatCompletion.acreate(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
stream=True,
)
async for stream_resp in res:
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "") stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
assert stream_stream_resp == expected_messages assert stream_stream_resp == expected_messages

View File

@@ -75,9 +75,8 @@ class LLMCacheValueData:
class LLMCacheKey(CacheKey[LLMCacheKeyData]): class LLMCacheKey(CacheKey[LLMCacheKeyData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__() super().__init__()
self._serializer = serializer
self.config = LLMCacheKeyData(**kwargs) self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int: def __hash__(self) -> int:
@@ -96,30 +95,23 @@ class LLMCacheKey(CacheKey[LLMCacheKeyData]):
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
return asdict(self.config) return asdict(self.config)
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheKeyData: def get_value(self) -> LLMCacheKeyData:
return self.config return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]): class LLMCacheValue(CacheValue[LLMCacheValueData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__() super().__init__()
self._serializer = serializer
self.value = LLMCacheValueData.from_dict(**kwargs) self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
return self.value.to_dict() return self.value.to_dict()
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheValueData: def get_value(self) -> LLMCacheValueData:
return self.value return self.value
def __str__(self) -> str: def __str__(self) -> str:
return f"vaue: {str(self.value)}" return f"value: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]): class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
@@ -146,7 +138,11 @@ class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
return await self.get(key, cache_config) is not None return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey: def new_key(self, **kwargs) -> LLMCacheKey:
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs) key = LLMCacheKey(**kwargs)
key.set_serializer(self._cache_manager.serializer)
return key
def new_value(self, **kwargs) -> LLMCacheValue: def new_value(self, **kwargs) -> LLMCacheValue:
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs) value = LLMCacheValue(**kwargs)
value.set_serializer(self._cache_manager.serializer)
return value

View File

@@ -1,17 +1,12 @@
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Index, Text from datetime import datetime
from sqlalchemy import Column, Integer, String, Index, Text, DateTime
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
class ChatHistoryEntity(Base): class ChatHistoryEntity(Model):
__tablename__ = "chat_history" __tablename__ = "chat_history"
id = Column( id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id" Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
@@ -20,38 +15,60 @@ class ChatHistoryEntity(Base):
"mysql_charset": "utf8mb4", "mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci", "mysql_collate": "utf8mb4_unicode_ci",
} }
conv_uid = Column(
String(255),
unique=True,
nullable=False,
comment="Conversation record unique id",
)
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
summary = Column(String(255), nullable=False, comment="Conversation record summary")
user_name = Column(String(255), nullable=True, comment="interlocutor")
messages = Column(
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
)
message_ids = Column(
Text(length=2**31 - 1), nullable=True, comment="Message ids, split by comma"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
Index("idx_q_user", "user_name")
Index("idx_q_mode", "chat_mode")
Index("idx_q_conv", "summary")
class ChatHistoryMessageEntity(Model):
__tablename__ = "chat_history_message"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
__table_args__ = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
conv_uid = Column( conv_uid = Column(
String(255), String(255),
unique=False, unique=False,
nullable=False, nullable=False,
comment="Conversation record unique id", comment="Conversation record unique id",
) )
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode") index = Column(Integer, nullable=False, comment="Message index")
summary = Column(String(255), nullable=False, comment="Conversation record summary") round_index = Column(Integer, nullable=False, comment="Message round index")
user_name = Column(String(255), nullable=True, comment="interlocutor") message_detail = Column(
messages = Column( Text(length=2**31 - 1), nullable=True, comment="Message details, json format"
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
) )
sys_code = Column(String(128), index=True, nullable=True, comment="System code") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
UniqueConstraint("conv_uid", name="uk_conversation") gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
Index("idx_q_user", "user_name") UniqueConstraint("conv_uid", "index", name="uk_conversation_message")
Index("idx_q_mode", "chat_mode")
Index("idx_q_conv", "summary")
class ChatHistoryDao(BaseDao[ChatHistoryEntity]): class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def list_last_20( def list_last_20(
self, user_name: Optional[str] = None, sys_code: Optional[str] = None self, user_name: Optional[str] = None, sys_code: Optional[str] = None
): ):
session = self.get_session() session = self.get_raw_session()
chat_history = session.query(ChatHistoryEntity) chat_history = session.query(ChatHistoryEntity)
if user_name: if user_name:
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name) chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
@@ -65,7 +82,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
return result return result
def update(self, entity: ChatHistoryEntity): def update(self, entity: ChatHistoryEntity):
session = self.get_session() session = self.get_raw_session()
try: try:
updated = session.merge(entity) updated = session.merge(entity)
session.commit() session.commit()
@@ -74,7 +91,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
session.close() session.close()
def update_message_by_uid(self, message: str, conv_uid: str): def update_message_by_uid(self, message: str, conv_uid: str):
session = self.get_session() session = self.get_raw_session()
try: try:
chat_history = session.query(ChatHistoryEntity) chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
@@ -85,20 +102,12 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
session.close() session.close()
def delete(self, conv_uid: int): def delete(self, conv_uid: int):
session = self.get_session()
if conv_uid is None: if conv_uid is None:
raise Exception("conv_uid is None") raise Exception("conv_uid is None")
with self.session() as session:
chat_history = session.query(ChatHistoryEntity) chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
chat_history.delete() chat_history.delete()
session.commit()
session.close()
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity: def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
session = self.get_session() return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first()
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
result = chat_history.first()
session.close()
return result

View File

@@ -0,0 +1,116 @@
from typing import List, Dict, Type
import json
from sqlalchemy.orm import Session
from dbgpt.core.interface.storage import StorageItemAdapter
from dbgpt.core.interface.message import (
StorageConversation,
ConversationIdentifier,
MessageIdentifier,
MessageStorageItem,
_messages_from_dict,
_conversation_to_dict,
BaseMessage,
)
from .chat_history_db import ChatHistoryEntity, ChatHistoryMessageEntity
class DBStorageConversationItemAdapter(
StorageItemAdapter[StorageConversation, ChatHistoryEntity]
):
def to_storage_format(self, item: StorageConversation) -> ChatHistoryEntity:
message_ids = ",".join(item.message_ids)
messages = None
if not item.save_message_independent and item.messages:
messages = _conversation_to_dict(item)
return ChatHistoryEntity(
conv_uid=item.conv_uid,
chat_mode=item.chat_mode,
summary=item.summary or item.get_latest_user_message().content,
user_name=item.user_name,
# We not save messages to chat_history table in new design
messages=messages,
message_ids=message_ids,
sys_code=item.sys_code,
)
def from_storage_format(self, model: ChatHistoryEntity) -> StorageConversation:
message_ids = model.message_ids.split(",") if model.message_ids else []
old_conversations: List[Dict] = (
json.loads(model.messages) if model.messages else []
)
old_messages = []
save_message_independent = True
if old_conversations:
# Load old messages from old conversations, in old design, we save messages to chat_history table
old_messages_dict = []
for old_conversation in old_conversations:
old_messages_dict.extend(
old_conversation["messages"]
if "messages" in old_conversation
else []
)
save_message_independent = False
old_messages: List[BaseMessage] = _messages_from_dict(old_messages_dict)
return StorageConversation(
conv_uid=model.conv_uid,
chat_mode=model.chat_mode,
summary=model.summary,
user_name=model.user_name,
message_ids=message_ids,
sys_code=model.sys_code,
save_message_independent=save_message_independent,
messages=old_messages,
)
def get_query_for_identifier(
self,
storage_format: Type[ChatHistoryEntity],
resource_id: ConversationIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
return session.query(ChatHistoryEntity).filter(
ChatHistoryEntity.conv_uid == resource_id.conv_uid
)
class DBMessageStorageItemAdapter(
StorageItemAdapter[MessageStorageItem, ChatHistoryMessageEntity]
):
def to_storage_format(self, item: MessageStorageItem) -> ChatHistoryMessageEntity:
round_index = item.message_detail.get("round_index", 0)
message_detail = json.dumps(item.message_detail, ensure_ascii=False)
return ChatHistoryMessageEntity(
conv_uid=item.conv_uid,
index=item.index,
round_index=round_index,
message_detail=message_detail,
)
def from_storage_format(
self, model: ChatHistoryMessageEntity
) -> MessageStorageItem:
message_detail = (
json.loads(model.message_detail) if model.message_detail else {}
)
return MessageStorageItem(
conv_uid=model.conv_uid,
index=model.index,
message_detail=message_detail,
)
def get_query_for_identifier(
self,
storage_format: Type[ChatHistoryMessageEntity],
resource_id: MessageIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
return session.query(ChatHistoryMessageEntity).filter(
ChatHistoryMessageEntity.conv_uid == resource_id.conv_uid,
ChatHistoryMessageEntity.index == resource_id.index,
)

View File

@@ -87,7 +87,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
[ [
self.chat_seesion_id, self.chat_seesion_id,
once_message.chat_mode, once_message.chat_mode,
once_message.get_user_conv().content, once_message.get_latest_user_message().content,
once_message.user_name, once_message.user_name,
once_message.sys_code, once_message.sys_code,
json.dumps(conversations, ensure_ascii=False), json.dumps(conversations, ensure_ascii=False),

View File

@@ -52,14 +52,14 @@ class DbHistoryMemory(BaseChatHistoryMemory):
if context: if context:
conversations = json.loads(context) conversations = json.loads(context)
else: else:
chat_history.summary = once_message.get_user_conv().content chat_history.summary = once_message.get_latest_user_message().content
else: else:
chat_history: ChatHistoryEntity = ChatHistoryEntity() chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.conv_uid = self.chat_seesion_id chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode chat_history.chat_mode = once_message.chat_mode
chat_history.user_name = once_message.user_name chat_history.user_name = once_message.user_name
chat_history.sys_code = once_message.sys_code chat_history.sys_code = once_message.sys_code
chat_history.summary = once_message.get_user_conv().content chat_history.summary = once_message.get_latest_user_message().content
conversations.append(_conversation_to_dict(once_message)) conversations.append(_conversation_to_dict(once_message))
chat_history.messages = json.dumps(conversations, ensure_ascii=False) chat_history.messages = json.dumps(conversations, ensure_ascii=False)

View File

@@ -0,0 +1,219 @@
import pytest
from typing import List
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.serialization.json_serialization import JsonSerializer
from dbgpt.core.interface.message import StorageConversation, HumanMessage, AIMessage
from dbgpt.core.interface.storage import QuerySpec
from dbgpt.storage.metadata import db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
ChatHistoryMessageEntity,
)
from dbgpt.storage.chat_history.storage_adapter import (
DBStorageConversationItemAdapter,
DBMessageStorageItemAdapter,
)
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def db_url():
"""Use in-memory SQLite database for testing"""
return "sqlite:///:memory:"
# return "sqlite:///test.db"
@pytest.fixture
def db_manager(db_url):
db.init_db(db_url)
db.create_all()
return db
@pytest.fixture
def storage_adapter():
return DBStorageConversationItemAdapter()
@pytest.fixture
def storage_message_adapter():
return DBMessageStorageItemAdapter()
@pytest.fixture
def conv_storage(db_manager, serializer, storage_adapter):
storage = SQLAlchemyStorage(
db_manager,
ChatHistoryEntity,
storage_adapter,
serializer,
)
return storage
@pytest.fixture
def message_storage(db_manager, serializer, storage_message_adapter):
storage = SQLAlchemyStorage(
db_manager,
ChatHistoryMessageEntity,
storage_message_adapter,
serializer,
)
return storage
@pytest.fixture
def conversation(conv_storage, message_storage):
return StorageConversation(
"conv1",
chat_mode="chat_normal",
user_name="user1",
conv_storage=conv_storage,
message_storage=message_storage,
)
@pytest.fixture
def four_round_conversation(conv_storage, message_storage):
conversation = StorageConversation(
"conv1",
chat_mode="chat_normal",
user_name="user1",
conv_storage=conv_storage,
message_storage=message_storage,
)
conversation.start_new_round()
conversation.add_user_message("hello, this is first round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is second round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is third round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is fourth round")
conversation.add_ai_message("hi")
conversation.end_current_round()
return conversation
@pytest.fixture
def conversation_list(request, conv_storage, message_storage):
params = request.param if hasattr(request, "param") else {}
conv_count = params.get("conv_count", 4)
result = []
for i in range(conv_count):
conversation = StorageConversation(
f"conv{i}",
chat_mode="chat_normal",
user_name="user1",
conv_storage=conv_storage,
message_storage=message_storage,
)
conversation.start_new_round()
conversation.add_user_message("hello, this is first round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is second round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is third round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is fourth round")
conversation.add_ai_message("hi")
conversation.end_current_round()
result.append(conversation)
return result
def test_save_and_load(
conversation: StorageConversation, conv_storage, message_storage
):
conversation.start_new_round()
conversation.add_user_message("hello")
conversation.add_ai_message("hi")
conversation.end_current_round()
saved_conversation = StorageConversation(
conv_uid=conversation.conv_uid,
conv_storage=conv_storage,
message_storage=message_storage,
)
assert saved_conversation.conv_uid == conversation.conv_uid
assert len(saved_conversation.messages) == 2
assert isinstance(saved_conversation.messages[0], HumanMessage)
assert isinstance(saved_conversation.messages[1], AIMessage)
assert saved_conversation.messages[0].content == "hello"
assert saved_conversation.messages[0].round_index == 1
assert saved_conversation.messages[1].content == "hi"
assert saved_conversation.messages[1].round_index == 1
def test_query_message(
conversation: StorageConversation, conv_storage, message_storage
):
conversation.start_new_round()
conversation.add_user_message("hello")
conversation.add_ai_message("hi")
conversation.end_current_round()
saved_conversation = StorageConversation(
conv_uid=conversation.conv_uid,
conv_storage=conv_storage,
message_storage=message_storage,
)
assert saved_conversation.conv_uid == conversation.conv_uid
assert len(saved_conversation.messages) == 2
query_spec = QuerySpec(conditions={"conv_uid": conversation.conv_uid})
results = conversation.conv_storage.query(query_spec, StorageConversation)
assert len(results) == 1
def test_complex_query(
conversation_list: List[StorageConversation], conv_storage, message_storage
):
query_spec = QuerySpec(conditions={"user_name": "user1"})
results = conv_storage.query(query_spec, StorageConversation)
assert len(results) == len(conversation_list)
for i, result in enumerate(results):
assert result.user_name == "user1"
assert result.conv_uid == f"conv{i}"
saved_conversation = StorageConversation(
conv_uid=result.conv_uid,
conv_storage=conv_storage,
message_storage=message_storage,
)
assert len(saved_conversation.messages) == 8
assert isinstance(saved_conversation.messages[0], HumanMessage)
assert isinstance(saved_conversation.messages[1], AIMessage)
assert saved_conversation.messages[0].content == "hello, this is first round"
assert saved_conversation.messages[1].content == "hi"
def test_query_with_page(
conversation_list: List[StorageConversation], conv_storage, message_storage
):
query_spec = QuerySpec(conditions={"user_name": "user1"})
page_result: PaginationResult = conv_storage.paginate_query(
page=1, page_size=2, cls=StorageConversation, spec=query_spec
)
assert page_result.total_count == len(conversation_list)
assert page_result.total_pages == 2
assert page_result.page_size == 2
assert len(page_result.items) == 2
assert page_result.items[0].conv_uid == "conv0"

View File

@@ -1 +1,17 @@
from dbgpt.storage.metadata.db_manager import (
db,
Model,
DatabaseManager,
create_model,
BaseModel,
)
from dbgpt.storage.metadata._base_dao import BaseDao from dbgpt.storage.metadata._base_dao import BaseDao
__ALL__ = [
"db",
"Model",
"DatabaseManager",
"create_model",
"BaseModel",
"BaseDao",
]

View File

@@ -1,25 +1,72 @@
from typing import TypeVar, Generic, Any from contextlib import contextmanager
from sqlalchemy.orm import sessionmaker from typing import TypeVar, Generic, Any, Optional
from sqlalchemy.orm.session import Session
T = TypeVar("T") T = TypeVar("T")
from .db_manager import db, DatabaseManager
class BaseDao(Generic[T]): class BaseDao(Generic[T]):
"""The base class for all DAOs.
Examples:
.. code-block:: python
class UserDao(BaseDao[User]):
def get_user_by_name(self, name: str) -> User:
with self.session() as session:
return session.query(User).filter(User.name == name).first()
def get_user_by_id(self, id: int) -> User:
with self.session() as session:
return User.get(id)
def create_user(self, name: str) -> User:
return User.create(**{"name": name})
Args:
db_manager (DatabaseManager, optional): The database manager. Defaults to None.
If None, the default database manager(db) will be used.
"""
def __init__( def __init__(
self, self,
orm_base=None, db_manager: Optional[DatabaseManager] = None,
database: str = None,
db_engine: Any = None,
session: Any = None,
) -> None: ) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist""" self._db_manager = db_manager or db
self._orm_base = orm_base
self._database = database
self._db_engine = db_engine def get_raw_session(self) -> Session:
self._session = session """Get a raw session object.
def get_session(self): Your should commit or rollback the session manually.
Session = sessionmaker(autocommit=False, autoflush=False, bind=self._db_engine) We suggest you use :meth:`session` instead.
session = Session()
return session
Example:
.. code-block:: python
user = User(name="Edward Snowden")
session = self.get_raw_session()
session.add(user)
session.commit()
session.close()
"""
return self._db_manager._session()
@contextmanager
def session(self) -> Session:
"""Provide a transactional scope around a series of operations.
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
Example:
.. code-block:: python
with self.session() as session:
session.query(User).filter(User.name == 'Edward Snowden').first()
Returns:
Session: A session object.
Raises:
Exception: Any exception will be raised.
"""
with self._db_manager.session() as session:
yield session

View File

@@ -0,0 +1,432 @@
from __future__ import annotations
import abc
from contextlib import contextmanager
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
import logging
from sqlalchemy import create_engine, URL, Engine
from sqlalchemy import orm, inspect, MetaData
from sqlalchemy.orm import (
scoped_session,
sessionmaker,
Session,
declarative_base,
DeclarativeMeta,
)
from sqlalchemy.orm.session import _PKIdentityArgument
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.pool import QueuePool
from dbgpt.util.string_utils import _to_str
from dbgpt.util.pagination_utils import PaginationResult
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="BaseModel")
class _QueryObject:
"""The query object."""
def __init__(self, db_manager: "DatabaseManager"):
self._db_manager = db_manager
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=self._db_manager._session())
except UnmappedClassError:
return None
class BaseQuery(orm.Query):
def paginate_query(
self, page: Optional[int] = 1, per_page: Optional[int] = 20
) -> PaginationResult:
"""Paginate the query.
Example:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
pagination = session.query(User).paginate_query(page=1, page_size=10)
print(pagination)
# Or you can use the query object
with db.session() as session:
pagination = User.query.paginate_query(page=1, page_size=10)
print(pagination)
Args:
page (Optional[int], optional): The page number. Defaults to 1.
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
Returns:
PaginationResult: The pagination result.
"""
if page < 1:
raise ValueError("Page must be greater than 0")
if per_page < 0:
raise ValueError("Per page must be greater than 0")
items = self.limit(per_page).offset((page - 1) * per_page).all()
total = self.order_by(None).count()
total_pages = (total - 1) // per_page + 1
return PaginationResult(
items=items,
total_count=total,
total_pages=total_pages,
page=page,
page_size=per_page,
)
class _Model:
"""Base class for SQLAlchemy declarative base model.
With this class, we can use the query object to query the database.
Examples:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
# User is an instance of _Model, and we can use the query object to query the database.
User.query.filter(User.name == "test").all()
"""
query_class = None
query: Optional[BaseQuery] = None
def __repr__(self):
identity = inspect(self).identity
if identity is None:
pk = "(transient {0})".format(id(self))
else:
pk = ", ".join(_to_str(value) for value in identity)
return "<{0} {1}>".format(type(self).__name__, pk)
class DatabaseManager:
"""The database manager.
Examples:
.. code-block:: python
from urllib.parse import quote_plus as urlquote, quote
from dbgpt.storage.metadata import DatabaseManager, create_model
db = DatabaseManager()
# Use sqlite with memory storage.
url = f"sqlite:///:memory:"
engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
db.init_db(url, engine_args=engine_args)
Model = create_model(db)
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
session.add(User(name="test", fullname="test"))
# db will commit the session automatically default.
# session.commit()
print(User.query.filter(User.name == "test").all())
# Use CURDMixin APIs to create, update, delete, query the database.
with db.session() as session:
User.create(**{"name": "test1", "fullname": "test1"})
User.create(**{"name": "test2", "fullname": "test1"})
users = User.all()
print(users)
user = users[0]
user.update(**{"name": "test1_1111"})
user2 = users[1]
# Update user2 by save
user2.name = "test2_1111"
user2.save()
# Delete user2
user2.delete()
"""
Query = BaseQuery
def __init__(self):
self._db_url = None
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
self._engine: Optional[Engine] = None
self._session: Optional[scoped_session] = None
@property
def Model(self) -> _Model:
"""Get the declarative base."""
return self._base
@property
def metadata(self) -> MetaData:
"""Get the metadata."""
return self.Model.metadata
@property
def engine(self):
"""Get the engine.""" ""
return self._engine
@contextmanager
def session(self) -> Session:
"""Get the session with context manager.
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
Example:
>>> with db.session() as session:
>>> session.query(...)
Returns:
Session: The session.
Raises:
RuntimeError: The database manager is not initialized.
Exception: Any exception.
"""
if not self._session:
raise RuntimeError("The database manager is not initialized.")
session = self._session()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
def _make_declarative_base(
self, model: Union[Type[DeclarativeMeta], Type[_Model]]
) -> DeclarativeMeta:
"""Make the declarative base.
Args:
base (DeclarativeMeta): The base class.
Returns:
DeclarativeMeta: The declarative base.
"""
if not isinstance(model, DeclarativeMeta):
model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None):
model.query_class = self.Query
model.query = _QueryObject(self)
return model
def init_db(
self,
db_url: Union[str, URL],
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
):
"""Initialize the database manager.
Args:
db_url (Union[str, URL]): The database url.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
"""
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
if not hasattr(base, "query"):
base.query = _QueryObject(self)
if not getattr(base, "query_class", None):
base.query_class = self.Query
self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
self._session = scoped_session(session_factory)
self._base.metadata.bind = self._engine
def init_default_db(
self,
sqlite_path: str,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
):
"""Initialize the database manager with default config.
Examples:
>>> db.init_default_db(sqlite_path)
>>> with db.session() as session:
>>> session.query(...)
Args:
sqlite_path (str): The sqlite path.
engine_args (Optional[Dict], optional): The engine arguments.
Defaults to None, if None, we will use connection pool.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
"""
if not engine_args:
engine_args = {}
# Pool class
engine_args["poolclass"] = QueuePool
# The number of connections to keep open inside the connection pool.
engine_args["pool_size"] = 10
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
# pool_size).
engine_args["max_overflow"] = 20
# The number of seconds to wait before giving up on getting a connection from the pool.
engine_args["pool_timeout"] = 30
# Recycle the connection if it has been idle for this many seconds.
engine_args["pool_recycle"] = 3600
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
engine_args["pool_pre_ping"] = True
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
def create_all(self):
self.Model.metadata.create_all(self._engine)
db = DatabaseManager()
"""The global database manager.
Examples:
>>> from dbgpt.storage.metadata import db
>>> sqlite_path = "/tmp/dbgpt.db"
>>> db.init_default_db(sqlite_path)
>>> with db.session() as session:
>>> session.query(...)
>>> from dbgpt.storage.metadata import db, Model
>>> from urllib.parse import quote_plus as urlquote, quote
>>> db_name = "dbgpt"
>>> db_host = "localhost"
>>> db_port = 3306
>>> user = "root"
>>> password = "123456"
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
>>> engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
>>> db.init_db(url, engine_args=engine_args)
>>> class User(Model):
>>> __tablename__ = "user"
>>> id = Column(Integer, primary_key=True)
>>> name = Column(String(50))
>>> fullname = Column(String(50))
>>> with db.session() as session:
>>> session.add(User(name="test", fullname="test"))
>>> session.commit()
"""
class BaseCRUDMixin(Generic[T]):
"""The base CRUD mixin."""
__abstract__ = True
@classmethod
def create(cls: Type[T], **kwargs) -> T:
instance = cls(**kwargs)
return instance.save()
@classmethod
def all(cls: Type[T]) -> List[T]:
return cls.query.all()
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
"""Update specific fields of a record."""
for attr, value in kwargs.items():
setattr(self, attr, value)
return commit and self.save() or self
@abc.abstractmethod
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
@abc.abstractmethod
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods."""
__abstract__ = True
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
return db_manager._session().get(cls, ident)
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = db_manager._session()
session.add(self)
if commit:
session.commit()
return self
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = db_manager._session()
session.delete(self)
return commit and session.commit()
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
"""Base model class that includes CRUD convenience methods."""
__abstract__ = True
return _NewModel
Model = create_model(db)
def initialize_db(
db_url: Union[str, URL],
db_name: str,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
try_to_create_db: Optional[bool] = False,
) -> DatabaseManager:
"""Initialize the database manager.
Args:
db_url (Union[str, URL]): The database url.
db_name (str): The database name.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
db.init_db(db_url, engine_args, base)
if try_to_create_db:
try:
db.create_all()
except Exception as e:
logger.error(f"Failed to create database {db_name}: {e}")
return db

View File

@@ -0,0 +1,128 @@
from contextlib import contextmanager
from typing import Type, List, Optional, Union, Dict
from dbgpt.core import Serializer
from dbgpt.core.interface.storage import (
StorageInterface,
QuerySpec,
ResourceIdentifier,
StorageItemAdapter,
T,
)
from sqlalchemy import URL
from sqlalchemy.orm import Session, DeclarativeMeta
from .db_manager import BaseModel, DatabaseManager, BaseQuery
def _copy_public_properties(src: BaseModel, dest: BaseModel):
"""Simple copy public properties from src to dest"""
for column in src.__table__.columns:
if column.name != "id":
setattr(dest, column.name, getattr(src, column.name))
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
def __init__(
self,
db_url_or_db: Union[str, URL, DatabaseManager],
model_class: Type[BaseModel],
adapter: StorageItemAdapter[T, BaseModel],
serializer: Optional[Serializer] = None,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
):
super().__init__(serializer=serializer, adapter=adapter)
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
db_manager = DatabaseManager()
db_manager.init_db(db_url_or_db, engine_args, base, query_class)
self.db_manager = db_manager
elif isinstance(db_url_or_db, DatabaseManager):
self.db_manager = db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
)
self._model_class = model_class
@contextmanager
def session(self) -> Session:
with self.db_manager.session() as session:
yield session
def save(self, data: T) -> None:
with self.session() as session:
model_instance = self.adapter.to_storage_format(data)
session.add(model_instance)
def update(self, data: T) -> None:
with self.session() as session:
model_instance = self.adapter.to_storage_format(data)
session.merge(model_instance)
def save_or_update(self, data: T) -> None:
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, data.identifier, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
new_instance = self.adapter.to_storage_format(data)
_copy_public_properties(new_instance, model_instance)
session.merge(model_instance)
return
self.save(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
return self.adapter.from_storage_format(model_instance)
return None
def delete(self, resource_id: ResourceIdentifier) -> None:
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
session.delete(model_instance)
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
"""
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
query = query.filter(getattr(self._model_class, key) == value)
if spec.limit is not None:
query = query.limit(spec.limit)
if spec.offset is not None:
query = query.offset(spec.offset)
model_instances = query.all()
return [
self.adapter.from_storage_format(instance)
for instance in model_instances
]
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
"""Count the number of data in the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
"""
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
query = query.filter(getattr(self._model_class, key) == value)
return query.count()

View File

@@ -1,94 +0,0 @@
import os
import sqlite3
import logging
from sqlalchemy import create_engine, DDL
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from alembic import command
from alembic.config import Config as AlembicConfig
from urllib.parse import quote
from dbgpt._private.config import Config
from dbgpt.configs.model_config import PILOT_PATH
from urllib.parse import quote_plus as urlquote
logger = logging.getLogger(__name__)
# DB-GPT metadata database config, now support mysql and sqlite
CFG = Config()
default_db_path = os.path.join(PILOT_PATH, "meta_data")
os.makedirs(default_db_path, exist_ok=True)
# Meta Info
META_DATA_DATABASE = CFG.LOCAL_DB_NAME
db_name = META_DATA_DATABASE
db_path = default_db_path + f"/{db_name}.db"
connection = sqlite3.connect(db_path)
if CFG.LOCAL_DB_TYPE == "mysql":
engine_temp = create_engine(
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
)
# check and auto create mysqldatabase
try:
# try to connect
with engine_temp.connect() as conn:
# TODO We should consider that the production environment does not have permission to execute the DDL
conn.execute(DDL(f"CREATE DATABASE IF NOT EXISTS {db_name}"))
print(f"Already connect '{db_name}'")
except OperationalError as e:
# if connect failed, create dbgpt database
logger.error(f"{db_name} not connect success!")
engine = create_engine(
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
)
else:
engine = create_engine(f"sqlite:///{db_path}")
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
Base = declarative_base()
# Base.metadata.create_all()
alembic_ini_path = default_db_path + "/alembic.ini"
alembic_cfg = AlembicConfig(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
os.makedirs(default_db_path + "/alembic", exist_ok=True)
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
alembic_cfg.attributes["target_metadata"] = Base.metadata
alembic_cfg.attributes["session"] = session
def ddl_init_and_upgrade(disable_alembic_upgrade: bool):
"""Initialize and upgrade database metadata
Args:
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
"""
if disable_alembic_upgrade:
logger.info(
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
)
return
with engine.connect() as connection:
alembic_cfg.attributes["connection"] = connection
heads = command.heads(alembic_cfg)
print("heads:" + str(heads))
command.revision(alembic_cfg, "dbgpt ddl upate", True)
command.upgrade(alembic_cfg, "head")

View File

View File

@@ -0,0 +1,129 @@
from __future__ import annotations
import pytest
from typing import Type
from dbgpt.storage.metadata.db_manager import (
DatabaseManager,
PaginationResult,
create_model,
BaseModel,
)
from sqlalchemy import Column, Integer, String
@pytest.fixture
def db():
db = DatabaseManager()
db.init_db("sqlite:///:memory:")
return db
@pytest.fixture
def Model(db):
return create_model(db)
def test_database_initialization(db: DatabaseManager, Model: Type[BaseModel]):
assert db.engine is not None
assert db.session is not None
with db.session() as session:
assert session is not None
def test_model_creation(db: DatabaseManager, Model: Type[BaseModel]):
assert db.metadata.tables == {}
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
assert list(db.metadata.tables.keys())[0] == "user"
def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# Create
with db.session() as session:
user = User.create(name="John Doe")
session.add(user)
session.commit()
# Read
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
assert user is not None
# Update
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
user.update(name="Jane Doe")
# Delete
with db.session() as session:
user = session.query(User).filter_by(name="Jane Doe").first()
user.delete()
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# Create
user = User.create(name="John Doe")
assert User.get(user.id) is not None
users = User.all()
assert len(users) == 1
# Update
user.update(name="Bob Doe")
assert User.get(user.id).name == "Bob Doe"
user = User.get(user.id)
user.delete()
assert User.get(user.id) is None
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# 添加数据
with db.session() as session:
for i in range(30):
user = User(name=f"User {i}")
session.add(user)
session.commit()
users_page_1 = User.query.paginate_query(page=1, per_page=10)
assert len(users_page_1.items) == 10
assert users_page_1.total_pages == 3
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
with pytest.raises(ValueError):
User.query.paginate_query(page=0, per_page=10)
with pytest.raises(ValueError):
User.query.paginate_query(page=1, per_page=-1)

View File

@@ -0,0 +1,173 @@
from typing import Dict, Type
from sqlalchemy.orm import declarative_base, Session
from sqlalchemy import Column, Integer, String
import pytest
from dbgpt.core.interface.storage import (
StorageItem,
ResourceIdentifier,
StorageItemAdapter,
QuerySpec,
)
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier
from dbgpt.util.serialization.json_serialization import JsonSerializer
Base = declarative_base()
class MockModel(Base):
"""The SQLAlchemy model for the mock data."""
__tablename__ = "mock_data"
id = Column(Integer, primary_key=True)
data = Column(String)
class MockStorageItem(StorageItem):
"""The mock storage item."""
def merge(self, other: "StorageItem") -> None:
if not isinstance(other, MockStorageItem):
raise ValueError("other must be a MockStorageItem")
self.data = other.data
def __init__(self, identifier: ResourceIdentifier, data: str):
self._identifier = identifier
self.data = data
@property
def identifier(self) -> ResourceIdentifier:
return self._identifier
def to_dict(self) -> Dict:
return {"identifier": self._identifier, "data": self.data}
def serialize(self) -> bytes:
return str(self.data).encode()
class MockStorageItemAdapter(StorageItemAdapter[MockStorageItem, MockModel]):
"""The adapter for the mock storage item."""
def to_storage_format(self, item: MockStorageItem) -> MockModel:
return MockModel(id=int(item.identifier.str_identifier), data=item.data)
def from_storage_format(self, model: MockModel) -> MockStorageItem:
return MockStorageItem(MockResourceIdentifier(str(model.id)), model.data)
def get_query_for_identifier(
self,
storage_format: Type[MockModel],
resource_id: ResourceIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise ValueError("session is required for this adapter")
return session.query(storage_format).filter(
storage_format.id == int(resource_id.str_identifier)
)
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def db_url():
"""Use in-memory SQLite database for testing"""
return "sqlite:///:memory:"
@pytest.fixture
def sqlalchemy_storage(db_url, serializer):
adapter = MockStorageItemAdapter()
storage = SQLAlchemyStorage(db_url, MockModel, adapter, serializer, base=Base)
Base.metadata.create_all(storage.db_manager.engine)
return storage
def test_save_and_load(sqlalchemy_storage):
item = MockStorageItem(MockResourceIdentifier("1"), "test_data")
sqlalchemy_storage.save(item)
loaded_item = sqlalchemy_storage.load(MockResourceIdentifier("1"), MockStorageItem)
assert loaded_item.data == "test_data"
def test_delete(sqlalchemy_storage):
resource_id = MockResourceIdentifier("1")
sqlalchemy_storage.delete(resource_id)
# Make sure the item is deleted
assert sqlalchemy_storage.load(resource_id, MockStorageItem) is None
def test_query_with_various_conditions(sqlalchemy_storage):
# Add multiple items for testing
for i in range(5):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
# Test query with single condition
query_spec = QuerySpec(conditions={"data": "test_data_2"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
assert results[0].data == "test_data_2"
# Test not existing condition
query_spec = QuerySpec(conditions={"data": "nonexistent"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 0
# Test query with multiple conditions
query_spec = QuerySpec(conditions={"data": "test_data_2", "id": "2"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
def test_query_nonexistent_item(sqlalchemy_storage):
query_spec = QuerySpec(conditions={"data": "nonexistent"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 0
def test_count_items(sqlalchemy_storage):
for i in range(5):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
# Test count without conditions
query_spec = QuerySpec(conditions={})
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
assert total_count == 5
# Test count with conditions
query_spec = QuerySpec(conditions={"data": "test_data_2"})
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
assert total_count == 1
def test_paginate_query(sqlalchemy_storage):
for i in range(10):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
page_size = 3
page_number = 2
query_spec = QuerySpec(conditions={})
page_result = sqlalchemy_storage.paginate_query(
page_number, page_size, MockStorageItem, query_spec
)
assert len(page_result.items) == page_size
assert page_result.page == page_number
assert page_result.total_pages == 4
assert page_result.total_count == 10

View File

@@ -0,0 +1,219 @@
from typing import Optional
import os
import logging
from sqlalchemy import Engine, text
from sqlalchemy.orm import Session, DeclarativeMeta
from alembic import command
from alembic.util.exc import CommandError
from alembic.config import Config as AlembicConfig
logger = logging.getLogger(__name__)
def create_alembic_config(
alembic_root_path: str,
engine: Engine,
base: DeclarativeMeta,
session: Session,
alembic_ini_path: Optional[str] = None,
script_location: Optional[str] = None,
) -> AlembicConfig:
"""Create alembic config.
Args:
alembic_root_path: alembic root path
engine: sqlalchemy engine
base: sqlalchemy base
session: sqlalchemy session
alembic_ini_path (Optional[str]): alembic ini path
script_location (Optional[str]): alembic script location
Returns:
alembic config
"""
alembic_ini_path = alembic_ini_path or os.path.join(
alembic_root_path, "alembic.ini"
)
alembic_cfg = AlembicConfig(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
script_location = script_location or os.path.join(alembic_root_path, "alembic")
versions_dir = os.path.join(script_location, "versions")
os.makedirs(script_location, exist_ok=True)
os.makedirs(versions_dir, exist_ok=True)
alembic_cfg.set_main_option("script_location", script_location)
alembic_cfg.attributes["target_metadata"] = base.metadata
alembic_cfg.attributes["session"] = session
return alembic_cfg
def create_migration_script(
alembic_cfg: AlembicConfig, engine: Engine, message: str = "New migration"
) -> None:
"""Create migration script.
Args:
alembic_cfg: alembic config
engine: sqlalchemy engine
message: migration message
"""
with engine.connect() as connection:
alembic_cfg.attributes["connection"] = connection
command.revision(alembic_cfg, message, autogenerate=True)
def upgrade_database(
alembic_cfg: AlembicConfig, engine: Engine, target_version: str = "head"
) -> None:
"""Upgrade database to target version.
Args:
alembic_cfg: alembic config
engine: sqlalchemy engine
target_version: target version, default is head(latest version)
"""
with engine.connect() as connection:
alembic_cfg.attributes["connection"] = connection
# Will create tables if not exists
command.upgrade(alembic_cfg, target_version)
def downgrade_database(
alembic_cfg: AlembicConfig, engine: Engine, revision: str = "-1"
):
"""Downgrade the database by one revision.
Args:
alembic_cfg: Alembic configuration object.
engine: SQLAlchemy engine instance.
revision: Revision identifier, default is "-1" which means one revision back.
"""
with engine.connect() as connection:
alembic_cfg.attributes["connection"] = connection
command.downgrade(alembic_cfg, revision)
def clean_alembic_migration(alembic_cfg: AlembicConfig, engine: Engine) -> None:
"""Clean Alembic migration scripts and history.
Args:
alembic_cfg: Alembic config object
engine: SQLAlchemy engine instance
"""
import shutil
# Get migration script location
script_location = alembic_cfg.get_main_option("script_location")
print(f"Delete migration script location: {script_location}")
# Delete all migration script files
for file in os.listdir(script_location):
if file.startswith("versions"):
filepath = os.path.join(script_location, file)
print(f"Delete migration script file: {filepath}")
if os.path.isfile(filepath):
os.remove(filepath)
else:
shutil.rmtree(filepath, ignore_errors=True)
# Delete Alembic version table if exists
version_table = alembic_cfg.get_main_option("version_table") or "alembic_version"
if version_table:
with engine.connect() as connection:
print(f"Delete Alembic version table: {version_table}")
connection.execute(text(f"DROP TABLE IF EXISTS {version_table}"))
print("Cleaned Alembic migration scripts and history")
_MIGRATION_SOLUTION = """
**Solution 1:**
Run the following command to upgrade the database.
```commandline
dbgpt db migration upgrade
```
**Solution 2:**
Run the following command to clean the migration script and migration history.
```commandline
dbgpt db migration clean -y
```
**Solution 3:**
If you have already run the above command, but the error still exists,
you can try the following command to clean the migration script, migration history and your data.
warning: This command will delete all your data!!! Please use it with caution.
```commandline
dbgpt db migration clean --drop_all_tables -y --confirm_drop_all_tables
```
or
```commandline
rm -rf pilot/meta_data/alembic/versions/*
rm -rf pilot/meta_data/alembic/dbgpt.db
```
"""
def _ddl_init_and_upgrade(
default_meta_data_path: str,
disable_alembic_upgrade: bool,
alembic_ini_path: Optional[str] = None,
script_location: Optional[str] = None,
):
"""Initialize and upgrade database metadata
Args:
default_meta_data_path (str): default meta data path
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
alembic_ini_path (Optional[str]): alembic ini path
script_location (Optional[str]): alembic script location
"""
if disable_alembic_upgrade:
logger.info(
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
)
return
else:
warn_msg = (
"Initialize and upgrade database metadata with alembic, "
"just run this in your development environment, if you deploy this in production environment, "
"please run webserver with --disable_alembic_upgrade(`python dbgpt/app/dbgpt_server.py "
"--disable_alembic_upgrade`).\n"
"we suggest you to use `dbgpt db migration` to initialize and upgrade database metadata with alembic, "
"your can run `dbgpt db migration --help` to get more information."
)
logger.warning(warn_msg)
from dbgpt.storage.metadata.db_manager import db
alembic_cfg = create_alembic_config(
default_meta_data_path,
db.engine,
db.Model,
db.session(),
alembic_ini_path,
script_location,
)
try:
create_migration_script(alembic_cfg, db.engine)
upgrade_database(alembic_cfg, db.engine)
except CommandError as e:
if "Target database is not up to date" in str(e):
logger.error(
f"Initialize and upgrade database metadata with alembic failed, error detail: {str(e)} "
f"you can try the following solutions:\n{_MIGRATION_SOLUTION}\n"
)
raise Exception(
"Initialize and upgrade database metadata with alembic failed, "
"you can see the error and solutions above"
) from e
else:
raise e

View File

@@ -39,6 +39,31 @@ def PublicAPI(*args, **kwargs):
return decorator return decorator
def DeveloperAPI(*args, **kwargs):
"""Decorator to mark a function or class as a developer API.
Developer APIs are low-level APIs for advanced users and may change cross major versions.
Examples:
>>> from dbgpt.util.annotations import DeveloperAPI
>>> @DeveloperAPI
... def foo():
... pass
"""
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return DeveloperAPI()(args[0])
def decorator(obj):
_modify_docstring(
obj,
"**DeveloperAPI:** This API is for advanced users and may change cross major versions.",
)
return obj
return decorator
def _modify_docstring(obj, message: str = None): def _modify_docstring(obj, message: str = None):
if not message: if not message:
return return

View File

@@ -0,0 +1,14 @@
from typing import TypeVar, Generic, List
from dbgpt._private.pydantic import BaseModel, Field
T = TypeVar("T")
class PaginationResult(BaseModel, Generic[T]):
"""Pagination result"""
items: List[T] = Field(..., description="The items in the current page")
total_count: int = Field(..., description="Total number of items")
total_pages: int = Field(..., description="total number of pages")
page: int = Field(..., description="Current page number")
page_size: int = Field(..., description="Number of items per page")

View File

@@ -41,4 +41,6 @@ class JsonSerializer(Serializer):
# Convert bytes back to JSON and then to the specified class # Convert bytes back to JSON and then to the specified class
json_data = json.loads(data.decode(JSON_ENCODING)) json_data = json.loads(data.decode(JSON_ENCODING))
# Assume that the cls has an __init__ that accepts a dictionary # Assume that the cls has an __init__ that accepts a dictionary
return cls(**json_data) obj = cls(**json_data)
obj.set_serializer(self)
return obj

View File

@@ -73,9 +73,11 @@ def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
return match_map return match_map
if __name__ == "__main__": def _to_str(x, charset="utf8", errors="strict"):
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123" if x is None or isinstance(x, str):
s1 = "123" return x
s2 = "456"
print(extract_content_open_ending(s, s1, s2, True)) if isinstance(x, bytes):
return x.decode(charset, errors)
return str(x)

View File

@@ -71,3 +71,66 @@ Download and install `Microsoft C++ Build Tools` from [visual-cpp-build-tools](h
1. update your mysql username and password in docker/examples/metadata/duckdb2mysql.py 1. update your mysql username and password in docker/examples/metadata/duckdb2mysql.py
2. python docker/examples/metadata/duckdb2mysql.py 2. python docker/examples/metadata/duckdb2mysql.py
``` ```
##### Q8: `How to manage and migrate my database`
You can use the command of `dbgpt db migration` to manage and migrate your database.
See the following command for details.
```commandline
dbgpt db migration --help
```
First, you need to create a migration script(just once unless you clean it).
This command with create a `alembic` directory in your `pilot/meta_data` directory and a initial migration script in it.
```commandline
dbgpt db migration init
```
Then you can upgrade your database with the following command.
```commandline
dbgpt db migration upgrade
```
Every time you change the model or pull the latest code from DB-GPT repository, you need to create a new migration script.
```commandline
dbgpt db migration migrate -m "your message"
```
Then you can upgrade your database with the following command.
```commandline
dbgpt db migration upgrade
```
##### Q9: `alembic.util.exc.CommandError: Target database is not up to date.`
**Solution 1:**
Run the following command to upgrade the database.
```commandline
dbgpt db migration upgrade
```
**Solution 2:**
Run the following command to clean the migration script and migration history.
```commandline
dbgpt db migration clean -y
```
**Solution 3:**
If you have already run the above command, but the error still exists,
you can try the following command to clean the migration script, migration history and your data.
warning: This command will delete all your data!!! Please use it with caution.
```commandline
dbgpt db migration clean --drop_all_tables -y --confirm_drop_all_tables
```
or
```commandline
rm -rf pilot/meta_data/alembic/versions/*
rm -rf pilot/meta_data/alembic/dbgpt.db
```

View File

@@ -97,6 +97,8 @@ Configure the proxy and modify LLM_MODEL, PROXY_API_URL and API_KEY in the `.env
LLM_MODEL=chatgpt_proxyllm LLM_MODEL=chatgpt_proxyllm
PROXY_API_KEY={your-openai-sk} PROXY_API_KEY={your-openai-sk}
PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
# If you use gpt-4
# PROXYLLM_BACKEND=gpt-4
``` ```
</TabItem> </TabItem>
<TabItem value="qwen" label="通义千问"> <TabItem value="qwen" label="通义千问">

View File

@@ -3,7 +3,7 @@ from sqlalchemy import pool
from alembic import context from alembic import context
from dbgpt.storage.metadata.meta_data import Base, engine from dbgpt.storage.metadata.db_manager import db
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@@ -13,8 +13,7 @@ config = context.config
# add your model's MetaData object here # add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support
# from myapp import mymodel # from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@@ -34,6 +33,8 @@ def run_migrations_offline() -> None:
script output. script output.
""" """
engine = db.engine
target_metadata = db.metadata
url = config.get_main_option(engine.url) url = config.get_main_option(engine.url)
context.configure( context.configure(
url=url, url=url,
@@ -53,12 +54,8 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
connectable = engine_from_config( engine = db.engine
config.get_section(config.config_ini_section, {}), target_metadata = db.metadata
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with engine.connect() as connection: with engine.connect() as connection:
if engine.dialect.name == "sqlite": if engine.dialect.name == "sqlite":
context.configure( context.configure(